1pub mod psk;
6
7use crate::key::exchange::Key;
8use crate::rsna::{
9 AuthRejectedReason, AuthStatus, Dot11VerifiedKeyFrame, SecAssocUpdate, UpdateSink,
10};
11use crate::Error;
12use fidl_fuchsia_wlan_mlme::SaeFrame;
13use ieee80211::{MacAddr, MacAddrBytes, Ssid};
14use log::warn;
15use wlan_common::ie::rsn::akm::AKM_SAE;
16use wlan_sae as sae;
17use zerocopy::SplitByteSlice;
18
19const DEFAULT_GROUP_ID: u16 = 19;
23
24#[derive(Error, Debug)]
25pub enum AuthError {
26 #[error("Failed to construct auth method from the given configuration: {:?}", _0)]
27 FailedConstruction(anyhow::Error),
28 #[error("Non-SAE auth method received an SAE event")]
29 UnexpectedSaeEvent,
30}
31
32pub struct SaeData {
33 peer: MacAddr,
34 pub pmk: Option<sae::Key>,
35 handshake: Box<dyn sae::SaeHandshake>,
36 retransmit_timeout_id: u64,
39}
40
41#[derive(Debug, PartialEq, Clone)]
42pub enum Config {
43 ComputedPsk(psk::Psk),
44 Sae { ssid: Ssid, password: Vec<u8>, mac: MacAddr, peer_mac: MacAddr },
45 DriverSae { password: Vec<u8> },
46}
47
48impl Config {
49 pub fn method_name(&self) -> MethodName {
50 match self {
51 Config::ComputedPsk(_) => MethodName::Psk,
52 Config::Sae { .. } | Config::DriverSae { .. } => MethodName::Sae,
53 }
54 }
55}
56
57pub enum Method {
58 Psk(psk::Psk),
59 Sae(SaeData),
60 DriverSae(Option<sae::Key>),
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum MethodName {
66 Psk,
67 Sae,
68}
69
70impl std::fmt::Debug for Method {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
72 match self {
73 Self::Psk(psk) => write!(f, "Method::Psk({:?})", psk),
74 Self::Sae(sae_data) => write!(
75 f,
76 "Method::Sae {{ peer: {:?}, pmk: {}, .. }}",
77 sae_data.peer,
78 match sae_data.pmk {
79 Some(_) => "Some(_)",
80 None => "None",
81 }
82 ),
83 Self::DriverSae(key) => write!(f, "Method::DriverSae({:?})", key),
84 }
85 }
86}
87
88impl Method {
89 pub fn from_config(cfg: Config) -> Result<Method, AuthError> {
90 match cfg {
91 Config::ComputedPsk(psk) => Ok(Method::Psk(psk)),
92 Config::Sae { ssid, password, mac, peer_mac } => {
93 let handshake = sae::new_sae_handshake(
95 DEFAULT_GROUP_ID,
96 AKM_SAE,
97 wlan_sae::PweMethod::Loop,
98 ssid,
99 password,
100 None, mac,
102 peer_mac.clone(),
103 )
104 .map_err(AuthError::FailedConstruction)?;
105 Ok(Method::Sae(SaeData {
106 peer: peer_mac,
107 pmk: None,
108 handshake,
109 retransmit_timeout_id: 0,
110 }))
111 }
112 Config::DriverSae { .. } => Ok(Method::DriverSae(None)),
113 }
114 }
115
116 pub fn on_eapol_key_frame<B: SplitByteSlice>(
118 &self,
119 _update_sink: &mut UpdateSink,
120 _frame: Dot11VerifiedKeyFrame<B>,
121 ) -> Result<(), AuthError> {
122 Ok(())
123 }
124
125 pub fn on_pmk_available(
128 &mut self,
129 pmk: &[u8],
130 pmkid: &[u8],
131 assoc_update_sink: &mut UpdateSink,
132 ) -> Result<(), AuthError> {
133 match self {
134 Method::DriverSae(key) => {
135 key.replace(sae::Key { pmk: pmk.to_vec(), pmkid: pmkid.to_vec() });
136 assoc_update_sink.push(SecAssocUpdate::Key(Key::Pmk(pmk.to_vec())));
137 Ok(())
138 }
139 _ => Err(AuthError::UnexpectedSaeEvent),
140 }
141 }
142
143 pub fn on_sae_handshake_ind(
144 &mut self,
145 assoc_update_sink: &mut UpdateSink,
146 ) -> Result<(), AuthError> {
147 match self {
148 Method::Sae(sae_data) => {
149 let mut sae_update_sink = sae::SaeUpdateSink::default();
150 sae_data.handshake.initiate_sae(&mut sae_update_sink);
151 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
152 Ok(())
153 }
154 _ => Err(AuthError::UnexpectedSaeEvent),
155 }
156 }
157
158 pub fn on_sae_frame_rx(
159 &mut self,
160 assoc_update_sink: &mut UpdateSink,
161 frame: SaeFrame,
162 ) -> Result<(), AuthError> {
163 match self {
164 Method::Sae(sae_data) => {
165 let mut sae_update_sink = sae::SaeUpdateSink::default();
166 let frame_rx = sae::AuthFrameRx {
167 seq: frame.seq_num,
168 status_code: frame.status_code,
169 body: &frame.sae_fields[..],
170 };
171 sae_data.handshake.handle_frame(&mut sae_update_sink, &frame_rx);
172 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
173 Ok(())
174 }
175 _ => Err(AuthError::UnexpectedSaeEvent),
176 }
177 }
178
179 pub fn on_sae_timeout(
180 &mut self,
181 assoc_update_sink: &mut UpdateSink,
182 event_id: u64,
183 ) -> Result<(), AuthError> {
184 match self {
185 Method::Sae(sae_data) => {
186 if sae_data.retransmit_timeout_id == event_id {
187 sae_data.retransmit_timeout_id += 1;
188 let mut sae_update_sink = sae::SaeUpdateSink::default();
189 sae_data
190 .handshake
191 .handle_timeout(&mut sae_update_sink, sae::Timeout::Retransmission);
192 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
193 }
194 Ok(())
195 }
196 _ => Err(AuthError::UnexpectedSaeEvent),
197 }
198 }
199}
200
201fn process_sae_updates(
202 sae_data: &mut SaeData,
203 assoc_update_sink: &mut UpdateSink,
204 sae_update_sink: sae::SaeUpdateSink,
205) {
206 for sae_update in sae_update_sink {
207 match sae_update {
208 sae::SaeUpdate::SendFrame(frame) => {
209 let sae_frame = SaeFrame {
210 peer_sta_address: sae_data.peer.clone().to_array(),
211 status_code: frame.status_code,
212 seq_num: frame.seq,
213 sae_fields: frame.body,
214 };
215 assoc_update_sink.push(SecAssocUpdate::TxSaeFrame(sae_frame));
216 }
217 sae::SaeUpdate::Success(key) => {
218 sae_data.pmk.replace(key.clone());
219 assoc_update_sink.push(SecAssocUpdate::Key(Key::Pmk(key.pmk)));
220 assoc_update_sink.push(SecAssocUpdate::SaeAuthStatus(AuthStatus::Success));
221 }
222 sae::SaeUpdate::Reject(reason) => {
223 warn!("SAE handshake rejected: {:?}", reason);
224 let status = match reason {
225 sae::RejectReason::AuthFailed => {
226 AuthStatus::Rejected(AuthRejectedReason::AuthFailed)
227 }
228 sae::RejectReason::KeyExpiration => {
229 AuthStatus::Rejected(AuthRejectedReason::PmksaExpired)
230 }
231 sae::RejectReason::TooManyRetries => {
232 AuthStatus::Rejected(AuthRejectedReason::TooManyRetries)
233 }
234 sae::RejectReason::InternalError(_) => AuthStatus::InternalError,
235 };
236 assoc_update_sink.push(SecAssocUpdate::SaeAuthStatus(status));
237 }
238 sae::SaeUpdate::ResetTimeout(timer) => {
239 match timer {
240 sae::Timeout::KeyExpiration => (), sae::Timeout::Retransmission => {
242 sae_data.retransmit_timeout_id += 1;
243 assoc_update_sink.push(SecAssocUpdate::ScheduleSaeTimeout(
244 sae_data.retransmit_timeout_id,
245 ));
246 }
247 };
248 }
249 sae::SaeUpdate::CancelTimeout(timer) => {
250 match timer {
251 sae::Timeout::KeyExpiration => (),
252 sae::Timeout::Retransmission => {
253 sae_data.retransmit_timeout_id += 1;
254 }
255 };
256 }
257 }
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use super::*;
264 use std::sync::{Arc, Mutex};
265 use wlan_common::assert_variant;
266
267 #[test]
268 fn psk_rejects_sae() {
269 let mut auth = Method::from_config(Config::ComputedPsk(Box::new([0x8; 16])))
270 .expect("Failed to construct PSK auth method");
271 let mut sink = UpdateSink::default();
272 auth.on_sae_handshake_ind(&mut sink).expect_err("PSK auth method accepted SAE ind");
273 let frame = SaeFrame {
274 peer_sta_address: [0xaa; 6],
275 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
276 seq_num: 1,
277 sae_fields: vec![0u8; 10],
278 };
279 auth.on_sae_frame_rx(&mut sink, frame).expect_err("PSK auth method accepted SAE frame");
280 assert!(sink.is_empty());
282 }
283
284 #[derive(Default)]
285 struct SaeCounter {
286 initiated: bool,
287 handled_commits: u32,
288 handled_confirms: u32,
289 handled_timeouts: u32,
290 }
291
292 struct DummySae(Arc<Mutex<SaeCounter>>);
293
294 impl sae::SaeHandshake for DummySae {
296 fn initiate_sae(&mut self, sink: &mut sae::SaeUpdateSink) {
297 self.0.lock().unwrap().initiated = true;
298 sink.push(sae::SaeUpdate::SendFrame(sae::AuthFrameTx {
299 seq: 1,
300 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
301 body: vec![],
302 }));
303 }
304 fn handle_commit(
305 &mut self,
306 _sink: &mut sae::SaeUpdateSink,
307 _commit_msg: &sae::CommitMsg<'_>,
308 ) {
309 assert!(self.0.lock().unwrap().initiated);
310 self.0.lock().unwrap().handled_commits += 1;
311 }
312 fn handle_confirm(
313 &mut self,
314 sink: &mut sae::SaeUpdateSink,
315 _confirm_msg: &sae::ConfirmMsg<'_>,
316 ) {
317 assert!(self.0.lock().unwrap().initiated);
318 self.0.lock().unwrap().handled_confirms += 1;
319 sink.push(sae::SaeUpdate::SendFrame(sae::AuthFrameTx {
320 seq: 2,
321 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
322 body: vec![],
323 }));
324 sink.push(sae::SaeUpdate::Success(sae::Key { pmk: vec![0xaa], pmkid: vec![0xbb] }))
325 }
326 fn handle_anti_clogging_token(
327 &mut self,
328 _sink: &mut sae::SaeUpdateSink,
329 _msg: &sae::AntiCloggingTokenMsg<'_>,
330 ) {
331 panic!("The SAE initiator should never receive an anti-clogging token.");
332 }
333 fn handle_timeout(&mut self, _sink: &mut sae::SaeUpdateSink, _timeout: sae::Timeout) {
334 self.0.lock().unwrap().handled_timeouts += 1;
335 }
336 }
337
338 const COMMIT: [u8; 98] = [
341 0x13, 0x00, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
342 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
343 0xaa, 0xaa, 0xaa, 0xaa, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
344 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
345 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
346 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
347 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
348 ];
349 const CONFIRM: [u8; 34] = [
350 0xaa, 0xaa, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
351 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
352 0xbb, 0xbb, 0xbb, 0xbb,
353 ];
354
355 #[test]
356 fn sae_executes_handshake() {
357 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
358 let mut auth = Method::Sae(SaeData {
359 peer: MacAddr::from([0xaa; 6]),
360 pmk: None,
361 handshake: Box::new(DummySae(sae_counter.clone())),
362 retransmit_timeout_id: 0,
363 });
364 let mut sink = UpdateSink::default();
365
366 auth.on_sae_handshake_ind(&mut sink).expect("SAE handshake should accept SAE ind");
367 assert!(sae_counter.lock().unwrap().initiated);
368 assert_variant!(sink.pop(), Some(SecAssocUpdate::TxSaeFrame(_)));
369
370 let commit_frame = SaeFrame {
371 peer_sta_address: [0xaa; 6],
372 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
373 seq_num: 1,
374 sae_fields: COMMIT.to_vec(),
375 };
376 auth.on_sae_frame_rx(&mut sink, commit_frame).expect("SAE handshake should accept commit");
377 assert_eq!(sae_counter.lock().unwrap().handled_commits, 1);
378 assert!(sink.is_empty());
379
380 let confirm_frame = SaeFrame {
381 peer_sta_address: [0xaa; 6],
382 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
383 seq_num: 2,
384 sae_fields: CONFIRM.to_vec(),
385 };
386 auth.on_sae_frame_rx(&mut sink, confirm_frame)
387 .expect("SAE handshake should accept confirm");
388 assert_eq!(sae_counter.lock().unwrap().handled_confirms, 1);
389 assert_eq!(sink.len(), 3);
390 assert_variant!(sink.remove(0), SecAssocUpdate::TxSaeFrame(_));
391 assert_variant!(sink.remove(0), SecAssocUpdate::Key(_));
392 assert_variant!(sink.remove(0), SecAssocUpdate::SaeAuthStatus(AuthStatus::Success));
393 match auth {
394 Method::Sae(sae_data) => assert!(sae_data.pmk.is_some()),
395 _ => unreachable!(),
396 };
397 }
398
399 #[test]
400 fn sae_handles_current_timeouts() {
401 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
402 let mut sae = Method::Sae(SaeData {
403 peer: MacAddr::from([0xaa; 6]),
404 pmk: None,
405 handshake: Box::new(DummySae(sae_counter.clone())),
406 retransmit_timeout_id: 0,
407 });
408 let mut sink = UpdateSink::default();
409
410 if let Method::Sae(data) = &mut sae {
411 process_sae_updates(
412 data,
413 &mut sink,
414 vec![sae::SaeUpdate::ResetTimeout(sae::Timeout::Retransmission)],
415 );
416 };
417 let event_id = assert_variant!(sink.pop(),
418 Some(SecAssocUpdate::ScheduleSaeTimeout(id)) => id,
419 );
420 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
421 assert_eq!(sae_counter.lock().unwrap().handled_timeouts, 1);
422 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
424 assert_eq!(sae_counter.lock().unwrap().handled_timeouts, 1); if let Method::Sae(data) = &mut sae {
428 process_sae_updates(
429 data,
430 &mut sink,
431 vec![
432 sae::SaeUpdate::ResetTimeout(sae::Timeout::Retransmission),
433 sae::SaeUpdate::CancelTimeout(sae::Timeout::Retransmission),
434 ],
435 );
436 };
437 let event_id = assert_variant!(sink.pop(),
438 Some(SecAssocUpdate::ScheduleSaeTimeout(id)) => id,
439 );
440 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
441 assert_eq!(sae_counter.lock().unwrap().handled_timeouts, 1); }
443
444 #[test]
445 fn sae_key_expiration_no_op() {
446 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
447 let mut data = SaeData {
448 peer: MacAddr::from([0xaa; 6]),
449 pmk: None,
450 handshake: Box::new(DummySae(sae_counter.clone())),
451 retransmit_timeout_id: 0,
452 };
453 let mut sink = UpdateSink::new();
454 process_sae_updates(
455 &mut data,
456 &mut sink,
457 vec![
458 sae::SaeUpdate::ResetTimeout(sae::Timeout::KeyExpiration),
459 sae::SaeUpdate::CancelTimeout(sae::Timeout::KeyExpiration),
460 ],
461 );
462 assert!(sink.is_empty(), "KeyExpiration should not produce updates.");
463 }
464
465 #[test]
466 fn driver_sae_handles_pmk() {
467 let mut auth = Method::from_config(Config::DriverSae { password: vec![0xbb; 8] })
468 .expect("Failed to construct PSK auth method");
469 let mut sink = UpdateSink::default();
470 auth.on_pmk_available(&[0xcc; 8][..], &[0xdd; 8][..], &mut sink)
471 .expect("Driver SAE should handle on_pmk_available");
472 assert_eq!(sink.len(), 1);
473 let pmk = assert_variant!(sink.get(0), Some(SecAssocUpdate::Key(Key::Pmk(pmk))) => pmk);
474 assert_eq!(*pmk, vec![0xcc; 8]);
475 }
476
477 #[test]
478 fn driver_sae_rejects_sme_sae_calls() {
479 let mut auth = Method::from_config(Config::DriverSae { password: vec![0xbb; 8] })
480 .expect("Failed to construct PSK auth method");
481 let mut sink = UpdateSink::default();
482 auth.on_sae_handshake_ind(&mut sink).expect_err("Driver SAE shouldn't handle SAE ind");
483 let frame = SaeFrame {
484 peer_sta_address: [0xaa; 6],
485 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
486 seq_num: 1,
487 sae_fields: COMMIT.to_vec(),
488 };
489 auth.on_sae_frame_rx(&mut sink, frame).expect_err("Driver SAE shouldn't handle frames");
490 auth.on_sae_timeout(&mut sink, 0).expect_err("Driver SAE shouldn't handle SAE timeouts");
491 assert!(sink.is_empty());
492 }
493}