1use crate::integrity::{self, integrity_algorithm};
6use crate::key::exchange::Key;
7use crate::keywrap::{self, keywrap_algorithm};
8
9use crate::{Error, ProtectionInfo, rsn_ensure};
10use anyhow::{anyhow, ensure};
11use fidl_fuchsia_wlan_mlme::SaeFrame;
12use wlan_common::ie::rsn::akm::Akm;
13use wlan_common::ie::rsn::cipher::{CIPHER_BIP_CMAC_128, Cipher, GROUP_CIPHER_SUITE, TKIP};
14use wlan_common::ie::rsn::rsne::{RsnCapabilities, Rsne};
15use wlan_common::ie::wpa::WpaIe;
16
17use zerocopy::SplitByteSlice;
18
19pub mod esssa;
20#[cfg(test)]
21pub mod test_util;
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum ProtectionType {
25 LegacyWpa1,
26 Rsne,
27}
28
29#[derive(Debug)]
30pub enum IgtkSupport {
31 Unsupported,
32 Capable,
33 Required,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct NegotiatedProtection {
38 pub group_data: Cipher,
39 pub pairwise: Cipher,
40 pub group_mgmt: Option<Cipher>,
41 pub akm: Akm,
42 pub mic_size: u16,
43 pub protection_type: ProtectionType,
44 caps: Option<RsnCapabilities>,
47}
48
49impl NegotiatedProtection {
50 pub fn from_protection(protection: &ProtectionInfo) -> Result<Self, anyhow::Error> {
51 match protection {
52 ProtectionInfo::Rsne(rsne) => Self::from_rsne(rsne),
53 ProtectionInfo::LegacyWpa(wpa) => Self::from_legacy_wpa(wpa),
54 }
55 }
56
57 fn key_descriptor_version(&self) -> u16 {
58 let key_descriptor_type = match self.protection_type {
59 ProtectionType::LegacyWpa1 => eapol::KeyDescriptor::LEGACY_WPA1,
60 ProtectionType::Rsne => eapol::KeyDescriptor::IEEE802DOT11,
61 };
62 derive_key_descriptor_version(key_descriptor_type, self)
63 }
64
65 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
66 pub fn integrity_algorithm(&self) -> Result<Box<dyn integrity::Algorithm>, Error> {
67 integrity_algorithm(self.key_descriptor_version(), &self.akm)
68 .ok_or(Error::UnknownIntegrityAlgorithm)
69 }
70
71 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
72 pub fn keywrap_algorithm(&self) -> Result<Box<dyn keywrap::Algorithm>, Error> {
73 keywrap_algorithm(self.key_descriptor_version(), &self.akm)
74 .ok_or(Error::UnknownKeywrapAlgorithm)
75 }
76
77 pub fn from_rsne(rsne: &Rsne) -> Result<Self, anyhow::Error> {
80 rsne.ensure_valid_s_rsne()
81 .map_err(|e| anyhow!(e).context(Error::InvalidNegotiatedProtection))?;
82
83 let group_data = rsne.group_data_cipher_suite.as_ref().unwrap();
86 let pairwise = &rsne.pairwise_cipher_suites[0];
87 let akm = &rsne.akm_suites[0];
88 let mic_size = akm.mic_bytes();
89 let mic_size = mic_size.unwrap();
90
91 Ok(Self {
92 group_data: group_data.clone(),
93 pairwise: pairwise.clone(),
94 group_mgmt: rsne.group_mgmt_cipher_suite.clone(),
95 akm: akm.clone(),
96 mic_size,
97 protection_type: ProtectionType::Rsne,
98 caps: rsne.rsn_capabilities.clone(),
99 })
100 }
101
102 pub fn from_legacy_wpa(wpa: &WpaIe) -> Result<Self, anyhow::Error> {
105 ensure!(wpa.unicast_cipher_list.len() == 1, Error::InvalidNegotiatedProtection);
106 ensure!(wpa.akm_list.len() == 1, Error::InvalidNegotiatedProtection);
107 let akm = wpa.akm_list[0].clone();
108 let mic_size = akm.mic_bytes().ok_or(Error::InvalidNegotiatedProtection)?;
109 let group_data = wpa.multicast_cipher.clone();
110 let pairwise = wpa.unicast_cipher_list[0].clone();
111 Ok(Self {
112 group_data,
113 pairwise,
114 group_mgmt: None,
115 akm,
116 mic_size,
117 protection_type: ProtectionType::LegacyWpa1,
118 caps: None,
119 })
120 }
121
122 pub fn to_full_protection(&self) -> ProtectionInfo {
125 match self.protection_type {
126 ProtectionType::Rsne => ProtectionInfo::Rsne(Rsne {
127 group_data_cipher_suite: Some(self.group_data.clone()),
128 pairwise_cipher_suites: vec![self.pairwise.clone()],
129 group_mgmt_cipher_suite: self.group_mgmt.clone(),
130 akm_suites: vec![self.akm.clone()],
131 rsn_capabilities: self.caps.clone(),
132 ..Default::default()
133 }),
134 ProtectionType::LegacyWpa1 => ProtectionInfo::LegacyWpa(WpaIe {
135 multicast_cipher: self.group_data.clone(),
136 unicast_cipher_list: vec![self.pairwise.clone()],
137 akm_list: vec![self.akm.clone()],
138 }),
139 }
140 }
141
142 pub fn igtk_support(&self) -> IgtkSupport {
143 match &self.caps {
144 Some(caps) => {
145 if caps.mgmt_frame_protection_req() {
146 IgtkSupport::Required
147 } else if caps.mgmt_frame_protection_cap() {
148 IgtkSupport::Capable
149 } else {
150 IgtkSupport::Unsupported
151 }
152 }
153 None => IgtkSupport::Unsupported,
154 }
155 }
156
157 pub fn group_mgmt_cipher(&self) -> Cipher {
158 self.group_mgmt.clone().unwrap_or(CIPHER_BIP_CMAC_128)
160 }
161}
162
163pub struct EncryptedKeyData<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
165
166impl<B: SplitByteSlice> EncryptedKeyData<B> {
167 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
168 pub fn decrypt(
171 self,
172 kek: &[u8],
173 protection: &NegotiatedProtection,
174 ) -> Result<(eapol::KeyFrameRx<B>, Vec<u8>), Error> {
175 let key_data = protection.keywrap_algorithm()?.unwrap_key(
176 kek,
177 &self.0.key_frame_fields.key_iv,
178 &self.0.key_data[..],
179 )?;
180 Ok((self.0, key_data))
181 }
182}
183
184#[derive(Debug)]
186pub struct WithUnverifiedMic<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
187
188impl<B: SplitByteSlice> WithUnverifiedMic<B> {
189 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
190 pub fn verify_mic(
194 self,
195 kck: &[u8],
196 protection: &NegotiatedProtection,
197 ) -> Result<UnverifiedKeyData<B>, Error> {
198 let mic_bytes = protection.akm.mic_bytes().ok_or(Error::UnsupportedAkmSuite)?;
201 rsn_ensure!(self.0.key_mic.len() == mic_bytes as usize, Error::InvalidMicSize);
202
203 let buf = self.0.to_bytes(true);
205 let valid_mic =
206 protection.integrity_algorithm()?.verify(kck, &buf[..], &self.0.key_mic[..]);
207 rsn_ensure!(valid_mic, Error::InvalidMic);
208
209 if self.0.key_frame_fields.key_info().encrypted_key_data() {
210 Ok(UnverifiedKeyData::Encrypted(EncryptedKeyData(self.0)))
211 } else {
212 Ok(UnverifiedKeyData::NotEncrypted(self.0))
213 }
214 }
215}
216
217pub enum UnverifiedKeyData<B: SplitByteSlice> {
220 Encrypted(EncryptedKeyData<B>),
221 NotEncrypted(eapol::KeyFrameRx<B>),
222}
223
224#[derive(Debug)]
227pub enum Dot11VerifiedKeyFrame<B: SplitByteSlice> {
228 WithUnverifiedMic(WithUnverifiedMic<B>),
229 WithoutMic(eapol::KeyFrameRx<B>),
230}
231
232impl<B: SplitByteSlice> Dot11VerifiedKeyFrame<B> {
233 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
241 pub fn from_frame(
242 frame: eapol::KeyFrameRx<B>,
243 role: &Role,
244 protection: &NegotiatedProtection,
245 key_replay_counter: u64,
246 ) -> Result<Dot11VerifiedKeyFrame<B>, Error> {
247 let sender = match role {
248 Role::Supplicant => Role::Authenticator,
249 Role::Authenticator => Role::Supplicant,
250 };
251
252 let key_descriptor = match frame.key_frame_fields.descriptor_type {
255 eapol::KeyDescriptor::IEEE802DOT11 => eapol::KeyDescriptor::IEEE802DOT11,
256 eapol::KeyDescriptor::LEGACY_WPA1
257 if protection.protection_type == ProtectionType::LegacyWpa1 =>
258 {
259 eapol::KeyDescriptor::LEGACY_WPA1
260 }
261 eapol::KeyDescriptor::RC4 => {
262 return Err(Error::InvalidKeyDescriptor(
263 frame.key_frame_fields.descriptor_type,
264 eapol::KeyDescriptor::IEEE802DOT11,
265 )
266 .into());
267 }
268 _ => {
270 return Err(
271 Error::UnsupportedKeyDescriptor(frame.key_frame_fields.descriptor_type).into()
272 );
273 }
274 };
275
276 let frame_key_descriptor_version =
278 frame.key_frame_fields.key_info().key_descriptor_version();
279 let expected_version = derive_key_descriptor_version(key_descriptor, protection);
280 rsn_ensure!(
281 frame_key_descriptor_version == expected_version,
282 Error::UnsupportedKeyDescriptorVersion(frame_key_descriptor_version)
283 );
284
285 match frame.key_frame_fields.key_info().key_type() {
288 eapol::KeyType::PAIRWISE => {}
289 eapol::KeyType::GROUP_SMK => {
290 rsn_ensure!(
292 !frame.key_frame_fields.key_info().install(),
293 Error::InvalidInstallBitGroupSmkHandshake
294 );
295 }
296 };
297
298 if let Role::Supplicant = sender {
300 rsn_ensure!(
301 !frame.key_frame_fields.key_info().key_ack(),
302 Error::InvalidKeyAckBitSupplicant
303 );
304 }
305
306 if let Role::Authenticator = sender {
314 rsn_ensure!(
315 !frame.key_frame_fields.key_info().error(),
316 Error::InvalidErrorBitAuthenticator
317 );
318 }
319
320 if let Role::Authenticator = sender {
322 rsn_ensure!(
323 !frame.key_frame_fields.key_info().request(),
324 Error::InvalidRequestBitAuthenticator
325 );
326 }
327
328 rsn_ensure!(
333 !frame.key_frame_fields.key_info().smk_message(),
334 Error::SmkHandshakeNotSupported
335 );
336
337 match frame.key_frame_fields.key_info().key_type() {
339 eapol::KeyType::PAIRWISE => match sender {
340 Role::Supplicant if frame.key_frame_fields.key_len.get() != 0 => {
349 let tk_len =
350 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?;
351 rsn_ensure!(
352 frame.key_frame_fields.key_len.get() == tk_len.into(),
353 Error::InvalidKeyLength(
354 frame.key_frame_fields.key_len.get().into(),
355 tk_len.into()
356 )
357 );
358 }
359 Role::Authenticator => {
361 let tk_len: usize =
362 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?.into();
363 rsn_ensure!(
364 usize::from(frame.key_frame_fields.key_len.get()) == tk_len,
365 Error::InvalidKeyLength(
366 frame.key_frame_fields.key_len.get().into(),
367 tk_len
368 )
369 );
370 }
371 _ => {}
372 },
373 eapol::KeyType::GROUP_SMK => {}
377 };
378
379 if key_replay_counter > 0 {
380 match sender {
381 Role::Supplicant => {
384 rsn_ensure!(
385 frame.key_frame_fields.key_replay_counter.get() >= key_replay_counter,
386 Error::InvalidKeyReplayCounter(
387 frame.key_frame_fields.key_replay_counter.get(),
388 key_replay_counter
389 )
390 );
391 }
392 Role::Authenticator => {
398 rsn_ensure!(
399 frame.key_frame_fields.key_replay_counter.get() > key_replay_counter,
400 Error::InvalidKeyReplayCounter(
401 frame.key_frame_fields.key_replay_counter.get(),
402 key_replay_counter
403 )
404 );
405 }
406 }
407 }
408
409 if frame.key_frame_fields.key_info().encrypted_key_data() {
412 rsn_ensure!(
413 frame.key_frame_fields.key_info().key_mic(),
414 Error::InvalidMicBitForEncryptedKeyData
415 );
416 }
417
418 if frame.key_frame_fields.key_info().key_mic() {
437 Ok(Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)))
438 } else {
439 Ok(Dot11VerifiedKeyFrame::WithoutMic(frame))
440 }
441 }
442
443 pub fn unsafe_get_raw(&self) -> &eapol::KeyFrameRx<B> {
447 match self {
448 Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)) => frame,
449 Dot11VerifiedKeyFrame::WithoutMic(frame) => frame,
450 }
451 }
452}
453
454pub fn derive_key_descriptor_version(
457 key_descriptor_type: eapol::KeyDescriptor,
458 protection: &NegotiatedProtection,
459) -> u16 {
460 let akm = &protection.akm;
461 let pairwise = &protection.pairwise;
462
463 if !akm.has_known_algorithm() || !pairwise.has_known_usage() {
464 return 0;
465 }
466
467 match akm.suite_type {
468 1 | 2 => match key_descriptor_type {
469 eapol::KeyDescriptor::RC4 => match pairwise.suite_type {
470 TKIP | GROUP_CIPHER_SUITE => 1,
471 _ => 0,
472 },
473 eapol::KeyDescriptor::IEEE802DOT11 | eapol::KeyDescriptor::LEGACY_WPA1 => {
474 if pairwise.suite_type == TKIP || pairwise.suite_type == GROUP_CIPHER_SUITE {
475 1
476 } else if pairwise.is_enhanced() || protection.group_data.is_enhanced() {
477 2
478 } else {
479 0
480 }
481 }
482 _ => 0,
483 },
484 3..=6 => 3,
487 _ => 0,
488 }
489}
490
491#[derive(Debug, Clone, Copy, PartialEq)]
492pub enum Role {
493 Authenticator,
494 Supplicant,
495}
496
497#[derive(Debug, PartialEq, Clone, Copy)]
498pub enum SecAssocStatus {
499 WrongPassword,
500 PmkSaEstablished,
501 EssSaEstablished,
502}
503
504#[derive(Debug, PartialEq, Clone, Copy)]
505pub enum AuthRejectedReason {
506 AuthFailed,
508 TooManyRetries,
510 PmksaExpired,
512}
513
514#[derive(Debug, PartialEq, Clone, Copy)]
515pub enum AuthStatus {
516 Success,
517 Rejected(AuthRejectedReason),
518 InternalError,
519}
520
521#[derive(Debug, PartialEq, Clone)]
522pub enum SecAssocUpdate {
523 TxEapolKeyFrame {
524 frame: eapol::KeyFrameBuf,
525 expect_response: bool,
528 },
529 Key(Key),
530 Status(SecAssocStatus),
531 TxSaeFrame(SaeFrame),
533 SaeAuthStatus(AuthStatus),
534 ScheduleSaeTimeout(u64),
535}
536
537pub type UpdateSink = Vec<SecAssocUpdate>;
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use assert_matches::assert_matches;
543 use wlan_common::ie::rsn::akm::{self, AKM_PSK};
544 use wlan_common::ie::rsn::cipher::{self, CIPHER_CCMP_128, CIPHER_GCMP_256};
545 use wlan_common::ie::rsn::fake_wpa2_s_rsne;
546 use zerocopy::byteorder::big_endian::U16;
547
548 #[test]
549 fn test_negotiated_protection_from_rsne() {
550 let rsne = Rsne {
551 group_data_cipher_suite: Some(CIPHER_GCMP_256),
552 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
553 akm_suites: vec![AKM_PSK],
554 ..Default::default()
555 };
556 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
557
558 let rsne = Rsne::wpa3_rsne();
559 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
560
561 let rsne = Rsne {
562 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
563 akm_suites: vec![AKM_PSK],
564 ..Default::default()
565 };
566 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
567
568 let rsne = Rsne {
569 group_data_cipher_suite: Some(CIPHER_CCMP_128),
570 akm_suites: vec![AKM_PSK],
571 ..Default::default()
572 };
573 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
574
575 let rsne = Rsne {
576 group_data_cipher_suite: Some(CIPHER_CCMP_128),
577 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
578 ..Default::default()
579 };
580 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
581 }
582
583 #[test]
587 fn test_supplicant_sends_zeroed_and_non_zeroed_key_length() {
588 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
589 .expect("could not derive negotiated RSNE");
590 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
591
592 let msg1 = env.initiate(11.into());
594 let (msg2_base, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
595
596 let mut buf = vec![];
598 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
599 msg2.key_frame_fields.key_len = U16::new(0);
600 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
601 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
602 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
603
604 let mut buf = vec![];
607 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
608 msg2.key_frame_fields.key_len = U16::new(16);
609 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
610 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
611 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
612 }
613
614 #[test]
617 fn test_supplicant_sends_random_key_length() {
618 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
619
620 let msg1 = env.initiate(12.into());
622 let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 12.into());
623 let mut buf = vec![];
624 let mut msg2 = msg2.copy_keyframe_mut(&mut buf);
625
626 msg2.key_frame_fields.key_len = U16::new(29);
627 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
628
629 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
630 .expect("could not derive negotiated RSNE");
631 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
632 assert!(result.is_err(), "successfully verified illegal message");
633 }
634
635 #[test]
636 fn test_to_rsne() {
637 let rsne = Rsne::wpa2_rsne();
638 let negotiated_protection = NegotiatedProtection::from_rsne(&rsne)
639 .expect("error, could not create negotiated RSNE")
640 .to_full_protection();
641 assert_matches!(negotiated_protection, ProtectionInfo::Rsne(actual_protection) => {
642 assert_eq!(actual_protection, rsne);
643 });
644 }
645
646 #[test]
647 fn test_to_legacy_wpa() {
648 let wpa_ie = make_wpa(Some(cipher::TKIP), vec![cipher::TKIP], vec![akm::PSK]);
649 let negotiated_protection = NegotiatedProtection::from_legacy_wpa(&wpa_ie)
650 .expect("error, could not create negotiated WPA")
651 .to_full_protection();
652 assert_matches!(negotiated_protection, ProtectionInfo::LegacyWpa(actual_protection) => {
653 assert_eq!(actual_protection, wpa_ie);
654 });
655 }
656
657 #[test]
658 fn test_igtk_support() {
659 let rsne = Rsne::wpa3_rsne();
661 let negotiated_protection =
662 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
663 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Required);
664 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
665
666 let mut rsne = Rsne::wpa3_rsne();
668 rsne.rsn_capabilities.replace(RsnCapabilities(0).with_mgmt_frame_protection_cap(true));
669 let negotiated_protection =
670 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
671 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Capable);
672
673 let rsne = Rsne::wpa2_rsne();
675 let negotiated_protection =
676 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
677 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Unsupported);
678 }
679
680 #[test]
681 fn test_default_igtk_cipher() {
682 let mut rsne = Rsne::wpa3_rsne();
683 rsne.group_mgmt_cipher_suite.take(); let negotiated_protection =
685 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
686 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Required);
687 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
688 }
689
690 fn make_wpa(unicast: Option<u8>, multicast: Vec<u8>, akms: Vec<u8>) -> WpaIe {
691 WpaIe {
692 multicast_cipher: unicast
693 .map(cipher::Cipher::new_dot11)
694 .expect("failed to make wpa ie!"),
695 unicast_cipher_list: multicast.into_iter().map(cipher::Cipher::new_dot11).collect(),
696 akm_list: akms.into_iter().map(akm::Akm::new_dot11).collect(),
697 }
698 }
699}