1use crate::boringssl::{self, Bignum, BignumCtx, EcGroup, EcGroupId, EcGroupParams, EcPoint};
6use crate::internal::{FiniteCyclicGroup, SaeParameters};
7use crate::PweMethod;
8use anyhow::{bail, Error};
9use ieee80211::MacAddr;
10use log::warn;
11use num::integer::Integer;
12use num::ToPrimitive;
13
14pub struct Group {
16 id: EcGroupId,
17 group: EcGroup,
18 bn_ctx: BignumCtx,
19}
20
21impl Group {
22 pub fn new(ec_group: EcGroupId) -> Result<Self, Error> {
25 Ok(Self { id: ec_group.clone(), group: EcGroup::new(ec_group)?, bn_ctx: BignumCtx::new()? })
26 }
27}
28
29fn concat_mac_addrs(sta_a_mac: &MacAddr, sta_b_mac: &MacAddr) -> Vec<u8> {
31 let mut result: Vec<u8> = Vec::with_capacity(sta_a_mac.len() + sta_b_mac.len());
32 match sta_a_mac.cmp(sta_b_mac) {
33 std::cmp::Ordering::Less => {
34 result.extend_from_slice(sta_b_mac.as_slice());
35 result.extend_from_slice(sta_a_mac.as_slice());
36 }
37 _ => {
38 result.extend_from_slice(sta_a_mac.as_slice());
39 result.extend_from_slice(sta_b_mac.as_slice());
40 }
41 };
42
43 result
44}
45
46const MIN_PWE_ITER: u8 = 50;
53const KDF_LABEL: &'static str = "SAE Hunting and Pecking";
54
55fn compute_y_squared(x: &Bignum, curve: &EcGroupParams, ctx: &BignumCtx) -> Result<Bignum, Error> {
57 let y = x.mod_exp(&Bignum::new_from_u64(3)?, &curve.p, ctx)?;
59 let y = y.mod_add(&curve.a.mod_mul(&x, &curve.p, ctx)?, &curve.p, ctx)?;
61 y.mod_add(&curve.b, &curve.p, ctx)
63}
64
65#[derive(PartialEq, Debug)]
66enum LegendreSymbol {
67 QuadResidue,
68 NonQuadResidue,
69 ZeroCongruent,
70}
71
72fn legendre(a: &Bignum, p: &Bignum, ctx: &BignumCtx) -> Result<LegendreSymbol, Error> {
74 let exp = p.sub(Bignum::one()?)?.rshift1()?;
75 let res = a.mod_exp(&exp, p, ctx)?;
76 if res.is_one() {
77 Ok(LegendreSymbol::QuadResidue)
78 } else if res.is_zero() {
79 Ok(LegendreSymbol::ZeroCongruent)
80 } else {
81 Ok(LegendreSymbol::NonQuadResidue)
82 }
83}
84
85fn generate_qr_and_qnr(p: &Bignum, ctx: &BignumCtx) -> Result<(Bignum, Bignum), Error> {
87 let mut qr = Bignum::rand(p)?;
90 while legendre(&qr, p, ctx)? != LegendreSymbol::QuadResidue {
91 qr = Bignum::rand(p)?;
92 }
93
94 let mut qnr = Bignum::rand(p)?;
95 while legendre(&qnr, p, ctx)? != LegendreSymbol::NonQuadResidue {
96 qnr = Bignum::rand(p)?;
97 }
98
99 Ok((qr, qnr))
100}
101
102fn is_quadratic_residue_blind(
106 v: &Bignum,
107 p: &Bignum,
108 qr: &Bignum,
109 qnr: &Bignum,
110 ctx: &BignumCtx,
111) -> Result<bool, Error> {
112 let r = Bignum::rand(&p.sub(Bignum::one()?)?)?.add(Bignum::one()?)?;
114 let num = v.mod_mul(&r, p, ctx)?.mod_mul(&r, p, ctx)?;
115 if num.is_odd() {
116 let num = num.mod_mul(qr, p, ctx)?;
117 Ok(legendre(&num, p, ctx)? == LegendreSymbol::QuadResidue)
118 } else {
119 let num = num.mod_mul(qnr, p, ctx)?;
120 Ok(legendre(&num, p, ctx)? == LegendreSymbol::NonQuadResidue)
121 }
122}
123
124impl Group {
125 fn generate_pwe_loop(&self, params: &SaeParameters) -> Result<EcPoint, Error> {
127 if params.password_id.is_some() {
128 bail!("Password ID cannot be used with looping PWE generation");
130 }
131
132 let group_params = self.group.get_params(&self.bn_ctx)?;
133 let length = group_params.p.bits();
134 let p_vec = group_params.p.to_be_vec(group_params.p.len());
135 let (qr, qnr) = generate_qr_and_qnr(&group_params.p, &self.bn_ctx)?;
136 let mut x: Option<Bignum> = None;
138 let mut save: Option<Vec<u8>> = None;
139
140 let mut counter = 1;
141 while counter <= MIN_PWE_ITER || x.is_none() {
142 let pwd_seed = {
143 let salt = concat_mac_addrs(¶ms.sta_a_mac, ¶ms.sta_b_mac);
144 let mut ikm = params.password.clone();
145 ikm.push(counter as u8);
146 params.hmac.hkdf_extract(&salt[..], &ikm[..])
147 };
148 let pwd_value =
149 params.hmac.kdf_hash_length(&pwd_seed[..], KDF_LABEL, &p_vec[..], length);
150 let pwd_value = Bignum::new_from_slice(&pwd_value[..])?;
153 if pwd_value < group_params.p {
154 let y_squared = compute_y_squared(&pwd_value, &group_params, &self.bn_ctx)?;
155 if is_quadratic_residue_blind(&y_squared, &group_params.p, &qr, &qnr, &self.bn_ctx)?
156 {
157 if x.is_none() {
159 x = Some(pwd_value);
160 save = Some(pwd_seed);
161 }
162 }
163 }
164 counter += 1;
165 }
166
167 let x = x.unwrap();
169 let save = save.unwrap();
170
171 let y_squared = compute_y_squared(&x, &group_params, &self.bn_ctx)?;
173 let mut y = y_squared.mod_sqrt(&group_params.p, &self.bn_ctx)?;
174 if save[save.len() - 1].is_odd() != y.is_odd() {
176 y = group_params.p.copy()?.sub(y)?;
177 }
178 EcPoint::new_from_affine_coords(x, y, &self.group, &self.bn_ctx)
179 }
180
181 fn generate_sswu_z_c1_c2(&self) -> Result<(Bignum, Bignum, Bignum), Error> {
185 let group_params = self.group.get_params(&self.bn_ctx)?;
186 let z = match self.id {
187 EcGroupId::P256 => Bignum::new_from_u64(10)?.set_negative(),
188 EcGroupId::P384 => Bignum::new_from_u64(12)?.set_negative(),
189 EcGroupId::P521 => Bignum::new_from_u64(4)?.set_negative(),
190 };
191 let c1 = group_params
192 .b
193 .mod_mul(
194 &group_params.a.mod_inverse(&group_params.p, &self.bn_ctx)?,
195 &group_params.p,
196 &self.bn_ctx,
197 )?
198 .set_negative();
199 let c2 = z.mod_inverse(&group_params.p, &self.bn_ctx)?.set_negative();
200 Ok((z, c1, c2))
201 }
202
203 fn calculate_sswu(
218 &self,
219 u: &Bignum,
220 z: &Bignum,
221 c1: &Bignum,
222 c2: &Bignum,
223 qr: &Bignum,
224 qnr: &Bignum,
225 ) -> Result<EcPoint, Error> {
226 let group_params = self.group.get_params(&self.bn_ctx)?;
227 let p = &group_params.p;
228 let p_2 = p.sub(Bignum::new_from_u64(2)?)?;
229
230 let tv1 = z.mod_mul(&u.mod_square(p, &self.bn_ctx)?, p, &self.bn_ctx)?; let tv2 = tv1.mod_square(p, &self.bn_ctx)?; let x1 = tv1.mod_add(&tv2, p, &self.bn_ctx)?; let x1 = x1.mod_exp(&p_2, p, &self.bn_ctx)?; let e1 = x1.is_zero(); let x1 = x1.mod_add(&Bignum::one()?, p, &self.bn_ctx)?; let x1 = if e1 { c2 } else { &x1 }; let x1 = x1.mod_mul(c1, p, &self.bn_ctx)?; let gx1 = x1.mod_square(&group_params.p, &self.bn_ctx)?; let gx1 = gx1.mod_add(&group_params.a, p, &self.bn_ctx)?; let gx1 = gx1.mod_mul(&x1, p, &self.bn_ctx)?; let gx1 = gx1.mod_add(&group_params.b, p, &self.bn_ctx)?; let x2 = tv1.mod_mul(&x1, p, &self.bn_ctx)?; let tv2 = tv1.mod_mul(&tv2, p, &self.bn_ctx)?; let gx2 = gx1.mod_mul(&tv2, p, &self.bn_ctx)?; let e2 = is_quadratic_residue_blind(&gx1, p, qr, qnr, &self.bn_ctx)?; let x = if e2 { x1 } else { x2 }; let y2 = if e2 { gx1 } else { gx2 }; let y = y2.mod_sqrt(p, &self.bn_ctx)?; let e3 = u.is_odd() == y.is_odd(); let negative_y = p.sub(y.copy()?)?.mod_nonnegative(p, &self.bn_ctx)?;
251 let y = if e3 { y } else { negative_y }; EcPoint::new_from_affine_coords(x, y, &self.group, &self.bn_ctx)
254 }
255
256 fn generate_pt(&self, params: &SaeParameters) -> Result<EcPoint, Error> {
259 let group_params = self.group.get_params(&self.bn_ctx)?;
260 let p = &group_params.p;
261 let len = p.len() + (p.len() / 2);
262
263 let mut password_with_id = params.password.clone();
264 match ¶ms.password_id {
265 Some(password_id) => password_with_id.extend_from_slice(&password_id),
266 _ => (),
267 };
268 let pwd_seed = params.hmac.hkdf_extract(¶ms.ssid, &password_with_id);
269
270 let (z, c1, c2) = self.generate_sswu_z_c1_c2()?;
271 let (qr, qnr) = generate_qr_and_qnr(p, &self.bn_ctx)?;
272
273 let pwd_value_1 = params.hmac.hkdf_expand(&pwd_seed, "SAE Hash to Element u1 P1", len);
274 let u1 = Bignum::new_from_slice(&pwd_value_1)?;
275 let u1 = u1.mod_nonnegative(p, &self.bn_ctx)?;
276 let p1 = self.calculate_sswu(&u1, &z, &c1, &c2, &qr, &qnr)?;
277
278 let pwd_value_2 = params.hmac.hkdf_expand(&pwd_seed, "SAE Hash to Element u2 P2", len);
279 let u2 = Bignum::new_from_slice(&pwd_value_2)?;
280 let u2 = u2.mod_nonnegative(p, &self.bn_ctx)?;
281 let p2 = self.calculate_sswu(&u2, &z, &c1, &c1, &qr, &qnr)?;
282
283 self.elem_op(&p1, &p2)
284 }
285
286 fn generate_pwe_direct(&self, params: &SaeParameters) -> Result<EcPoint, Error> {
288 let pt = self.generate_pt(params)?;
292
293 let salt = vec![0u8; params.hmac.bits() / 8];
295 let ikm = concat_mac_addrs(¶ms.sta_a_mac, ¶ms.sta_b_mac);
296 let val = Bignum::new_from_slice(¶ms.hmac.hkdf_extract(&salt, &ikm))?;
297 let val = val
298 .mod_nonnegative(
299 &self.group.get_order(&self.bn_ctx)?.sub(Bignum::one()?)?,
300 &self.bn_ctx,
301 )?
302 .add(Bignum::one()?)?;
303
304 self.scalar_op(&val, &pt)
305 }
306}
307
308impl FiniteCyclicGroup for Group {
309 type Element = boringssl::EcPoint;
310
311 fn group_id(&self) -> u16 {
312 self.id.to_u16().unwrap()
313 }
314
315 fn generate_pwe(&self, params: &SaeParameters) -> Result<Self::Element, Error> {
316 match params.pwe_method {
317 PweMethod::Loop => self.generate_pwe_loop(params),
318 PweMethod::Direct => self.generate_pwe_direct(params),
319 }
320 }
321
322 fn scalar_op(&self, scalar: &Bignum, element: &Self::Element) -> Result<Self::Element, Error> {
323 element.mul(&self.group, &scalar, &self.bn_ctx)
324 }
325
326 fn elem_op(
327 &self,
328 element1: &Self::Element,
329 element2: &Self::Element,
330 ) -> Result<Self::Element, Error> {
331 element1.add(&self.group, &element2, &self.bn_ctx)
332 }
333
334 fn inverse_op(&self, element: Self::Element) -> Result<Self::Element, Error> {
335 element.invert(&self.group, &self.bn_ctx)
336 }
337
338 fn order(&self) -> Result<Bignum, Error> {
339 self.group.get_order(&self.bn_ctx)
340 }
341
342 fn map_to_secret_value(&self, element: &Self::Element) -> Result<Option<Vec<u8>>, Error> {
343 if element.is_point_at_infinity(&self.group) {
345 Ok(None)
346 } else {
347 let group_params = self.group.get_params(&self.bn_ctx)?;
348 let (x, _y) = element.to_affine_coords(&self.group, &self.bn_ctx)?;
349 Ok(Some(x.to_be_vec(group_params.p.len())))
350 }
351 }
352
353 fn element_to_octets(&self, element: &Self::Element) -> Result<Vec<u8>, Error> {
355 let group_params = self.group.get_params(&self.bn_ctx)?;
356 let length = group_params.p.len();
357 let (x, y) = element.to_affine_coords(&self.group, &self.bn_ctx)?;
358 let mut res = x.to_be_vec(length);
359 res.append(&mut y.to_be_vec(length));
360 Ok(res)
361 }
362
363 fn element_from_octets(&self, octets: &[u8]) -> Result<Option<Self::Element>, Error> {
365 let group_params = self.group.get_params(&self.bn_ctx)?;
366 let length = group_params.p.len();
367 if octets.len() != length * 2 {
368 warn!("element_from_octets called with wrong number of octets");
369 return Ok(None);
370 }
371 let x = Bignum::new_from_slice(&octets[0..length])?;
372 let y = Bignum::new_from_slice(&octets[length..])?;
373 Ok(EcPoint::new_from_affine_coords(x, y, &self.group, &self.bn_ctx).ok())
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::hmac_utils::HmacUtilsImpl;
381 use ieee80211::{MacAddr, Ssid};
382 use lazy_static::lazy_static;
383 use mundane::hash::Sha256;
384 use std::convert::TryFrom;
385
386 #[derive(Debug)]
387 struct SswuTestVector {
388 curve: EcGroupId,
389 u: &'static str,
390 q_x: &'static str,
391 q_y: &'static str,
392 }
393
394 const TEST_GROUP: EcGroupId = EcGroupId::P256;
397 const TEST_SSID: &'static str = "byteme";
398 const TEST_PWD: &'static str = "mekmitasdigoat";
399 const TEST_PWD_ID: &'static str = "psk4internet";
400
401 lazy_static! {
404 static ref TEST_LOOP_STA_A: MacAddr = MacAddr::from([0x4d, 0x3f, 0x2f, 0xff, 0xe3, 0x87]);
405 static ref TEST_LOOP_STA_B: MacAddr = MacAddr::from([0xa5, 0xd8, 0xaa, 0x95, 0x8e, 0x3c]);
406 }
407 const TEST_LOOP_PWE_X: &'static str =
408 "da6eb7b06a1ac5624974f90afdd6a8e9d5722634cf987c34defc91a9874e5658";
409 const TEST_LOOP_PWE_Y: &'static str =
410 "f4fefd130bd5be08fe68af3e4a290272ec065fd3671f3c25bf8ec419ddc9b822";
411
412 lazy_static! {
415 static ref TEST_DIRECT_STA_A: MacAddr = MacAddr::from([0x00, 0x09, 0x5b, 0x66, 0xec, 0x1e]);
416 static ref TEST_DIRECT_STA_B: MacAddr = MacAddr::from([0x00, 0x0b, 0x6b, 0xd9, 0x02, 0x46]);
417 }
418 const TEST_DIRECT_Z: &'static str =
419 "ffffffff00000001000000000000000000000000fffffffffffffffffffffff5";
420 const TEST_DIRECT_C1: &'static str =
421 "73976747e368dbf83bf93f1c7cdd823ecc5f023b441be5a76944bebf629b756e";
422 const TEST_DIRECT_C2: &'static str =
423 "e666666580000000e666666666666666666666674ccccccccccccccccccccccc";
424 const TEST_DIRECT_U1: &'static str =
425 "dc941bc3c6a2b4948b6c61d55590ecb1f0c51c4b1bebaff677e593698d5a53c6";
426 const TEST_DIRECT_U2: &'static str =
427 "1b8375a518bc21396ad6a65e5597e0bf80d793b6d66e2534a6e7dfe3ee22616f";
428 const TEST_DIRECT_P1_X: &'static str =
429 "a07c260764a13445ff8cd97c5acc644e7119bde51bad42583eed6f4109639e6b";
430 const TEST_DIRECT_P1_Y: &'static str =
431 "3bdc8df0d32337936c74df604933a454142251c53c576c0351b28deaf9428d7e";
432 const TEST_DIRECT_P2_X: &'static str =
433 "72cd2a967a837fea5051f0133db46227775ba09f7b6dfb99ae7a8ef22c7d34a0";
434 const TEST_DIRECT_P2_Y: &'static str =
435 "864390d797d352b368d311af515bde116fe54459fec867ee18a8a1619ca3ff59";
436 const TEST_DIRECT_PT_X: &'static str =
437 "b6e38c98750c684b5d17c3d8c9a4100b39931279187ca6cced5f37ef46ddfa97";
438 const TEST_DIRECT_PT_Y: &'static str =
439 "5687e972e50f73e3898861e7edad21bea7d5f622df88243bb804920ae8e647fa";
440 const TEST_DIRECT_PWE_X: &'static str =
441 "c93049b9e64000f848201649e999f2b5c22dea69b5632c9df4d633b8aa1f6c1e";
442 const TEST_DIRECT_PWE_Y: &'static str =
443 "73634e94b53d82e7383a8d258199d9dc1a5ee8269d060382ccbf33e614ff59a0";
444
445 const TEST_SSWU_CURVES: &'static [SswuTestVector] = &[
448 SswuTestVector {
449 curve: EcGroupId::P256,
450 u: "ad5342c66a6dd0ff080df1da0ea1c04b96e0330dd89406465eeba11582515009",
451 q_x: "ab640a12220d3ff283510ff3f4b1953d09fad35795140b1c5d64f313967934d5",
452 q_y: "dccb558863804a881d4fff3455716c836cef230e5209594ddd33d85c565b19b1",
453 },
454 SswuTestVector {
455 curve: EcGroupId::P256,
456 u: "8c0f1d43204bd6f6ea70ae8013070a1518b43873bcd850aafa0a9e220e2eea5a",
457 q_x: "51cce63c50d972a6e51c61334f0f4875c9ac1cd2d3238412f84e31da7d980ef5",
458 q_y: "b45d1a36d00ad90e5ec7840a60a4de411917fbe7c82c3949a6e699e5a1b66aac",
459 },
460 SswuTestVector {
461 curve: EcGroupId::P256,
462 u: "afe47f2ea2b10465cc26ac403194dfb68b7f5ee865cda61e9f3e07a537220af1",
463 q_x: "5219ad0ddef3cc49b714145e91b2f7de6ce0a7a7dc7406c7726c7e373c58cb48",
464 q_y: "7950144e52d30acbec7b624c203b1996c99617d0b61c2442354301b191d93ecf",
465 },
466 SswuTestVector {
467 curve: EcGroupId::P256,
468 u: "379a27833b0bfe6f7bdca08e1e83c760bf9a338ab335542704edcd69ce9e46e0",
469 q_x: "019b7cb4efcfeaf39f738fe638e31d375ad6837f58a852d032ff60c69ee3875f",
470 q_y: "589a62d2b22357fed5449bc38065b760095ebe6aeac84b01156ee4252715446e",
471 },
472 SswuTestVector {
473 curve: EcGroupId::P384,
474 u: "425c1d0b099ffa6c15069b08299e6e21a204e08c2a0627f5afc24215d19e45bc\
475 47d70da5972ff77e33f176b5e18e8485",
476 q_x: "4589af7986491d42b7ee23726c57abeade65c7b8eba12d07fbce48065a01a78c\
477 4b018c739034d9fabc2c4ef6176c7c40",
478 q_y: "5b2985027c29802bf2afdb8a3c95fa655ad3189a2118209bd285d420268bf71e\
479 610c9533e3f4f438ba4b64f66f6fbed9",
480 },
481 SswuTestVector {
482 curve: EcGroupId::P384,
483 u: "cbefdd543ed48b5a9bbbd460f559d23b388aa72157279ba02069231881eb2a94\
484 7d887a5b1e0a6173bc92a5700f679a14",
485 q_x: "cbd6c34a12a266b447b444b303d577cd5d61e3c0af19d4676ababb470bb79574\
486 1ebf167caa9f0910a4fcc899134596d7",
487 q_y: "63df08d5d3aa8090cbb94222b34aad35e1b11414d3aef8f1a26205c81b4d15bb\
488 be4faf25d77924705bf09afd8812d2f0",
489 },
490 SswuTestVector {
491 curve: EcGroupId::P521,
492 u: "01e5f09974e5724f25286763f00ce76238c7a6e03dc396600350ee2c4135fb17\
493 dc555be99a4a4bae0fd303d4f66d984ed7b6a3ba386093752a855d26d559d69e\
494 7e9e",
495 q_x: "00b70ae99b6339fffac19cb9bfde2098b84f75e50ac1e80d6acb954e4534af5f\
496 0e9c4a5b8a9c10317b8e6421574bae2b133b4f2b8c6ce4b3063da1d91d34fa2b\
497 3a3c",
498 q_y: "007f368d98a4ddbf381fb354de40e44b19e43bb11a1278759f4ea7b485e1b6db\
499 33e750507c071250e3e443c1aaed61f2c28541bb54b1b456843eda1eb15ec2a9\
500 b36e",
501 },
502 SswuTestVector {
503 curve: EcGroupId::P521,
504 u: "00ae593b42ca2ef93ac488e9e09a5fe5a2f6fb330d18913734ff602f2a761fca\
505 af5f596e790bcc572c9140ec03f6cccc38f767f1c1975a0b4d70b392d95a0c72\
506 78aa",
507 q_x: "01143d0e9cddcdacd6a9aafe1bcf8d218c0afc45d4451239e821f5d2a56df92b\
508 e942660b532b2aa59a9c635ae6b30e803c45a6ac871432452e685d661cd41cf6\
509 7214",
510 q_y: "00ff75515df265e996d702a5380defffab1a6d2bc232234c7bcffa433cd8aa79\
511 1fbc8dcf667f08818bffa739ae25773b32073213cae9a0f2a917a0b1301a242d\
512 da0c",
513 },
514 ];
515
516 const TEST_DIRECT_QR: &'static str =
518 "22d92ad59d5e2681443903612413e0da06650cf2ec4278fd1f4308418a2041b0";
519 const TEST_DIRECT_QNR: &'static str =
520 "07204a4749c26085a78cea57031524c21575d114d71f0e2ca7d742d7d99fdbe6";
521
522 fn make_group() -> Group {
523 let group = boringssl::EcGroup::new(TEST_GROUP).unwrap();
524 let bn_ctx = boringssl::BignumCtx::new().unwrap();
525 Group { id: TEST_GROUP, group, bn_ctx }
526 }
527
528 fn bn(value: u64) -> Bignum {
529 Bignum::new_from_u64(value).unwrap()
530 }
531
532 #[test]
533 fn get_group_id() {
534 let group = make_group();
535 assert_eq!(group.group_id(), TEST_GROUP.to_u16().unwrap());
536 }
537
538 #[test]
539 fn generate_pwe_loop() {
540 let group = make_group();
541 let group_params = group.group.get_params(&group.bn_ctx).unwrap();
542 let params = SaeParameters {
543 hmac: Box::new(HmacUtilsImpl::<Sha256>::new()),
544 pwe_method: PweMethod::Loop,
545 ssid: Ssid::try_from(TEST_SSID).unwrap(),
546 password: Vec::from(TEST_PWD),
547 password_id: None,
548 sta_a_mac: *TEST_LOOP_STA_A,
549 sta_b_mac: *TEST_LOOP_STA_B,
550 };
551 let pwe = group.generate_pwe(¶ms).unwrap();
552 let (x, y) = pwe.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
553 assert_eq!(x.to_be_vec(group_params.p.len()), hex::decode(TEST_LOOP_PWE_X).unwrap());
554 assert_eq!(y.to_be_vec(group_params.p.len()), hex::decode(TEST_LOOP_PWE_Y).unwrap());
555
556 let params =
558 SaeParameters { sta_a_mac: *TEST_LOOP_STA_B, sta_b_mac: *TEST_LOOP_STA_A, ..params };
559 let pwe = group.generate_pwe(¶ms).unwrap();
560 let (x, y) = pwe.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
561 assert_eq!(x.to_be_vec(group_params.p.len()), hex::decode(TEST_LOOP_PWE_X).unwrap());
562 assert_eq!(y.to_be_vec(group_params.p.len()), hex::decode(TEST_LOOP_PWE_Y).unwrap());
563 }
564
565 #[test]
566 fn generate_pwe_direct() {
567 let group = make_group();
568 let group_params = group.group.get_params(&group.bn_ctx).unwrap();
569 let params = SaeParameters {
570 hmac: Box::new(HmacUtilsImpl::<Sha256>::new()),
571 pwe_method: PweMethod::Direct,
572 ssid: Ssid::try_from(TEST_SSID).unwrap(),
573 password: Vec::from(TEST_PWD),
574 password_id: Some(Vec::from(TEST_PWD_ID)),
575 sta_a_mac: *TEST_DIRECT_STA_A,
576 sta_b_mac: *TEST_DIRECT_STA_B,
577 };
578 let pwe = group.generate_pwe(¶ms).unwrap();
579 let (x, y) = pwe.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
580 assert_eq!(x.to_be_vec(group_params.p.len()), hex::decode(TEST_DIRECT_PWE_X).unwrap());
581 assert_eq!(y.to_be_vec(group_params.p.len()), hex::decode(TEST_DIRECT_PWE_Y).unwrap());
582
583 let params = SaeParameters {
585 sta_a_mac: *TEST_DIRECT_STA_B,
586 sta_b_mac: *TEST_DIRECT_STA_A,
587 ..params
588 };
589 let pwe = group.generate_pwe(¶ms).unwrap();
590 let (x, y) = pwe.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
591 assert_eq!(x.to_be_vec(group_params.p.len()), hex::decode(TEST_DIRECT_PWE_X).unwrap());
592 assert_eq!(y.to_be_vec(group_params.p.len()), hex::decode(TEST_DIRECT_PWE_Y).unwrap());
593 }
594
595 #[test]
596 fn generate_pwe_loop_no_pwd_id() {
597 let group = make_group();
598 let params = SaeParameters {
599 hmac: Box::new(HmacUtilsImpl::<Sha256>::new()),
600 pwe_method: PweMethod::Loop,
601 ssid: Ssid::try_from(TEST_SSID).unwrap(),
602 password: Vec::from(TEST_PWD),
603 password_id: Some(Vec::from(TEST_PWD_ID)),
604 sta_a_mac: *TEST_LOOP_STA_A,
605 sta_b_mac: *TEST_LOOP_STA_B,
606 };
607 let pwe = group.generate_pwe(¶ms);
608 assert!(pwe.is_err());
610 }
611
612 #[test]
613 fn test_legendre() {
614 let ctx = BignumCtx::new().unwrap();
616 assert_eq!(legendre(&bn(13), &bn(23), &ctx).unwrap(), LegendreSymbol::QuadResidue);
617 assert_eq!(legendre(&bn(19), &bn(23), &ctx).unwrap(), LegendreSymbol::NonQuadResidue);
618 assert_eq!(legendre(&bn(26), &bn(13), &ctx).unwrap(), LegendreSymbol::ZeroCongruent);
619 }
620
621 #[test]
622 fn generate_qr_qnr() {
623 let ctx = BignumCtx::new().unwrap();
625 let (qr, qnr) = generate_qr_and_qnr(&bn(3), &ctx).unwrap();
626 assert_eq!(qr, bn(1));
627 assert_eq!(qnr, bn(2));
628 }
629
630 #[test]
631 fn quadratic_residue_blind() {
632 let qr_table = [
634 false, true, false, false, true, false, true, false, false, true, true, false, false,
635 false, true, true, true, true, false, true, false, true, true, true, true, true, true,
636 false, false, true, false,
637 ];
638 let prime = bn(67);
639 let ctx = BignumCtx::new().unwrap();
640 let (qr, qnr) = generate_qr_and_qnr(&prime, &ctx).unwrap();
641 qr_table.iter().enumerate().for_each(|(i, _is_residue)| {
642 assert_eq!(
643 qr_table[i],
644 is_quadratic_residue_blind(&bn(i as u64), &prime, &qr, &qnr, &ctx).unwrap()
645 )
646 });
647 }
648
649 #[test]
650 fn calculate_sswu() {
651 {
653 let group = make_group();
654 let group_params = group.group.get_params(&group.bn_ctx).unwrap();
655 let p = &group_params.p;
656
657 let z = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_Z).unwrap()).unwrap();
658 let c1 = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_C1).unwrap()).unwrap();
659 let c2 = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_C2).unwrap()).unwrap();
660 let qr = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_QR).unwrap()).unwrap();
661 let qnr = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_QNR).unwrap()).unwrap();
662
663 let u1 = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_U1).unwrap()).unwrap();
664 let p1 = group.calculate_sswu(&u1, &z, &c1, &c2, &qr, &qnr).unwrap();
665 let (p1_x, p1_y) = p1.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
666 assert_eq!(hex::encode(p1_x.to_be_vec(p.len())), TEST_DIRECT_P1_X);
667 assert_eq!(hex::encode(p1_y.to_be_vec(p.len())), TEST_DIRECT_P1_Y);
668
669 let u2 = Bignum::new_from_slice(&hex::decode(TEST_DIRECT_U2).unwrap()).unwrap();
670 let p2 = group.calculate_sswu(&u2, &z, &c1, &c2, &qr, &qnr).unwrap();
671 let (p2_x, p2_y) = p2.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
672 assert_eq!(hex::encode(p2_x.to_be_vec(p.len())), TEST_DIRECT_P2_X);
673 assert_eq!(hex::encode(p2_y.to_be_vec(p.len())), TEST_DIRECT_P2_Y);
674 }
675
676 for vector in TEST_SSWU_CURVES {
678 let group = boringssl::EcGroup::new(vector.curve).unwrap();
679 let bn_ctx = boringssl::BignumCtx::new().unwrap();
680 let group = Group { id: vector.curve, group, bn_ctx };
681 let group_params = group.group.get_params(&group.bn_ctx).unwrap();
682 let p = &group_params.p;
683
684 let (z, c1, c2) = group.generate_sswu_z_c1_c2().unwrap();
685 let (qr, qnr) = generate_qr_and_qnr(p, &group.bn_ctx).unwrap();
686 let u = Bignum::new_from_slice(&hex::decode(vector.u).unwrap()).unwrap();
687 let (q_x, q_y) = group
688 .calculate_sswu(&u, &z, &c1, &c2, &qr, &qnr)
689 .unwrap()
690 .to_affine_coords(&group.group, &group.bn_ctx)
691 .unwrap();
692 assert_eq!(
693 hex::encode(q_x.to_be_vec(p.len())),
694 vector.q_x,
695 "test vector: {:?}",
696 vector
697 );
698 assert_eq!(
699 hex::encode(q_y.to_be_vec(p.len())),
700 vector.q_y,
701 "test vector: {:?}",
702 vector
703 );
704 }
705 }
706
707 #[test]
708 fn generate_pt() {
709 let group = make_group();
710 let params = SaeParameters {
711 hmac: Box::new(HmacUtilsImpl::<Sha256>::new()),
712 pwe_method: PweMethod::Direct,
713 ssid: Ssid::try_from(TEST_SSID).unwrap(),
714 password: Vec::from(TEST_PWD),
715 password_id: Some(Vec::from(TEST_PWD_ID)),
716 sta_a_mac: *TEST_DIRECT_STA_A,
717 sta_b_mac: *TEST_DIRECT_STA_B,
718 };
719
720 let pt = group.generate_pt(¶ms).unwrap();
721 let (pt_x, pt_y) = pt.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
722 assert_eq!(hex::encode(pt_x.to_be_vec(0)), TEST_DIRECT_PT_X);
723 assert_eq!(hex::encode(pt_y.to_be_vec(0)), TEST_DIRECT_PT_Y);
724 }
725
726 #[test]
727 fn test_element_to_octets() {
728 let x = Bignum::new_from_slice(&hex::decode(TEST_LOOP_PWE_X).unwrap()).unwrap();
729 let y = Bignum::new_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap()).unwrap();
730 let group = make_group();
731 let element = EcPoint::new_from_affine_coords(x, y, &group.group, &group.bn_ctx).unwrap();
732
733 let octets = group.element_to_octets(&element).unwrap();
734 let mut expected = hex::decode(TEST_LOOP_PWE_X).unwrap();
735 expected.extend_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap());
736 assert_eq!(octets, expected);
737 }
738
739 #[test]
740 fn test_element_to_octets_padding() {
741 let group = make_group();
742 let params = group.group.get_params(&group.bn_ctx).unwrap();
743 let x = bn(0xffffffff);
746 let y = compute_y_squared(&x, ¶ms, &group.bn_ctx)
747 .unwrap()
748 .mod_sqrt(¶ms.p, &group.bn_ctx)
749 .unwrap();
750 let element = EcPoint::new_from_affine_coords(x, y, &group.group, &group.bn_ctx).unwrap();
751
752 let octets = group.element_to_octets(&element).unwrap();
753 let mut expected_x = vec![0x00; 28];
754 expected_x.extend_from_slice(&[0xff; 4]);
755 assert_eq!(octets.len(), 64);
756 assert_eq!(&octets[0..32], &expected_x[0..32]);
757 }
758
759 #[test]
760 fn test_element_from_octets() {
761 let mut octets = hex::decode(TEST_LOOP_PWE_X).unwrap();
762 octets.extend_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap());
763 let group = make_group();
764 let element = group.element_from_octets(&octets).unwrap();
765 assert!(element.is_some());
766 let element = element.unwrap();
767
768 let expected_x = Bignum::new_from_slice(&hex::decode(TEST_LOOP_PWE_X).unwrap()).unwrap();
769 let expected_y = Bignum::new_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap()).unwrap();
770 let (x, y) = element.to_affine_coords(&group.group, &group.bn_ctx).unwrap();
771
772 assert_eq!(x, expected_x);
773 assert_eq!(y, expected_y);
774 }
775
776 #[test]
777 fn test_element_from_octets_padded() {
778 let mut octets = hex::decode(TEST_LOOP_PWE_X).unwrap();
779 octets.extend_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap());
780 octets.extend_from_slice(&[0xff; 10]);
781 let group = make_group();
782 let element = group.element_from_octets(&octets).unwrap();
783 assert!(element.is_none());
784 }
785
786 #[test]
787 fn test_element_from_octets_truncated() {
788 let mut octets = hex::decode(TEST_LOOP_PWE_X).unwrap();
789 octets.extend_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap());
790 octets.truncate(octets.len() - 10);
791 let group = make_group();
792 let element = group.element_from_octets(&octets).unwrap();
793 assert!(element.is_none());
794 }
795
796 #[test]
797 fn test_element_from_octets_bad_point() {
798 let mut octets = hex::decode(TEST_LOOP_PWE_X).unwrap();
799 octets.extend_from_slice(&hex::decode(TEST_LOOP_PWE_Y).unwrap());
800 let idx = octets.len() - 1;
801 octets[idx] += 1; let group = make_group();
803 let element = group.element_from_octets(&octets).unwrap();
804 assert!(element.is_none());
805 }
806}