1use crate::integrity::{self, integrity_algorithm};
6use crate::key::exchange::Key;
7use crate::keywrap::{self, keywrap_algorithm};
8
9use crate::{Error, ProtectionInfo, rsn_ensure};
10use anyhow::{anyhow, ensure};
11use fidl_fuchsia_wlan_mlme::SaeFrame;
12use wlan_common::ie::rsn::akm::Akm;
13use wlan_common::ie::rsn::cipher::{CIPHER_BIP_CMAC_128, Cipher, GROUP_CIPHER_SUITE, TKIP};
14use wlan_common::ie::rsn::rsne::{RsnCapabilities, Rsne};
15use wlan_common::ie::wpa::WpaIe;
16
17use zerocopy::SplitByteSlice;
18
19pub mod esssa;
20#[cfg(test)]
21pub mod test_util;
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum ProtectionType {
25 LegacyWpa1,
26 Rsne,
27}
28
29#[derive(Debug)]
30pub enum IgtkSupport {
31 Unsupported,
32 Capable,
33 Required,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct NegotiatedProtection {
38 pub group_data: Cipher,
39 pub pairwise: Cipher,
40 pub group_mgmt: Option<Cipher>,
41 pub akm: Akm,
42 pub mic_size: u16,
43 pub protection_type: ProtectionType,
44 caps: Option<RsnCapabilities>,
47}
48
49impl NegotiatedProtection {
50 pub fn from_protection(protection: &ProtectionInfo) -> Result<Self, anyhow::Error> {
51 match protection {
52 ProtectionInfo::Rsne(rsne) => Self::from_rsne(rsne),
53 ProtectionInfo::LegacyWpa(wpa) => Self::from_legacy_wpa(wpa),
54 }
55 }
56
57 fn key_descriptor_version(&self) -> u16 {
58 let key_descriptor_type = match self.protection_type {
59 ProtectionType::LegacyWpa1 => eapol::KeyDescriptor::LEGACY_WPA1,
60 ProtectionType::Rsne => eapol::KeyDescriptor::IEEE802DOT11,
61 };
62 derive_key_descriptor_version(key_descriptor_type, self)
63 }
64
65 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
66 pub fn integrity_algorithm(&self) -> Result<Box<dyn integrity::Algorithm>, Error> {
67 integrity_algorithm(self.key_descriptor_version(), &self.akm)
68 .ok_or(Error::UnknownIntegrityAlgorithm)
69 }
70
71 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
72 pub fn keywrap_algorithm(&self) -> Result<Box<dyn keywrap::Algorithm>, Error> {
73 keywrap_algorithm(self.key_descriptor_version(), &self.akm)
74 .ok_or(Error::UnknownKeywrapAlgorithm)
75 }
76
77 pub fn from_rsne(rsne: &Rsne) -> Result<Self, anyhow::Error> {
80 rsne.ensure_valid_s_rsne()
81 .map_err(|e| anyhow!(e).context(Error::InvalidNegotiatedProtection))?;
82
83 let group_data = rsne.group_data_cipher_suite.as_ref().unwrap();
86 let pairwise = &rsne.pairwise_cipher_suites[0];
87 let akm = &rsne.akm_suites[0];
88 let mic_size = akm.mic_bytes();
89 let mic_size = mic_size.unwrap();
90
91 Ok(Self {
92 group_data: group_data.clone(),
93 pairwise: pairwise.clone(),
94 group_mgmt: rsne.group_mgmt_cipher_suite.clone(),
95 akm: akm.clone(),
96 mic_size,
97 protection_type: ProtectionType::Rsne,
98 caps: rsne.rsn_capabilities.clone(),
99 })
100 }
101
102 pub fn from_legacy_wpa(wpa: &WpaIe) -> Result<Self, anyhow::Error> {
105 ensure!(wpa.unicast_cipher_list.len() == 1, Error::InvalidNegotiatedProtection);
106 ensure!(wpa.akm_list.len() == 1, Error::InvalidNegotiatedProtection);
107 let akm = wpa.akm_list[0].clone();
108 let mic_size = akm.mic_bytes().ok_or(Error::InvalidNegotiatedProtection)?;
109 let group_data = wpa.multicast_cipher.clone();
110 let pairwise = wpa.unicast_cipher_list[0].clone();
111 Ok(Self {
112 group_data,
113 pairwise,
114 group_mgmt: None,
115 akm,
116 mic_size,
117 protection_type: ProtectionType::LegacyWpa1,
118 caps: None,
119 })
120 }
121
122 pub fn to_full_protection(&self) -> ProtectionInfo {
125 match self.protection_type {
126 ProtectionType::Rsne => ProtectionInfo::Rsne(Rsne {
127 group_data_cipher_suite: Some(self.group_data.clone()),
128 pairwise_cipher_suites: vec![self.pairwise.clone()],
129 group_mgmt_cipher_suite: self.group_mgmt.clone(),
130 akm_suites: vec![self.akm.clone()],
131 rsn_capabilities: self.caps.clone(),
132 ..Default::default()
133 }),
134 ProtectionType::LegacyWpa1 => ProtectionInfo::LegacyWpa(WpaIe {
135 multicast_cipher: self.group_data.clone(),
136 unicast_cipher_list: vec![self.pairwise.clone()],
137 akm_list: vec![self.akm.clone()],
138 }),
139 }
140 }
141
142 pub fn igtk_support(&self) -> IgtkSupport {
143 match &self.caps {
144 Some(caps) => {
145 if caps.mgmt_frame_protection_req() {
146 IgtkSupport::Required
147 } else if caps.mgmt_frame_protection_cap() {
148 IgtkSupport::Capable
149 } else {
150 IgtkSupport::Unsupported
151 }
152 }
153 None => IgtkSupport::Unsupported,
154 }
155 }
156
157 pub fn group_mgmt_cipher(&self) -> Cipher {
158 self.group_mgmt.clone().unwrap_or(CIPHER_BIP_CMAC_128)
160 }
161}
162
163pub struct EncryptedKeyData<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
165
166impl<B: SplitByteSlice> EncryptedKeyData<B> {
167 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
168 pub fn decrypt(
171 self,
172 kek: &[u8],
173 protection: &NegotiatedProtection,
174 ) -> Result<(eapol::KeyFrameRx<B>, Vec<u8>), Error> {
175 let key_data = protection.keywrap_algorithm()?.unwrap_key(
176 kek,
177 &self.0.key_frame_fields.key_iv,
178 &self.0.key_data[..],
179 )?;
180 Ok((self.0, key_data))
181 }
182}
183
184#[derive(Debug)]
186pub struct WithUnverifiedMic<B: SplitByteSlice>(eapol::KeyFrameRx<B>);
187
188impl<B: SplitByteSlice> WithUnverifiedMic<B> {
189 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
190 pub fn verify_mic(
194 self,
195 kck: &[u8],
196 protection: &NegotiatedProtection,
197 ) -> Result<UnverifiedKeyData<B>, Error> {
198 let mic_bytes = protection.akm.mic_bytes().ok_or(Error::UnsupportedAkmSuite)?;
201 rsn_ensure!(self.0.key_mic.len() == mic_bytes as usize, Error::InvalidMicSize);
202
203 let buf = self.0.to_bytes(true);
205 let valid_mic =
206 protection.integrity_algorithm()?.verify(kck, &buf[..], &self.0.key_mic[..]);
207 rsn_ensure!(valid_mic, Error::InvalidMic);
208
209 if self.0.key_frame_fields.key_info().encrypted_key_data() {
210 Ok(UnverifiedKeyData::Encrypted(EncryptedKeyData(self.0)))
211 } else {
212 Ok(UnverifiedKeyData::NotEncrypted(self.0))
213 }
214 }
215}
216
217pub enum UnverifiedKeyData<B: SplitByteSlice> {
220 Encrypted(EncryptedKeyData<B>),
221 NotEncrypted(eapol::KeyFrameRx<B>),
222}
223
224#[derive(Debug)]
227pub enum Dot11VerifiedKeyFrame<B: SplitByteSlice> {
228 WithUnverifiedMic(WithUnverifiedMic<B>),
229 WithoutMic(eapol::KeyFrameRx<B>),
230}
231
232impl<B: SplitByteSlice> Dot11VerifiedKeyFrame<B> {
233 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
241 pub fn from_frame(
242 frame: eapol::KeyFrameRx<B>,
243 role: &Role,
244 protection: &NegotiatedProtection,
245 key_replay_counter: u64,
246 ) -> Result<Dot11VerifiedKeyFrame<B>, Error> {
247 let sender = match role {
248 Role::Supplicant => Role::Authenticator,
249 Role::Authenticator => Role::Supplicant,
250 };
251
252 let key_descriptor = match frame.key_frame_fields.descriptor_type {
255 eapol::KeyDescriptor::IEEE802DOT11 => eapol::KeyDescriptor::IEEE802DOT11,
256 eapol::KeyDescriptor::LEGACY_WPA1
257 if protection.protection_type == ProtectionType::LegacyWpa1 =>
258 {
259 eapol::KeyDescriptor::LEGACY_WPA1
260 }
261 eapol::KeyDescriptor::RC4 => {
262 return Err(Error::InvalidKeyDescriptor(
263 frame.key_frame_fields.descriptor_type,
264 eapol::KeyDescriptor::IEEE802DOT11,
265 )
266 .into());
267 }
268 _ => {
270 return Err(Error::UnsupportedKeyDescriptor(
271 frame.key_frame_fields.descriptor_type,
272 )
273 .into());
274 }
275 };
276
277 let frame_key_descriptor_version =
279 frame.key_frame_fields.key_info().key_descriptor_version();
280 let expected_version = derive_key_descriptor_version(key_descriptor, protection);
281 rsn_ensure!(
282 frame_key_descriptor_version == expected_version,
283 Error::UnsupportedKeyDescriptorVersion(frame_key_descriptor_version)
284 );
285
286 match frame.key_frame_fields.key_info().key_type() {
289 eapol::KeyType::PAIRWISE => {}
290 eapol::KeyType::GROUP_SMK => {
291 rsn_ensure!(
293 !frame.key_frame_fields.key_info().install(),
294 Error::InvalidInstallBitGroupSmkHandshake
295 );
296 }
297 };
298
299 if let Role::Supplicant = sender {
301 rsn_ensure!(
302 !frame.key_frame_fields.key_info().key_ack(),
303 Error::InvalidKeyAckBitSupplicant
304 );
305 }
306
307 if let Role::Authenticator = sender {
315 rsn_ensure!(
316 !frame.key_frame_fields.key_info().error(),
317 Error::InvalidErrorBitAuthenticator
318 );
319 }
320
321 if let Role::Authenticator = sender {
323 rsn_ensure!(
324 !frame.key_frame_fields.key_info().request(),
325 Error::InvalidRequestBitAuthenticator
326 );
327 }
328
329 rsn_ensure!(
334 !frame.key_frame_fields.key_info().smk_message(),
335 Error::SmkHandshakeNotSupported
336 );
337
338 match frame.key_frame_fields.key_info().key_type() {
340 eapol::KeyType::PAIRWISE => match sender {
341 Role::Supplicant if frame.key_frame_fields.key_len.get() != 0 => {
350 let tk_len =
351 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?;
352 rsn_ensure!(
353 frame.key_frame_fields.key_len.get() == u16::from(tk_len),
354 Error::InvalidKeyLength(
355 frame.key_frame_fields.key_len.get().into(),
356 tk_len.into()
357 )
358 );
359 }
360 Role::Authenticator => {
362 let tk_len: usize =
363 protection.pairwise.tk_bytes().ok_or(Error::UnsupportedCipherSuite)?.into();
364 rsn_ensure!(
365 usize::from(frame.key_frame_fields.key_len.get()) == tk_len,
366 Error::InvalidKeyLength(
367 frame.key_frame_fields.key_len.get().into(),
368 tk_len
369 )
370 );
371 }
372 _ => {}
373 },
374 eapol::KeyType::GROUP_SMK => {}
378 };
379
380 if key_replay_counter > 0 {
381 match sender {
382 Role::Supplicant => {
385 rsn_ensure!(
386 frame.key_frame_fields.key_replay_counter.get() >= key_replay_counter,
387 Error::InvalidKeyReplayCounter(
388 frame.key_frame_fields.key_replay_counter.get(),
389 key_replay_counter
390 )
391 );
392 }
393 Role::Authenticator => {
399 rsn_ensure!(
400 frame.key_frame_fields.key_replay_counter.get() > key_replay_counter,
401 Error::InvalidKeyReplayCounter(
402 frame.key_frame_fields.key_replay_counter.get(),
403 key_replay_counter
404 )
405 );
406 }
407 }
408 }
409
410 if frame.key_frame_fields.key_info().encrypted_key_data() {
413 rsn_ensure!(
414 frame.key_frame_fields.key_info().key_mic(),
415 Error::InvalidMicBitForEncryptedKeyData
416 );
417 }
418
419 if frame.key_frame_fields.key_info().key_mic() {
438 Ok(Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)))
439 } else {
440 Ok(Dot11VerifiedKeyFrame::WithoutMic(frame))
441 }
442 }
443
444 pub fn unsafe_get_raw(&self) -> &eapol::KeyFrameRx<B> {
448 match self {
449 Dot11VerifiedKeyFrame::WithUnverifiedMic(WithUnverifiedMic(frame)) => frame,
450 Dot11VerifiedKeyFrame::WithoutMic(frame) => frame,
451 }
452 }
453}
454
455pub fn derive_key_descriptor_version(
458 key_descriptor_type: eapol::KeyDescriptor,
459 protection: &NegotiatedProtection,
460) -> u16 {
461 let akm = &protection.akm;
462 let pairwise = &protection.pairwise;
463
464 if !akm.has_known_algorithm() || !pairwise.has_known_usage() {
465 return 0;
466 }
467
468 match akm.suite_type {
469 1 | 2 => match key_descriptor_type {
470 eapol::KeyDescriptor::RC4 => match pairwise.suite_type {
471 TKIP | GROUP_CIPHER_SUITE => 1,
472 _ => 0,
473 },
474 eapol::KeyDescriptor::IEEE802DOT11 | eapol::KeyDescriptor::LEGACY_WPA1 => {
475 if pairwise.suite_type == TKIP || pairwise.suite_type == GROUP_CIPHER_SUITE {
476 1
477 } else if pairwise.is_enhanced() || protection.group_data.is_enhanced() {
478 2
479 } else {
480 0
481 }
482 }
483 _ => 0,
484 },
485 3..=6 => 3,
488 _ => 0,
489 }
490}
491
492#[derive(Debug, Clone, Copy, PartialEq)]
493pub enum Role {
494 Authenticator,
495 Supplicant,
496}
497
498#[derive(Debug, PartialEq, Clone, Copy)]
499pub enum SecAssocStatus {
500 WrongPassword,
501 PmkSaEstablished,
502 EssSaEstablished,
503}
504
505#[derive(Debug, PartialEq, Clone, Copy)]
506pub enum AuthRejectedReason {
507 AuthFailed,
509 TooManyRetries,
511 PmksaExpired,
513}
514
515#[derive(Debug, PartialEq, Clone, Copy)]
516pub enum AuthStatus {
517 Success,
518 Rejected(AuthRejectedReason),
519 InternalError,
520}
521
522#[derive(Debug, PartialEq, Clone)]
523pub enum SecAssocUpdate {
524 TxEapolKeyFrame {
525 frame: eapol::KeyFrameBuf,
526 expect_response: bool,
529 },
530 Key(Key),
531 Status(SecAssocStatus),
532 TxSaeFrame(SaeFrame),
534 SaeAuthStatus(AuthStatus),
535 ScheduleSaeTimeout(u64),
536 TxOwePublicKey {
538 group_id: u16,
539 key: Vec<u8>,
540 },
541}
542
543pub type UpdateSink = Vec<SecAssocUpdate>;
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use assert_matches::assert_matches;
549 use wlan_common::ie::rsn::akm::{self, AKM_PSK};
550 use wlan_common::ie::rsn::cipher::{self, CIPHER_CCMP_128, CIPHER_GCMP_256};
551 use wlan_common::ie::rsn::fake_wpa2_s_rsne;
552 use zerocopy::byteorder::big_endian::U16;
553
554 #[test]
555 fn test_negotiated_protection_from_rsne() {
556 let rsne = Rsne {
557 group_data_cipher_suite: Some(CIPHER_GCMP_256),
558 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
559 akm_suites: vec![AKM_PSK],
560 ..Default::default()
561 };
562 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
563
564 let rsne = Rsne::wpa3_rsne();
565 NegotiatedProtection::from_rsne(&rsne).expect("error, could not create negotiated RSNE");
566
567 let rsne = Rsne {
568 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
569 akm_suites: vec![AKM_PSK],
570 ..Default::default()
571 };
572 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
573
574 let rsne = Rsne {
575 group_data_cipher_suite: Some(CIPHER_CCMP_128),
576 akm_suites: vec![AKM_PSK],
577 ..Default::default()
578 };
579 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
580
581 let rsne = Rsne {
582 group_data_cipher_suite: Some(CIPHER_CCMP_128),
583 pairwise_cipher_suites: vec![CIPHER_CCMP_128],
584 ..Default::default()
585 };
586 NegotiatedProtection::from_rsne(&rsne).expect_err("error, created negotiated RSNE");
587 }
588
589 #[test]
593 fn test_supplicant_sends_zeroed_and_non_zeroed_key_length() {
594 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
595 .expect("could not derive negotiated RSNE");
596 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
597
598 let msg1 = env.initiate(11.into());
600 let (msg2_base, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 11.into());
601
602 let mut buf = vec![];
604 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
605 msg2.key_frame_fields.key_len = U16::new(0);
606 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
607 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
608 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
609
610 let mut buf = vec![];
613 let mut msg2 = msg2_base.copy_keyframe_mut(&mut buf);
614 msg2.key_frame_fields.key_len = U16::new(16);
615 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
616 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
617 assert!(result.is_ok(), "failed verifying message: {}", result.unwrap_err());
618 }
619
620 #[test]
623 fn test_supplicant_sends_random_key_length() {
624 let mut env = test_util::FourwayTestEnv::new(test_util::HandshakeKind::Wpa2, 1, 3);
625
626 let msg1 = env.initiate(12.into());
628 let (msg2, ptk) = env.send_msg1_to_supplicant(msg1.keyframe(), 12.into());
629 let mut buf = vec![];
630 let mut msg2 = msg2.copy_keyframe_mut(&mut buf);
631
632 msg2.key_frame_fields.key_len = U16::new(29);
633 env.finalize_key_frame(&mut msg2, Some(ptk.kck()));
634
635 let protection = NegotiatedProtection::from_rsne(&fake_wpa2_s_rsne())
636 .expect("could not derive negotiated RSNE");
637 let result = Dot11VerifiedKeyFrame::from_frame(msg2, &Role::Authenticator, &protection, 12);
638 assert!(result.is_err(), "successfully verified illegal message");
639 }
640
641 #[test]
642 fn test_to_rsne() {
643 let rsne = Rsne::wpa2_rsne();
644 let negotiated_protection = NegotiatedProtection::from_rsne(&rsne)
645 .expect("error, could not create negotiated RSNE")
646 .to_full_protection();
647 assert_matches!(negotiated_protection, ProtectionInfo::Rsne(actual_protection) => {
648 assert_eq!(actual_protection, rsne);
649 });
650 }
651
652 #[test]
653 fn test_to_legacy_wpa() {
654 let wpa_ie = make_wpa(Some(cipher::TKIP), vec![cipher::TKIP], vec![akm::PSK]);
655 let negotiated_protection = NegotiatedProtection::from_legacy_wpa(&wpa_ie)
656 .expect("error, could not create negotiated WPA")
657 .to_full_protection();
658 assert_matches!(negotiated_protection, ProtectionInfo::LegacyWpa(actual_protection) => {
659 assert_eq!(actual_protection, wpa_ie);
660 });
661 }
662
663 #[test]
664 fn test_igtk_support() {
665 let rsne = Rsne::wpa3_rsne();
667 let negotiated_protection =
668 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
669 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Required);
670 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
671
672 let mut rsne = Rsne::wpa3_rsne();
674 rsne.rsn_capabilities.replace(RsnCapabilities(0).with_mgmt_frame_protection_cap(true));
675 let negotiated_protection =
676 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
677 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Capable);
678
679 let rsne = Rsne::wpa2_rsne();
681 let negotiated_protection =
682 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
683 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Unsupported);
684 }
685
686 #[test]
687 fn test_default_igtk_cipher() {
688 let mut rsne = Rsne::wpa3_rsne();
689 rsne.group_mgmt_cipher_suite.take(); let negotiated_protection =
691 NegotiatedProtection::from_rsne(&rsne).expect("Could not create negotiated RSNE");
692 assert_matches!(negotiated_protection.igtk_support(), IgtkSupport::Required);
693 assert_eq!(negotiated_protection.group_mgmt_cipher(), CIPHER_BIP_CMAC_128);
694 }
695
696 fn make_wpa(unicast: Option<u8>, multicast: Vec<u8>, akms: Vec<u8>) -> WpaIe {
697 WpaIe {
698 multicast_cipher: unicast
699 .map(cipher::Cipher::new_dot11)
700 .expect("failed to make wpa ie!"),
701 unicast_cipher_list: multicast.into_iter().map(cipher::Cipher::new_dot11).collect(),
702 akm_list: akms.into_iter().map(akm::Akm::new_dot11).collect(),
703 }
704 }
705}