1pub mod psk;
6
7use crate::Error;
8use crate::key::exchange::Key;
9use crate::rsna::{
10 AuthRejectedReason, AuthStatus, Dot11VerifiedKeyFrame, SecAssocUpdate, UpdateSink,
11};
12use fidl_fuchsia_wlan_mlme::SaeFrame;
13use ieee80211::{MacAddr, MacAddrBytes, Ssid};
14use log::warn;
15use wlan_common::ie::rsn::akm::AKM_SAE;
16use wlan_sae as sae;
17use zerocopy::SplitByteSlice;
18
19const DEFAULT_GROUP_ID: u16 = 19;
23
24#[derive(Error, Debug)]
25pub enum AuthError {
26 #[error("Failed to construct auth method from the given configuration: {:?}", _0)]
27 FailedConstruction(anyhow::Error),
28 #[error("Non-SAE auth method received an SAE event")]
29 UnexpectedSaeEvent,
30}
31
32pub struct SaeData {
33 peer: MacAddr,
34 pub pmk: Option<sae::Key>,
35 handshake: Box<dyn sae::SaeHandshake>,
36 retransmit_timeout_id: u64,
39}
40
41#[derive(Debug, PartialEq, Clone)]
42pub enum Config {
43 ComputedPsk(psk::Psk),
44 Sae { ssid: Ssid, password: Vec<u8>, mac: MacAddr, peer_mac: MacAddr },
45 DriverSae { password: Vec<u8> },
46}
47
48impl Config {
49 pub fn method_name(&self) -> MethodName {
50 match self {
51 Config::ComputedPsk(_) => MethodName::Psk,
52 Config::Sae { .. } | Config::DriverSae { .. } => MethodName::Sae,
53 }
54 }
55}
56
57pub enum Method {
58 Psk(psk::Psk),
59 Sae(SaeData),
60 DriverSae(Option<sae::Key>),
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum MethodName {
66 Psk,
67 Sae,
68}
69
70impl std::fmt::Debug for Method {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
72 match self {
73 Self::Psk(psk) => write!(f, "Method::Psk({:?})", psk),
74 Self::Sae(sae_data) => write!(
75 f,
76 "Method::Sae {{ peer: {:?}, pmk: {}, .. }}",
77 sae_data.peer,
78 match sae_data.pmk {
79 Some(_) => "Some(_)",
80 None => "None",
81 }
82 ),
83 Self::DriverSae(key) => write!(f, "Method::DriverSae({:?})", key),
84 }
85 }
86}
87
88impl Method {
89 pub fn from_config(cfg: Config) -> Result<Method, AuthError> {
90 match cfg {
91 Config::ComputedPsk(psk) => Ok(Method::Psk(psk)),
92 Config::Sae { ssid, password, mac, peer_mac } => {
93 let handshake = sae::new_sae_handshake(
95 DEFAULT_GROUP_ID,
96 AKM_SAE,
97 wlan_sae::PweMethod::Loop,
98 ssid,
99 password,
100 None, mac,
102 peer_mac.clone(),
103 )
104 .map_err(AuthError::FailedConstruction)?;
105 Ok(Method::Sae(SaeData {
106 peer: peer_mac,
107 pmk: None,
108 handshake,
109 retransmit_timeout_id: 0,
110 }))
111 }
112 Config::DriverSae { .. } => Ok(Method::DriverSae(None)),
113 }
114 }
115
116 pub fn on_eapol_key_frame<B: SplitByteSlice>(
118 &self,
119 _update_sink: &mut UpdateSink,
120 _frame: Dot11VerifiedKeyFrame<B>,
121 ) -> Result<(), AuthError> {
122 Ok(())
123 }
124
125 pub fn on_pmk_available(
128 &mut self,
129 pmk: &[u8],
130 pmkid: &[u8],
131 assoc_update_sink: &mut UpdateSink,
132 ) -> Result<(), AuthError> {
133 match self {
134 Method::DriverSae(key) => {
135 key.replace(sae::Key { pmk: pmk.to_vec(), pmkid: pmkid.to_vec() });
136 assoc_update_sink.push(SecAssocUpdate::Key(Key::Pmk(pmk.to_vec())));
137 Ok(())
138 }
139 _ => Err(AuthError::UnexpectedSaeEvent),
140 }
141 }
142
143 pub fn on_sae_handshake_ind(
144 &mut self,
145 assoc_update_sink: &mut UpdateSink,
146 ) -> Result<(), AuthError> {
147 match self {
148 Method::Sae(sae_data) => {
149 let mut sae_update_sink = sae::SaeUpdateSink::default();
150 sae_data.handshake.initiate_sae(&mut sae_update_sink);
151 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
152 Ok(())
153 }
154 _ => Err(AuthError::UnexpectedSaeEvent),
155 }
156 }
157
158 pub fn on_sae_frame_rx(
159 &mut self,
160 assoc_update_sink: &mut UpdateSink,
161 frame: SaeFrame,
162 ) -> Result<(), AuthError> {
163 match self {
164 Method::Sae(sae_data) => {
165 let mut sae_update_sink = sae::SaeUpdateSink::default();
166 let frame_rx = sae::AuthFrameRx {
167 seq: frame.seq_num,
168 status_code: frame.status_code,
169 body: &frame.sae_fields[..],
170 };
171 sae_data.handshake.handle_frame(&mut sae_update_sink, &frame_rx);
172 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
173 Ok(())
174 }
175 _ => Err(AuthError::UnexpectedSaeEvent),
176 }
177 }
178
179 pub fn on_sae_timeout(
180 &mut self,
181 assoc_update_sink: &mut UpdateSink,
182 event_id: u64,
183 ) -> Result<(), AuthError> {
184 match self {
185 Method::Sae(sae_data) => {
186 if sae_data.retransmit_timeout_id == event_id {
187 sae_data.retransmit_timeout_id += 1;
188 let mut sae_update_sink = sae::SaeUpdateSink::default();
189 sae_data
190 .handshake
191 .handle_timeout(&mut sae_update_sink, sae::Timeout::Retransmission);
192 process_sae_updates(sae_data, assoc_update_sink, sae_update_sink);
193 }
194 Ok(())
195 }
196 _ => Err(AuthError::UnexpectedSaeEvent),
197 }
198 }
199}
200
201fn process_sae_updates(
202 sae_data: &mut SaeData,
203 assoc_update_sink: &mut UpdateSink,
204 sae_update_sink: sae::SaeUpdateSink,
205) {
206 for sae_update in sae_update_sink {
207 match sae_update {
208 sae::SaeUpdate::SendFrame(frame) => {
209 let sae_frame = SaeFrame {
210 peer_sta_address: sae_data.peer.clone().to_array(),
211 status_code: frame.status_code,
212 seq_num: frame.seq,
213 sae_fields: frame.body,
214 };
215 assoc_update_sink.push(SecAssocUpdate::TxSaeFrame(sae_frame));
216 }
217 sae::SaeUpdate::Success(key) => {
218 sae_data.pmk.replace(key.clone());
219 assoc_update_sink.push(SecAssocUpdate::Key(Key::Pmk(key.pmk)));
220 assoc_update_sink.push(SecAssocUpdate::SaeAuthStatus(AuthStatus::Success));
221 }
222 sae::SaeUpdate::Reject(reason) => {
223 warn!("SAE handshake rejected: {:?}", reason);
224 let status = match reason {
225 sae::RejectReason::AuthFailed => {
226 AuthStatus::Rejected(AuthRejectedReason::AuthFailed)
227 }
228 sae::RejectReason::KeyExpiration => {
229 AuthStatus::Rejected(AuthRejectedReason::PmksaExpired)
230 }
231 sae::RejectReason::TooManyRetries => {
232 AuthStatus::Rejected(AuthRejectedReason::TooManyRetries)
233 }
234 sae::RejectReason::InternalError(_) => AuthStatus::InternalError,
235 };
236 assoc_update_sink.push(SecAssocUpdate::SaeAuthStatus(status));
237 }
238 sae::SaeUpdate::ResetTimeout(timer) => {
239 match timer {
240 sae::Timeout::KeyExpiration => (), sae::Timeout::Retransmission => {
242 sae_data.retransmit_timeout_id += 1;
243 assoc_update_sink.push(SecAssocUpdate::ScheduleSaeTimeout(
244 sae_data.retransmit_timeout_id,
245 ));
246 }
247 };
248 }
249 sae::SaeUpdate::CancelTimeout(timer) => {
250 match timer {
251 sae::Timeout::KeyExpiration => (),
252 sae::Timeout::Retransmission => {
253 sae_data.retransmit_timeout_id += 1;
254 }
255 };
256 }
257 }
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use super::*;
264 use assert_matches::assert_matches;
265 use fuchsia_sync::Mutex;
266 use std::sync::Arc;
267
268 #[test]
269 fn psk_rejects_sae() {
270 let mut auth = Method::from_config(Config::ComputedPsk(Box::new([0x8; 16])))
271 .expect("Failed to construct PSK auth method");
272 let mut sink = UpdateSink::default();
273 auth.on_sae_handshake_ind(&mut sink).expect_err("PSK auth method accepted SAE ind");
274 let frame = SaeFrame {
275 peer_sta_address: [0xaa; 6],
276 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
277 seq_num: 1,
278 sae_fields: vec![0u8; 10],
279 };
280 auth.on_sae_frame_rx(&mut sink, frame).expect_err("PSK auth method accepted SAE frame");
281 assert!(sink.is_empty());
283 }
284
285 #[derive(Default)]
286 struct SaeCounter {
287 initiated: bool,
288 handled_commits: u32,
289 handled_confirms: u32,
290 handled_timeouts: u32,
291 }
292
293 struct DummySae(Arc<Mutex<SaeCounter>>);
294
295 impl sae::SaeHandshake for DummySae {
297 fn initiate_sae(&mut self, sink: &mut sae::SaeUpdateSink) {
298 self.0.lock().initiated = true;
299 sink.push(sae::SaeUpdate::SendFrame(sae::AuthFrameTx {
300 seq: 1,
301 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
302 body: vec![],
303 }));
304 }
305 fn handle_commit(
306 &mut self,
307 _sink: &mut sae::SaeUpdateSink,
308 _commit_msg: &sae::CommitMsg<'_>,
309 ) {
310 assert!(self.0.lock().initiated);
311 self.0.lock().handled_commits += 1;
312 }
313 fn handle_confirm(
314 &mut self,
315 sink: &mut sae::SaeUpdateSink,
316 _confirm_msg: &sae::ConfirmMsg<'_>,
317 ) {
318 assert!(self.0.lock().initiated);
319 self.0.lock().handled_confirms += 1;
320 sink.push(sae::SaeUpdate::SendFrame(sae::AuthFrameTx {
321 seq: 2,
322 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
323 body: vec![],
324 }));
325 sink.push(sae::SaeUpdate::Success(sae::Key { pmk: vec![0xaa], pmkid: vec![0xbb] }))
326 }
327 fn handle_anti_clogging_token(
328 &mut self,
329 _sink: &mut sae::SaeUpdateSink,
330 _msg: &sae::AntiCloggingTokenMsg<'_>,
331 ) {
332 panic!("The SAE initiator should never receive an anti-clogging token.");
333 }
334 fn handle_timeout(&mut self, _sink: &mut sae::SaeUpdateSink, _timeout: sae::Timeout) {
335 self.0.lock().handled_timeouts += 1;
336 }
337 }
338
339 const COMMIT: [u8; 98] = [
342 0x13, 0x00, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
343 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
344 0xaa, 0xaa, 0xaa, 0xaa, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
345 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
346 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
347 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
348 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
349 ];
350 const CONFIRM: [u8; 34] = [
351 0xaa, 0xaa, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
352 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb,
353 0xbb, 0xbb, 0xbb, 0xbb,
354 ];
355
356 #[test]
357 fn sae_executes_handshake() {
358 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
359 let mut auth = Method::Sae(SaeData {
360 peer: MacAddr::from([0xaa; 6]),
361 pmk: None,
362 handshake: Box::new(DummySae(sae_counter.clone())),
363 retransmit_timeout_id: 0,
364 });
365 let mut sink = UpdateSink::default();
366
367 auth.on_sae_handshake_ind(&mut sink).expect("SAE handshake should accept SAE ind");
368 assert!(sae_counter.lock().initiated);
369 assert_matches!(sink.pop(), Some(SecAssocUpdate::TxSaeFrame(_)));
370
371 let commit_frame = SaeFrame {
372 peer_sta_address: [0xaa; 6],
373 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
374 seq_num: 1,
375 sae_fields: COMMIT.to_vec(),
376 };
377 auth.on_sae_frame_rx(&mut sink, commit_frame).expect("SAE handshake should accept commit");
378 assert_eq!(sae_counter.lock().handled_commits, 1);
379 assert!(sink.is_empty());
380
381 let confirm_frame = SaeFrame {
382 peer_sta_address: [0xaa; 6],
383 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
384 seq_num: 2,
385 sae_fields: CONFIRM.to_vec(),
386 };
387 auth.on_sae_frame_rx(&mut sink, confirm_frame)
388 .expect("SAE handshake should accept confirm");
389 assert_eq!(sae_counter.lock().handled_confirms, 1);
390 assert_eq!(sink.len(), 3);
391 assert_matches!(sink.remove(0), SecAssocUpdate::TxSaeFrame(_));
392 assert_matches!(sink.remove(0), SecAssocUpdate::Key(_));
393 assert_matches!(sink.remove(0), SecAssocUpdate::SaeAuthStatus(AuthStatus::Success));
394 match auth {
395 Method::Sae(sae_data) => assert!(sae_data.pmk.is_some()),
396 _ => unreachable!(),
397 };
398 }
399
400 #[test]
401 fn sae_handles_current_timeouts() {
402 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
403 let mut sae = Method::Sae(SaeData {
404 peer: MacAddr::from([0xaa; 6]),
405 pmk: None,
406 handshake: Box::new(DummySae(sae_counter.clone())),
407 retransmit_timeout_id: 0,
408 });
409 let mut sink = UpdateSink::default();
410
411 if let Method::Sae(data) = &mut sae {
412 process_sae_updates(
413 data,
414 &mut sink,
415 vec![sae::SaeUpdate::ResetTimeout(sae::Timeout::Retransmission)],
416 );
417 };
418 let event_id = assert_matches!(sink.pop(),
419 Some(SecAssocUpdate::ScheduleSaeTimeout(id)) => id
420 );
421 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
422 assert_eq!(sae_counter.lock().handled_timeouts, 1);
423 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
425 assert_eq!(sae_counter.lock().handled_timeouts, 1); if let Method::Sae(data) = &mut sae {
429 process_sae_updates(
430 data,
431 &mut sink,
432 vec![
433 sae::SaeUpdate::ResetTimeout(sae::Timeout::Retransmission),
434 sae::SaeUpdate::CancelTimeout(sae::Timeout::Retransmission),
435 ],
436 );
437 };
438 let event_id = assert_matches!(sink.pop(),
439 Some(SecAssocUpdate::ScheduleSaeTimeout(id)) => id
440 );
441 sae.on_sae_timeout(&mut sink, event_id).expect("SAE handshake should accept timeout");
442 assert_eq!(sae_counter.lock().handled_timeouts, 1); }
444
445 #[test]
446 fn sae_key_expiration_no_op() {
447 let sae_counter = Arc::new(Mutex::new(SaeCounter::default()));
448 let mut data = SaeData {
449 peer: MacAddr::from([0xaa; 6]),
450 pmk: None,
451 handshake: Box::new(DummySae(sae_counter.clone())),
452 retransmit_timeout_id: 0,
453 };
454 let mut sink = UpdateSink::new();
455 process_sae_updates(
456 &mut data,
457 &mut sink,
458 vec![
459 sae::SaeUpdate::ResetTimeout(sae::Timeout::KeyExpiration),
460 sae::SaeUpdate::CancelTimeout(sae::Timeout::KeyExpiration),
461 ],
462 );
463 assert!(sink.is_empty(), "KeyExpiration should not produce updates.");
464 }
465
466 #[test]
467 fn driver_sae_handles_pmk() {
468 let mut auth = Method::from_config(Config::DriverSae { password: vec![0xbb; 8] })
469 .expect("Failed to construct PSK auth method");
470 let mut sink = UpdateSink::default();
471 auth.on_pmk_available(&[0xcc; 8][..], &[0xdd; 8][..], &mut sink)
472 .expect("Driver SAE should handle on_pmk_available");
473 assert_eq!(sink.len(), 1);
474 let pmk = assert_matches!(sink.get(0), Some(SecAssocUpdate::Key(Key::Pmk(pmk))) => pmk);
475 assert_eq!(*pmk, vec![0xcc; 8]);
476 }
477
478 #[test]
479 fn driver_sae_rejects_sme_sae_calls() {
480 let mut auth = Method::from_config(Config::DriverSae { password: vec![0xbb; 8] })
481 .expect("Failed to construct PSK auth method");
482 let mut sink = UpdateSink::default();
483 auth.on_sae_handshake_ind(&mut sink).expect_err("Driver SAE shouldn't handle SAE ind");
484 let frame = SaeFrame {
485 peer_sta_address: [0xaa; 6],
486 status_code: fidl_fuchsia_wlan_ieee80211::StatusCode::Success,
487 seq_num: 1,
488 sae_fields: COMMIT.to_vec(),
489 };
490 auth.on_sae_frame_rx(&mut sink, frame).expect_err("Driver SAE shouldn't handle frames");
491 auth.on_sae_timeout(&mut sink, 0).expect_err("Driver SAE shouldn't handle SAE timeouts");
492 assert!(sink.is_empty());
493 }
494}