1use super::{AuthFrameRx, AuthFrameTx};
6use anyhow::{anyhow, bail, Error};
7use fidl_fuchsia_wlan_ieee80211::StatusCode;
8use wlan_common::append::Append;
9use wlan_common::buffer_reader::BufferReader;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct AntiCloggingTokenMsg<'a> {
15 pub group_id: u16,
16 pub anti_clogging_token: &'a [u8],
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CommitMsg<'a> {
23 pub group_id: u16,
24 pub anti_clogging_token: Option<&'a [u8]>,
25 pub scalar: &'a [u8],
26 pub element: &'a [u8],
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ConfirmMsg<'a> {
33 pub send_confirm: u16,
34 pub confirm: &'a [u8],
35}
36
37#[derive(Debug)]
38pub enum ParseSuccess<'a> {
39 Commit(CommitMsg<'a>),
40 Confirm(ConfirmMsg<'a>),
41 AntiCloggingToken(AntiCloggingTokenMsg<'a>),
42}
43
44fn get_scalar_and_element_len_bytes(group_id: u16) -> Result<(usize, usize), Error> {
45 match group_id {
46 19 => Ok((32, 64)),
47 _ => bail!("Unsupported SAE group ID: {}", group_id),
48 }
49}
50
51pub fn parse<'a>(frame: &'a AuthFrameRx<'_>) -> Result<ParseSuccess<'a>, Error> {
52 match (frame.seq, frame.status_code) {
54 (1, StatusCode::Success) => parse_commit(frame.body).map(ParseSuccess::Commit),
55 (1, StatusCode::AntiCloggingTokenRequired) => {
56 parse_anti_clogging_token(frame.body).map(ParseSuccess::AntiCloggingToken)
57 }
58 (2, StatusCode::Success) => parse_confirm(frame.body).map(ParseSuccess::Confirm),
59 _ => bail!("Could not parse received SAE frame"),
60 }
61}
62
63fn parse_anti_clogging_token(body: &[u8]) -> Result<AntiCloggingTokenMsg<'_>, Error> {
64 let mut reader = BufferReader::new(body);
65 let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
66 if reader.bytes_remaining() == 0 {
67 bail!("Commit indicated AntiCloggingTokenRequired, but no token provided");
68 }
69 let anti_clogging_token = reader.into_remaining();
70 Ok(AntiCloggingTokenMsg { group_id, anti_clogging_token })
71}
72
73fn parse_commit(body: &[u8]) -> Result<CommitMsg<'_>, Error> {
74 let mut reader = BufferReader::new(body);
75 let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
76
77 let (scalar_size, element_size) = get_scalar_and_element_len_bytes(group_id)?;
78 let bytes_remaining = reader.bytes_remaining();
79 let anti_clogging_token = match bytes_remaining.cmp(&(scalar_size + element_size)) {
80 std::cmp::Ordering::Equal => None,
81 std::cmp::Ordering::Greater => Some(
82 reader
83 .read_bytes(bytes_remaining - scalar_size - element_size)
84 .ok_or_else(|| anyhow!("Unexpected buffer end"))?,
85 ),
86 std::cmp::Ordering::Less => bail!("Buffer truncated"),
87 };
88
89 let scalar = reader.read_bytes(scalar_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
90 let element =
91 reader.read_bytes(element_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
92
93 Ok(CommitMsg { group_id, scalar, element, anti_clogging_token })
94}
95
96const CONFIRM_BYTES: usize = 32;
97
98fn parse_confirm(body: &[u8]) -> Result<ConfirmMsg<'_>, Error> {
99 let mut reader = BufferReader::new(body);
100 let send_confirm =
101 reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read send confirm"))?;
102 let confirm = reader.read_bytes(CONFIRM_BYTES).ok_or_else(|| anyhow!("Buffer truncated"))?;
103 match reader.bytes_remaining() {
104 0 => Ok(ConfirmMsg { send_confirm, confirm }),
105 _ => bail!("Buffer too long"),
106 }
107}
108
109#[allow(unused_must_use)]
111pub fn write_commit(
112 group_id: u16,
113 scalar: &[u8],
114 element: &[u8],
115 anti_clogging_token: &[u8],
116) -> AuthFrameTx {
117 let mut body = Vec::with_capacity(2 + scalar.len() + element.len() + anti_clogging_token.len());
118 body.append_value(&group_id);
119 body.append_bytes(anti_clogging_token);
120 body.append_bytes(scalar);
121 body.append_bytes(element);
122 AuthFrameTx { seq: 1, status_code: StatusCode::Success, body }
123}
124
125#[allow(unused_must_use)]
127#[allow(dead_code)]
129pub fn write_token(group_id: u16, token: &[u8]) -> AuthFrameTx {
130 let mut body = Vec::with_capacity(2 + token.len());
131 body.append_value(&group_id);
132 body.append_bytes(token);
133 AuthFrameTx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body }
134}
135
136#[allow(unused_must_use)]
138pub fn write_confirm(send_confirm: u16, confirm: &[u8]) -> AuthFrameTx {
139 let mut body = Vec::with_capacity(2 + confirm.len());
140 body.append_value(&send_confirm);
141 body.append_bytes(confirm);
142 AuthFrameTx { seq: 2, status_code: StatusCode::Success, body }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use wlan_common::assert_variant;
149
150 #[rustfmt::skip]
151 const ECC_COMMIT_BODY: &[u8] = &[
152 19, 00,
154 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
156 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
158 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
159 ];
160
161 #[rustfmt::skip]
162 const ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN: &[u8] = &[
163 19, 00,
165 4, 4, 4, 4, 4, 4, 4, 4,
167 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
169 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
171 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
172 ];
173
174 #[rustfmt::skip]
175 const ECC_CONFIRM_BODY: &[u8] = &[
176 0x01, 0x00,
178 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
180 ];
181
182 #[rustfmt::skip]
183 const ECC_ACT_REQUIRED_BODY: &[u8] = &[
184 19, 00,
186 4,4,4,4,4,4,4,4
188 ];
189
190 #[test]
191 fn test_parse_commit() {
192 let commit_msg =
193 AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: ECC_COMMIT_BODY };
194 let parse_result = parse(&commit_msg);
195 let commit = assert_variant!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
196 assert_eq!(commit.group_id, 19);
197 assert_eq!(commit.scalar, &[1u8; 32][..]);
198 assert_eq!(commit.element, &[2u8; 64][..]);
199 assert!(commit.anti_clogging_token.is_none());
200 }
201
202 #[test]
203 fn commit_with_anti_clogging_token() {
204 let commit_msg = AuthFrameRx {
205 seq: 1,
206 status_code: StatusCode::Success,
207 body: ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN,
208 };
209 let parse_result = parse(&commit_msg);
210 let commit = assert_variant!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
211 assert_eq!(commit.group_id, 19);
212 let anti_clogging_token = assert_variant!(commit.anti_clogging_token, Some(token) => token);
213 assert_eq!(anti_clogging_token, &[0x4; 8]);
214 assert_eq!(commit.scalar, &[1u8; 32][..]);
215 assert_eq!(commit.element, &[2u8; 64][..]);
216 }
217
218 #[test]
219 fn unknown_group_id_commit() {
220 let mut body = ECC_COMMIT_BODY.to_vec();
221 body[0] = 0xff; let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &body[..] };
223 assert_variant!(parse(&commit_msg), Err(e) => {
224 assert!(format!("{:?}", e).contains("Unsupported SAE group ID: 255"))
225 });
226 }
227
228 #[test]
229 fn truncated_commit() {
230 let commit_msg =
231 AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &ECC_COMMIT_BODY[..20] };
232 assert_variant!(parse(&commit_msg), Err(e) => {
233 assert!(format!("{:?}", e).contains("Buffer truncated"))
234 });
235
236 let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &[] };
237 assert_variant!(parse(&commit_msg), Err(e) => {
238 assert!(format!("{:?}", e).contains("Failed to read group ID"))
239 });
240 }
241
242 #[test]
243 fn test_parse_confirm() {
244 let confirm_msg =
245 AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: ECC_CONFIRM_BODY };
246 let parse_result = parse(&confirm_msg);
247 let confirm = assert_variant!(parse_result, Ok(ParseSuccess::Confirm(confirm)) => confirm);
248 assert_eq!(confirm.send_confirm, 1);
249 assert_eq!(confirm.confirm, &[3u8; 32][..]);
250 }
251
252 #[test]
253 fn truncated_confirm() {
254 let confirm_msg =
255 AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &ECC_CONFIRM_BODY[..20] };
256 assert_variant!(parse(&confirm_msg), Err(e) => {
257 assert!(format!("{:?}", e).contains("Buffer truncated"))
258 });
259
260 let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &[] };
261 assert_variant!(parse(&confirm_msg), Err(e) => {
262 assert!(format!("{:?}", e).contains("Failed to read send confirm"))
263 });
264 }
265
266 #[test]
267 fn padded_confirm() {
268 let mut body = ECC_CONFIRM_BODY.to_vec();
269 body.push(0xff);
270 let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &body[..] };
271 assert_variant!(parse(&confirm_msg), Err(e) => {
272 assert!(format!("{:?}", e).contains("Buffer too long"))
273 });
274 }
275
276 #[test]
277 fn test_parse_anti_clogging_token_required() {
278 let act_required = AuthFrameRx {
279 seq: 1,
280 status_code: StatusCode::AntiCloggingTokenRequired,
281 body: ECC_ACT_REQUIRED_BODY,
282 };
283 let parse_result = parse(&act_required);
284 let act = assert_variant!(parse_result, Ok(ParseSuccess::AntiCloggingToken(act)) => act);
285 assert_eq!(act.group_id, 19);
286 assert_eq!(act.anti_clogging_token, &[0x4; 8][..]);
287 }
288
289 #[test]
290 fn truncated_anti_clogging_token() {
291 let act_required = AuthFrameRx {
292 seq: 1,
293 status_code: StatusCode::AntiCloggingTokenRequired,
294 body: &[19, 00],
295 };
296 assert_variant!(parse(&act_required), Err(e) => {
297 assert!(format!("{:?}", e).contains("no token provided"))
298 });
299
300 let act_required =
301 AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[19] };
302 assert_variant!(parse(&act_required), Err(e) => {
303 assert!(format!("{:?}", e).contains("Failed to read group ID"))
304 });
305
306 let act_required =
307 AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[] };
308 assert_variant!(parse(&act_required), Err(e) => {
309 assert!(format!("{:?}", e).contains("Failed to read group ID"))
310 });
311 }
312
313 #[test]
314 fn test_write_commit() {
315 let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[]);
316 assert_eq!(auth_frame.seq, 1);
317 assert_eq!(auth_frame.status_code, StatusCode::Success);
318 assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY);
319 }
320
321 #[test]
322 fn test_write_commit_with_anti_clogging_token() {
323 let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[4u8; 8]);
324 assert_eq!(auth_frame.seq, 1);
325 assert_eq!(auth_frame.status_code, StatusCode::Success);
326 let mut expected_body = ECC_COMMIT_BODY.to_vec();
327 expected_body.append(&mut vec![4u8; 8]);
328 assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN);
329 }
330
331 #[test]
332 fn test_write_confirm() {
333 let auth_frame = write_confirm(1, &[3u8; 32]);
334 assert_eq!(auth_frame.seq, 2);
335 assert_eq!(auth_frame.status_code, StatusCode::Success);
336 assert_eq!(&auth_frame.body[..], ECC_CONFIRM_BODY);
337 }
338
339 #[test]
340 fn test_write_anticlogging_token() {
341 let auth_frame = write_token(19, &[4u8; 8]);
342 assert_eq!(auth_frame.seq, 1);
343 assert_eq!(auth_frame.status_code, StatusCode::AntiCloggingTokenRequired);
344 assert_eq!(&auth_frame.body[..], ECC_ACT_REQUIRED_BODY);
345 }
346}