1use super::{AuthFrameRx, AuthFrameTx, PweMethod};
6use anyhow::{Error, anyhow, bail};
7use fidl_fuchsia_wlan_ieee80211::StatusCode;
8use wlan_common::append::Append;
9use wlan_common::buffer_reader::BufferReader;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct AntiCloggingTokenMsg<'a> {
15 pub group_id: u16,
16 pub anti_clogging_token: &'a [u8],
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CommitMsg<'a> {
23 pub group_id: u16,
24 pub anti_clogging_token: Option<&'a [u8]>,
25 pub scalar: &'a [u8],
26 pub element: &'a [u8],
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ConfirmMsg<'a> {
33 pub send_confirm: u16,
34 pub confirm: &'a [u8],
35}
36
37#[derive(Debug)]
38pub enum ParseSuccess<'a> {
39 Commit(CommitMsg<'a>),
40 Confirm(ConfirmMsg<'a>),
41 AntiCloggingToken(AntiCloggingTokenMsg<'a>),
42}
43
44fn get_scalar_and_element_len_bytes(group_id: u16) -> Result<(usize, usize), Error> {
45 match group_id {
46 19 => Ok((32, 64)),
47 _ => bail!("Unsupported SAE group ID: {}", group_id),
48 }
49}
50
51pub fn parse<'a>(frame: &'a AuthFrameRx<'_>) -> Result<ParseSuccess<'a>, Error> {
52 match (frame.seq, frame.status_code) {
54 (1, StatusCode::Success) => parse_commit(frame.body).map(ParseSuccess::Commit),
55 (1, StatusCode::SaeHashToElement) => parse_commit(frame.body).map(ParseSuccess::Commit),
56 (1, StatusCode::AntiCloggingTokenRequired) => {
57 parse_anti_clogging_token(frame.body).map(ParseSuccess::AntiCloggingToken)
58 }
59 (2, StatusCode::Success) => parse_confirm(frame.body).map(ParseSuccess::Confirm),
60 _ => bail!("Could not parse received SAE frame"),
61 }
62}
63
64fn parse_anti_clogging_token(body: &[u8]) -> Result<AntiCloggingTokenMsg<'_>, Error> {
65 let mut reader = BufferReader::new(body);
66 let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
67 if reader.bytes_remaining() == 0 {
68 bail!("Commit indicated AntiCloggingTokenRequired, but no token provided");
69 }
70 let anti_clogging_token = reader.into_remaining();
71 Ok(AntiCloggingTokenMsg { group_id, anti_clogging_token })
72}
73
74fn parse_commit(body: &[u8]) -> Result<CommitMsg<'_>, Error> {
75 let mut reader = BufferReader::new(body);
76 let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
77
78 let (scalar_size, element_size) = get_scalar_and_element_len_bytes(group_id)?;
79 let bytes_remaining = reader.bytes_remaining();
80 let anti_clogging_token = match bytes_remaining.cmp(&(scalar_size + element_size)) {
81 std::cmp::Ordering::Equal => None,
82 std::cmp::Ordering::Greater => Some(
83 reader
84 .read_bytes(bytes_remaining - scalar_size - element_size)
85 .ok_or_else(|| anyhow!("Unexpected buffer end"))?,
86 ),
87 std::cmp::Ordering::Less => bail!("Buffer truncated"),
88 };
89
90 let scalar = reader.read_bytes(scalar_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
91 let element =
92 reader.read_bytes(element_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
93
94 Ok(CommitMsg { group_id, scalar, element, anti_clogging_token })
95}
96
97const CONFIRM_BYTES: usize = 32;
98
99fn parse_confirm(body: &[u8]) -> Result<ConfirmMsg<'_>, Error> {
100 let mut reader = BufferReader::new(body);
101 let send_confirm =
102 reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read send confirm"))?;
103 let confirm = reader.read_bytes(CONFIRM_BYTES).ok_or_else(|| anyhow!("Buffer truncated"))?;
104 match reader.bytes_remaining() {
105 0 => Ok(ConfirmMsg { send_confirm, confirm }),
106 _ => bail!("Buffer too long"),
107 }
108}
109
110#[allow(unused_must_use)]
112pub fn write_commit(
113 group_id: u16,
114 scalar: &[u8],
115 element: &[u8],
116 anti_clogging_token: &[u8],
117 pwe_method: PweMethod,
118) -> AuthFrameTx {
119 let mut body = Vec::with_capacity(2 + scalar.len() + element.len() + anti_clogging_token.len());
120 body.append_value(&group_id);
121 body.append_bytes(anti_clogging_token);
122 body.append_bytes(scalar);
123 body.append_bytes(element);
124 let status_code = match pwe_method {
125 PweMethod::Direct => StatusCode::SaeHashToElement,
126 _ => StatusCode::Success,
127 };
128 AuthFrameTx { seq: 1, status_code, body }
129}
130
131#[allow(unused_must_use)]
133#[allow(dead_code)]
135pub fn write_token(group_id: u16, token: &[u8]) -> AuthFrameTx {
136 let mut body = Vec::with_capacity(2 + token.len());
137 body.append_value(&group_id);
138 body.append_bytes(token);
139 AuthFrameTx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body }
140}
141
142#[allow(unused_must_use)]
144pub fn write_confirm(send_confirm: u16, confirm: &[u8]) -> AuthFrameTx {
145 let mut body = Vec::with_capacity(2 + confirm.len());
146 body.append_value(&send_confirm);
147 body.append_bytes(confirm);
148 AuthFrameTx { seq: 2, status_code: StatusCode::Success, body }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use assert_matches::assert_matches;
155
156 #[rustfmt::skip]
157 const ECC_COMMIT_BODY: &[u8] = &[
158 19, 00,
160 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
162 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
164 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
165 ];
166
167 #[rustfmt::skip]
168 const ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN: &[u8] = &[
169 19, 00,
171 4, 4, 4, 4, 4, 4, 4, 4,
173 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
175 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
177 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
178 ];
179
180 #[rustfmt::skip]
181 const ECC_CONFIRM_BODY: &[u8] = &[
182 0x01, 0x00,
184 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
186 ];
187
188 #[rustfmt::skip]
189 const ECC_ACT_REQUIRED_BODY: &[u8] = &[
190 19, 00,
192 4,4,4,4,4,4,4,4
194 ];
195
196 #[test]
197 fn test_parse_commit() {
198 let commit_msg =
199 AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: ECC_COMMIT_BODY };
200 let parse_result = parse(&commit_msg);
201 let commit = assert_matches!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
202 assert_eq!(commit.group_id, 19);
203 assert_eq!(commit.scalar, &[1u8; 32][..]);
204 assert_eq!(commit.element, &[2u8; 64][..]);
205 assert!(commit.anti_clogging_token.is_none());
206 }
207
208 #[test]
209 fn commit_with_anti_clogging_token() {
210 let commit_msg = AuthFrameRx {
211 seq: 1,
212 status_code: StatusCode::Success,
213 body: ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN,
214 };
215 let parse_result = parse(&commit_msg);
216 let commit = assert_matches!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
217 assert_eq!(commit.group_id, 19);
218 let anti_clogging_token = assert_matches!(commit.anti_clogging_token, Some(token) => token);
219 assert_eq!(anti_clogging_token, &[0x4; 8]);
220 assert_eq!(commit.scalar, &[1u8; 32][..]);
221 assert_eq!(commit.element, &[2u8; 64][..]);
222 }
223
224 #[test]
225 fn unknown_group_id_commit() {
226 let mut body = ECC_COMMIT_BODY.to_vec();
227 body[0] = 0xff; let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &body[..] };
229 assert_matches!(parse(&commit_msg), Err(e) => {
230 assert!(format!("{:?}", e).contains("Unsupported SAE group ID: 255"))
231 });
232 }
233
234 #[test]
235 fn truncated_commit() {
236 let commit_msg =
237 AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &ECC_COMMIT_BODY[..20] };
238 assert_matches!(parse(&commit_msg), Err(e) => {
239 assert!(format!("{:?}", e).contains("Buffer truncated"))
240 });
241
242 let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &[] };
243 assert_matches!(parse(&commit_msg), Err(e) => {
244 assert!(format!("{:?}", e).contains("Failed to read group ID"))
245 });
246 }
247
248 #[test]
249 fn test_parse_confirm() {
250 let confirm_msg =
251 AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: ECC_CONFIRM_BODY };
252 let parse_result = parse(&confirm_msg);
253 let confirm = assert_matches!(parse_result, Ok(ParseSuccess::Confirm(confirm)) => confirm);
254 assert_eq!(confirm.send_confirm, 1);
255 assert_eq!(confirm.confirm, &[3u8; 32][..]);
256 }
257
258 #[test]
259 fn truncated_confirm() {
260 let confirm_msg =
261 AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &ECC_CONFIRM_BODY[..20] };
262 assert_matches!(parse(&confirm_msg), Err(e) => {
263 assert!(format!("{:?}", e).contains("Buffer truncated"))
264 });
265
266 let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &[] };
267 assert_matches!(parse(&confirm_msg), Err(e) => {
268 assert!(format!("{:?}", e).contains("Failed to read send confirm"))
269 });
270 }
271
272 #[test]
273 fn padded_confirm() {
274 let mut body = ECC_CONFIRM_BODY.to_vec();
275 body.push(0xff);
276 let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &body[..] };
277 assert_matches!(parse(&confirm_msg), Err(e) => {
278 assert!(format!("{:?}", e).contains("Buffer too long"))
279 });
280 }
281
282 #[test]
283 fn test_parse_anti_clogging_token_required() {
284 let act_required = AuthFrameRx {
285 seq: 1,
286 status_code: StatusCode::AntiCloggingTokenRequired,
287 body: ECC_ACT_REQUIRED_BODY,
288 };
289 let parse_result = parse(&act_required);
290 let act = assert_matches!(parse_result, Ok(ParseSuccess::AntiCloggingToken(act)) => act);
291 assert_eq!(act.group_id, 19);
292 assert_eq!(act.anti_clogging_token, &[0x4; 8][..]);
293 }
294
295 #[test]
296 fn truncated_anti_clogging_token() {
297 let act_required = AuthFrameRx {
298 seq: 1,
299 status_code: StatusCode::AntiCloggingTokenRequired,
300 body: &[19, 00],
301 };
302 assert_matches!(parse(&act_required), Err(e) => {
303 assert!(format!("{:?}", e).contains("no token provided"))
304 });
305
306 let act_required =
307 AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[19] };
308 assert_matches!(parse(&act_required), Err(e) => {
309 assert!(format!("{:?}", e).contains("Failed to read group ID"))
310 });
311
312 let act_required =
313 AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[] };
314 assert_matches!(parse(&act_required), Err(e) => {
315 assert!(format!("{:?}", e).contains("Failed to read group ID"))
316 });
317 }
318
319 #[test]
320 fn test_write_commit() {
321 let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[], PweMethod::Loop);
322 assert_eq!(auth_frame.seq, 1);
323 assert_eq!(auth_frame.status_code, StatusCode::Success);
324 assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY);
325 }
326
327 #[test]
328 fn test_write_commit_direct() {
329 let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[], PweMethod::Direct);
330 assert_eq!(auth_frame.seq, 1);
331 assert_eq!(auth_frame.status_code, StatusCode::SaeHashToElement);
332 assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY);
333 }
334
335 #[test]
336 fn test_write_commit_with_anti_clogging_token() {
337 let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[4u8; 8], PweMethod::Loop);
338 assert_eq!(auth_frame.seq, 1);
339 assert_eq!(auth_frame.status_code, StatusCode::Success);
340 let mut expected_body = ECC_COMMIT_BODY.to_vec();
341 expected_body.append(&mut vec![4u8; 8]);
342 assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN);
343 }
344
345 #[test]
346 fn test_write_confirm() {
347 let auth_frame = write_confirm(1, &[3u8; 32]);
348 assert_eq!(auth_frame.seq, 2);
349 assert_eq!(auth_frame.status_code, StatusCode::Success);
350 assert_eq!(&auth_frame.body[..], ECC_CONFIRM_BODY);
351 }
352
353 #[test]
354 fn test_write_anticlogging_token() {
355 let auth_frame = write_token(19, &[4u8; 8]);
356 assert_eq!(auth_frame.seq, 1);
357 assert_eq!(auth_frame.status_code, StatusCode::AntiCloggingTokenRequired);
358 assert_eq!(&auth_frame.body[..], ECC_ACT_REQUIRED_BODY);
359 }
360}