1use crate::integrity::{self, integrity_algorithm};
6use crate::key::exchange::Key;
7use crate::keywrap::{self, keywrap_algorithm};
8
9use crate::{rsn_ensure, Error, ProtectionInfo};
10use anyhow::{anyhow, ensure};
11use fidl_fuchsia_wlan_mlme::SaeFrame;
12use wlan_common::ie::rsn::akm::Akm;
13use wlan_common::ie::rsn::cipher::{Cipher, CIPHER_BIP_CMAC_128, GROUP_CIPHER_SUITE, TKIP};
14use wlan_common::ie::rsn::rsne::{RsnCapabilities, Rsne};
15use wlan_common::ie::wpa::WpaIe;
16use zerocopy::SplitByteSlice;
17
18pub mod esssa;
19#[cfg(test)]
20pub mod test_util;
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ProtectionType {
24 LegacyWpa1,
25 Rsne,
26}
27
28#[derive(Debug)]
29pub enum IgtkSupport {
30 Unsupported,
31 Capable,
32 Required,
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub struct NegotiatedProtection {
37 pub group_data: Cipher,
38 pub pairwise: Cipher,
39 pub group_mgmt: Option<Cipher>,
40 pub akm: Akm,
41 pub mic_size: u16,
42 pub protection_type: ProtectionType,
43 caps: Option<RsnCapabilities>,
46}
47
48impl NegotiatedProtection {
49 pub fn from_protection(protection: &ProtectionInfo) -> Result<Self, anyhow::Error> {
50 match protection {
51 ProtectionInfo::Rsne(rsne) => Self::from_rsne(rsne),
52 ProtectionInfo::LegacyWpa(wpa) => Self::from_legacy_wpa(wpa),
53 }
54 }
55
56 fn key_descriptor_version(&self) -> u16 {
57 let key_descriptor_type = match self.protection_type {
58 ProtectionType::LegacyWpa1 => eapol::KeyDescriptor::LEGACY_WPA1,
59 ProtectionType::Rsne => eapol::KeyDescriptor::IEEE802DOT11,
60 };
61 derive_key_descriptor_version(key_descriptor_type, self)
62 }
63
64 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
65 pub fn integrity_algorithm(&self) -> Result<Box<dyn integrity::Algorithm>, Error> {
66 integrity_algorithm(self.key_descriptor_version(), &self.akm)
67 .ok_or(Error::UnknownIntegrityAlgorithm)
68 }
69
70 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
71 pub fn keywrap_algorithm(&self) -> Result<Box<dyn keywrap::Algorithm>, Error> {
72 keywrap_algorithm(self.key_descriptor_version(), &self.akm)
73 .ok_or(Error::UnknownKeywrapAlgorithm)
74 }
75
76 pub fn from_rsne(rsne: &Rsne) -> Result<Self, anyhow::Error> {
79 rsne.ensure_valid_s_rsne()
80 .map_err(|e| anyhow!(e).context(Error::InvalidNegotiatedProtection))?;
81
82 let group_data = rsne.group_data_cipher_suite.as_ref().unwrap();
85 let pairwise = &rsne.pairwise_cipher_suites[0];
86 let akm = &rsne.akm_suites[0];
87 let mic_size = akm.mic_bytes();
88 let mic_size = mic_size.unwrap();
89
90 Ok(Self {
91 group_data: group_data.clone(),
92 pairwise: pairwise.clone(),
93 group_mgmt: rsne.group_mgmt_cipher_suite.clone(),
94 akm: akm.clone(),
95 mic_size,
96 protection_type: ProtectionType::Rsne,
97 caps: rsne.rsn_capabilities.clone(),
98 })
99 }
100
101 pub fn from_legacy_wpa(wpa: &WpaIe) -> Result<Self, anyhow::Error> {
104 ensure!(wpa.unicast_cipher_list.len() == 1, Error::InvalidNegotiatedProtection);
105 ensure!(wpa.akm_list.len() == 1, Error::InvalidNegotiatedProtection);
106 let akm = wpa.akm_list[0].clone();
107 let mic_size = akm.mic_bytes().ok_or(Error::InvalidNegotiatedProtection)?;
108 let group_data = wpa.multicast_cipher.clone();
109 let pairwise = wpa.unicast_cipher_list[0].clone();
110 Ok(Self {
111 group_data,
112 pairwise,
113 group_mgmt: None,
114 akm,
115 mic_size,
116 protection_type: ProtectionType::LegacyWpa1,
117 caps: None,
118 })
119 }
120
121 pub fn to_full_protection(&self) -> ProtectionInfo {
124 match self.protection_type {
125 ProtectionType::Rsne => ProtectionInfo::Rsne(Rsne {
126 group_data_cipher_suite: Some(self.group_data.clone()),
127 pairwise_cipher_suites: vec![self.pairwise.clone()],
128 group_mgmt_cipher_suite: self.group_mgmt.clone(),
129 akm_suites: vec![self.akm.clone()],
130 rsn_capabilities: self.caps.clone(),
131 ..Default::default()
132 }),
133 ProtectionType::LegacyWpa1 => ProtectionInfo::LegacyWpa(WpaIe {
134 multicast_cipher: self.group_data.clone(),
135 unicast_cipher_list: vec![self.pairwise.clone()],
136 akm_list: vec![self.akm.clone()],
137 }),
138 }
139 }
140
141 pub fn igtk_support(&self) -> IgtkSupport {
142 match &self.caps {
143 Some(caps) => {
144 if caps.mgmt_frame_protection_req() {
145 IgtkSupport::Required
146 } else if caps.mgmt_frame_protection_cap() {
147 IgtkSupport::Capable
148 } else {
149 IgtkSupport::Unsupported
150 }
151 }
152 None => IgtkSupport::Unsupported,
153 }
154 }
155
156 pub fn group_mgmt_cipher(&self) -> Cipher {
157 self.group_mgmt.clone().unwrap_or(CIPHER_BIP_CMAC_128)
159 }
160}
161
162pub struct EncryptedKeyData<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
164
165impl<B: SplitByteSlice> EncryptedKeyData<B> {
166 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
167 pub fn decrypt(
170 self,
171 kek: &[u8],
172 protection: &NegotiatedProtection,
173 ) -> Result<(eapol::KeyFrameRx<B>, Vec<u8>), Error> {
174 let key_data = protection.keywrap_algorithm()?.unwrap_key(
175 kek,
176 &self.0.key_frame_fields.key_iv,
177 &self.0.key_data[..],
178 )?;
179 Ok((self.0, key_data))
180 }
181}
182
183#[derive(Debug)]
185pub struct WithUnverifiedMic<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
186
187impl<B: SplitByteSlice> WithUnverifiedMic<B> {
188 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
189 pub fn verify_mic(
193 self,
194 kck: &[u8],
195 protection: &NegotiatedProtection,
196 ) -> Result<UnverifiedKeyData<B>, Error> {
197 let mic_bytes = protection.akm.mic_bytes().ok_or(Error::UnsupportedAkmSuite)?;
200 rsn_ensure!(self.0.key_mic.len() == mic_bytes as usize, Error::InvalidMicSize);
201
202 let buf = self.0.to_bytes(true);
204 let valid_mic =
205 protection.integrity_algorithm()?.verify(kck, &buf[..], &self.0.key_mic[..]);
206 rsn_ensure!(valid_mic, Error::InvalidMic);
207
208 if self.0.key_frame_fields.key_info().encrypted_key_data() {
209 Ok(UnverifiedKeyData::Encrypted(EncryptedKeyData(self.0)))
210 } else {
211 Ok(UnverifiedKeyData::NotEncrypted(self.0))
212 }
213 }
214}
215
216pub enum UnverifiedKeyData<B: SplitByteSlice> {
219 Encrypted(EncryptedKeyData<B>),
220 NotEncrypted(eapol::KeyFrameRx<B>),
221}
222
223#[derive(Debug)]
226pub enum Dot11VerifiedKeyFrame<B: SplitByteSlice> {
227 WithUnverifiedMic(WithUnverifiedMic<B>),
228 WithoutMic(eapol::KeyFrameRx<B>),
229}
230
231impl<B: SplitByteSlice> Dot11VerifiedKeyFrame<B> {
232 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
240 pub fn from_frame(
241 frame: eapol::KeyFrameRx<B>,
242 role: &Role,
243 protection: &NegotiatedProtection,
244 key_replay_counter: u64,
245 ) -> Result<Dot11VerifiedKeyFrame<B>, Error> {
246 let sender = match role {
247 Role::Supplicant => Role::Authenticator,
248 Role::Authenticator => Role::Supplicant,
249 };
250
251 let key_descriptor = match frame.key_frame_fields.descriptor_type {
254 eapol::KeyDescriptor::IEEE802DOT11 => eapol::KeyDescriptor::IEEE802DOT11,
255 eapol::KeyDescriptor::LEGACY_WPA1
256 if protection.protection_type == ProtectionType::LegacyWpa1 =>
257 {
258 eapol::KeyDescriptor::LEGACY_WPA1
259 }
260 eapol::KeyDescriptor::RC4 => {
261 return Err(Error::InvalidKeyDescriptor(
262 frame.key_frame_fields.descriptor_type,
263 eapol::KeyDescriptor::IEEE802DOT11,
264 )
265 .into())
266 }
267 _ => {
269 return Err(
270 Error::UnsupportedKeyDescriptor(frame.key_frame_fields.descriptor_type).into()
271 )
272 }
273 };
274
275 let frame_key_descriptor_version =
277 frame.key_frame_fields.key_info().key_descriptor_version();
278 let expected_version = derive_key_descriptor_version(key_descriptor, protection);
279 rsn_ensure!(
280 frame_key_descriptor_version == expected_version,
281 Error::UnsupportedKeyDescriptorVersion(frame_key_descriptor_version)
282 );
283
284 match frame.key_frame_fields.key_info().key_type() {
287 eapol::KeyType::PAIRWISE => {}
288 eapol::KeyType::GROUP_SMK => {
289 rsn_ensure!(
291 !frame.key_frame_fields.key_info().install(),
292 Error::InvalidInstallBitGroupSmkHandshake
293 );
294 }
295 };
296
297 if let Role::Supplicant = sender {
299 rsn_ensure!(
300 !frame.key_frame_fields.key_info().key_ack(),
301 Error::InvalidKeyAckBitSupplicant
302 );
303 }
304
305 if let Role::Authenticator = sender {
313 rsn_ensure!(
314 !frame.key_frame_fields.key_info().error(),
315 Error::InvalidErrorBitAuthenticator
316 );
317 }
318
319 if let Role::Authenticator = sender {
321 rsn_ensure!(
322 !frame.key_frame_fields.key_info().request(),
323 Error::InvalidRequestBitAuthenticator
324 );
325 }
326
327 rsn_ensure!(
332 !frame.key_frame_fields.key_info().smk_message(),
333 Error::SmkHandshakeNotSupported
334 );
335
336 match frame.key_frame_fields.key_info().key_type() {
338 eapol::KeyType::PAIRWISE => match sender {
339 Role::Supplicant if frame.key_frame_fields.key_len.to_native() != 0 => {
348 let tk_len =
349 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?;
350 rsn_ensure!(
351 frame.key_frame_fields.key_len.to_native() == tk_len.into(),
352 Error::InvalidKeyLength(
353 frame.key_frame_fields.key_len.to_native().into(),
354 tk_len.into()
355 )
356 );
357 }
358 Role::Authenticator => {
360 let tk_len: usize =
361 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?.into();
362 rsn_ensure!(
363 usize::from(frame.key_frame_fields.key_len.to_native()) == tk_len,
364 Error::InvalidKeyLength(
365 frame.key_frame_fields.key_len.to_native().into(),
366 tk_len
367 )
368 );
369 }
370 _ => {}
371 },
372 eapol::KeyType::GROUP_SMK => {}
376 };
377
378 if key_replay_counter > 0 {
379 match sender {
380 Role::Supplicant => {
383 rsn_ensure!(
384 frame.key_frame_fields.key_replay_counter.to_native() >= key_replay_counter,
385 Error::InvalidKeyReplayCounter(
386 frame.key_frame_fields.key_replay_counter.to_native(),
387 key_replay_counter
388 )
389 );
390 }
391 Role::Authenticator => {
397 rsn_ensure!(
398 frame.key_frame_fields.key_replay_counter.to_native() > key_replay_counter,
399 Error::InvalidKeyReplayCounter(
400 frame.key_frame_fields.key_replay_counter.to_native(),
401 key_replay_counter
402 )
403 );
404 }
405 }
406 }
407
408 if frame.key_frame_fields.key_info().encrypted_key_data() {
411 rsn_ensure!(
412 frame.key_frame_fields.key_info().key_mic(),
413 Error::InvalidMicBitForEncryptedKeyData
414 );
415 }
416
417 if frame.key_frame_fields.key_info().key_mic() {
436 Ok(Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)))
437 } else {
438 Ok(Dot11VerifiedKeyFrame::WithoutMic(frame))
439 }
440 }
441
442 pub fn unsafe_get_raw(&self) -> &eapol::KeyFrameRx<B> {
446 match self {
447 Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)) => frame,
448 Dot11VerifiedKeyFrame::WithoutMic(frame) => frame,
449 }
450 }
451}
452
453pub fn derive_key_descriptor_version(
456 key_descriptor_type: eapol::KeyDescriptor,
457 protection: &NegotiatedProtection,
458) -> u16 {
459 let akm = &protection.akm;
460 let pairwise = &protection.pairwise;
461
462 if !akm.has_known_algorithm() || !pairwise.has_known_usage() {
463 return 0;
464 }
465
466 match akm.suite_type {
467 1 | 2 => match key_descriptor_type {
468 eapol::KeyDescriptor::RC4 => match pairwise.suite_type {
469 TKIP | GROUP_CIPHER_SUITE => 1,
470 _ => 0,
471 },
472 eapol::KeyDescriptor::IEEE802DOT11 | eapol::KeyDescriptor::LEGACY_WPA1 => {
473 if pairwise.suite_type == TKIP || pairwise.suite_type == GROUP_CIPHER_SUITE {
474 1
475 } else if pairwise.is_enhanced() || protection.group_data.is_enhanced() {
476 2
477 } else {
478 0
479 }
480 }
481 _ => 0,
482 },
483 3..=6 => 3,
486 _ => 0,
487 }
488}
489
490#[derive(Debug, Clone, Copy, PartialEq)]
491pub enum Role {
492 Authenticator,
493 Supplicant,
494}
495
496#[derive(Debug, PartialEq, Clone, Copy)]
497pub enum SecAssocStatus {
498 WrongPassword,
499 PmkSaEstablished,
500 EssSaEstablished,
501}
502
503#[derive(Debug, PartialEq, Clone, Copy)]
504pub enum AuthRejectedReason {
505 AuthFailed,
507 TooManyRetries,
509 PmksaExpired,
511}
512
513#[derive(Debug, PartialEq, Clone, Copy)]
514pub enum AuthStatus {
515 Success,
516 Rejected(AuthRejectedReason),
517 InternalError,
518}
519
520#[derive(Debug, PartialEq, Clone)]
521pub enum SecAssocUpdate {
522 TxEapolKeyFrame {
523 frame: eapol::KeyFrameBuf,
524 expect_response: bool,
527 },
528 Key(Key),
529 Status(SecAssocStatus),
530 TxSaeFrame(SaeFrame),
532 SaeAuthStatus(AuthStatus),
533 ScheduleSaeTimeout(u64),
534}
535
536pub type UpdateSink = Vec<SecAssocUpdate>;
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use wlan_common::assert_variant;
542 use wlan_common::ie::rsn::akm::{self, AKM_PSK};
543 use wlan_common::ie::rsn::cipher::{self, CIPHER_CCMP_128, CIPHER_GCMP_256};
544 use wlan_common::ie::rsn::fake_wpa2_s_rsne;
545
546 #[test]
547 fn test_negotiated_protection_from_rsne() {
548 let rsne = Rsne {
549 group_data_cipher_suite: Some(CIPHER_GCMP_256),
550 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
551 akm_suites: vec![AKM_PSK],
552 ..Default::default()
553 };
554 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
555
556 let rsne = Rsne::wpa3_rsne();
557 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
558
559 let rsne = Rsne {
560 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
561 akm_suites: vec![AKM_PSK],
562 ..Default::default()
563 };
564 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
565
566 let rsne = Rsne {
567 group_data_cipher_suite: Some(CIPHER_CCMP_128),
568 akm_suites: vec![AKM_PSK],
569 ..Default::default()
570 };
571 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
572
573 let rsne = Rsne {
574 group_data_cipher_suite: Some(CIPHER_CCMP_128),
575 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
576 ..Default::default()
577 };
578 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
579 }
580
581 #[test]
585 fn test_supplicant_sends_zeroed_and_non_zeroed_key_length() {
586 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
587 .expect("could not derive negotiated RSNE");
588 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
589
590 let msg1 = env.initiate(11.into());
592 let (msg2_base, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
593
594 let mut buf = vec![];
596 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
597 msg2.key_frame_fields.key_len.set_from_native(0);
598 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
599 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
600 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
601
602 let mut buf = vec![];
605 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
606 msg2.key_frame_fields.key_len.set_from_native(16);
607 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
608 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
609 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
610 }
611
612 #[test]
615 fn test_supplicant_sends_random_key_length() {
616 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
617
618 let msg1 = env.initiate(12.into());
620 let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 12.into());
621 let mut buf = vec![];
622 let mut msg2 = msg2.copy_keyframe_mut(&mut buf);
623
624 msg2.key_frame_fields.key_len.set_from_native(29);
625 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
626
627 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
628 .expect("could not derive negotiated RSNE");
629 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
630 assert!(result.is_err(), "successfully verified illegal message");
631 }
632
633 #[test]
634 fn test_to_rsne() {
635 let rsne = Rsne::wpa2_rsne();
636 let negotiated_protection = NegotiatedProtection::from_rsne(&rsne)
637 .expect("error, could not create negotiated RSNE")
638 .to_full_protection();
639 assert_variant!(negotiated_protection, ProtectionInfo::Rsne(actual_protection) => {
640 assert_eq!(actual_protection, rsne);
641 });
642 }
643
644 #[test]
645 fn test_to_legacy_wpa() {
646 let wpa_ie = make_wpa(Some(cipher::TKIP), vec![cipher::TKIP], vec![akm::PSK]);
647 let negotiated_protection = NegotiatedProtection::from_legacy_wpa(&wpa_ie)
648 .expect("error, could not create negotiated WPA")
649 .to_full_protection();
650 assert_variant!(negotiated_protection, ProtectionInfo::LegacyWpa(actual_protection) => {
651 assert_eq!(actual_protection, wpa_ie);
652 });
653 }
654
655 #[test]
656 fn test_igtk_support() {
657 let rsne = Rsne::wpa3_rsne();
659 let negotiated_protection =
660 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
661 assert_variant!(negotiated_protection.igtk_support(), IgtkSupport::Required);
662 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
663
664 let mut rsne = Rsne::wpa3_rsne();
666 rsne.rsn_capabilities.replace(RsnCapabilities(0).with_mgmt_frame_protection_cap(true));
667 let negotiated_protection =
668 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
669 assert_variant!(negotiated_protection.igtk_support(), IgtkSupport::Capable);
670
671 let rsne = Rsne::wpa2_rsne();
673 let negotiated_protection =
674 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
675 assert_variant!(negotiated_protection.igtk_support(), IgtkSupport::Unsupported);
676 }
677
678 #[test]
679 fn test_default_igtk_cipher() {
680 let mut rsne = Rsne::wpa3_rsne();
681 rsne.group_mgmt_cipher_suite.take(); let negotiated_protection =
683 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
684 assert_variant!(negotiated_protection.igtk_support(), IgtkSupport::Required);
685 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
686 }
687
688 fn make_wpa(unicast: Option<u8>, multicast: Vec<u8>, akms: Vec<u8>) -> WpaIe {
689 WpaIe {
690 multicast_cipher: unicast
691 .map(cipher::Cipher::new_dot11)
692 .expect("failed to make wpa ie!"),
693 unicast_cipher_list: multicast.into_iter().map(cipher::Cipher::new_dot11).collect(),
694 akm_list: akms.into_iter().map(akm::Akm::new_dot11).collect(),
695 }
696 }
697}