wlan_sae/
frame.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::{AuthFrameRx, AuthFrameTx};
6use anyhow::{anyhow, bail, Error};
7use fidl_fuchsia_wlan_ieee80211::StatusCode;
8use wlan_common::append::Append;
9use wlan_common::buffer_reader::BufferReader;
10
11/// IEEE Std 802.11-2016, 12.4.6
12/// An anticlogging token sent to a peer.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct AntiCloggingTokenMsg<'a> {
15    pub group_id: u16,
16    pub anti_clogging_token: &'a [u8],
17}
18
19/// IEEE Std 802.11-2016, 12.4.7.4
20/// An SAE Commit message received or sent to a peer.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CommitMsg<'a> {
23    pub group_id: u16,
24    pub anti_clogging_token: Option<&'a [u8]>,
25    pub scalar: &'a [u8],
26    pub element: &'a [u8],
27}
28
29/// IEEE Std 802.11-2016, 12.4.7.5
30/// An SAE Confirm message received or sent to a peer.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ConfirmMsg<'a> {
33    pub send_confirm: u16,
34    pub confirm: &'a [u8],
35}
36
37#[derive(Debug)]
38pub enum ParseSuccess<'a> {
39    Commit(CommitMsg<'a>),
40    Confirm(ConfirmMsg<'a>),
41    AntiCloggingToken(AntiCloggingTokenMsg<'a>),
42}
43
44fn get_scalar_and_element_len_bytes(group_id: u16) -> Result<(usize, usize), Error> {
45    match group_id {
46        19 => Ok((32, 64)),
47        _ => bail!("Unsupported SAE group ID: {}", group_id),
48    }
49}
50
51pub fn parse<'a>(frame: &'a AuthFrameRx<'_>) -> Result<ParseSuccess<'a>, Error> {
52    // IEEE 802.11 9.3.3.12 Table 9-36 specifies all SAE auth frame formats.
53    match (frame.seq, frame.status_code) {
54        (1, StatusCode::Success) => parse_commit(frame.body).map(ParseSuccess::Commit),
55        (1, StatusCode::AntiCloggingTokenRequired) => {
56            parse_anti_clogging_token(frame.body).map(ParseSuccess::AntiCloggingToken)
57        }
58        (2, StatusCode::Success) => parse_confirm(frame.body).map(ParseSuccess::Confirm),
59        _ => bail!("Could not parse received SAE frame"),
60    }
61}
62
63fn parse_anti_clogging_token(body: &[u8]) -> Result<AntiCloggingTokenMsg<'_>, Error> {
64    let mut reader = BufferReader::new(body);
65    let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
66    if reader.bytes_remaining() == 0 {
67        bail!("Commit indicated AntiCloggingTokenRequired, but no token provided");
68    }
69    let anti_clogging_token = reader.into_remaining();
70    Ok(AntiCloggingTokenMsg { group_id, anti_clogging_token })
71}
72
73fn parse_commit(body: &[u8]) -> Result<CommitMsg<'_>, Error> {
74    let mut reader = BufferReader::new(body);
75    let group_id = reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read group ID"))?;
76
77    let (scalar_size, element_size) = get_scalar_and_element_len_bytes(group_id)?;
78    let bytes_remaining = reader.bytes_remaining();
79    let anti_clogging_token = match bytes_remaining.cmp(&(scalar_size + element_size)) {
80        std::cmp::Ordering::Equal => None,
81        std::cmp::Ordering::Greater => Some(
82            reader
83                .read_bytes(bytes_remaining - scalar_size - element_size)
84                .ok_or_else(|| anyhow!("Unexpected buffer end"))?,
85        ),
86        std::cmp::Ordering::Less => bail!("Buffer truncated"),
87    };
88
89    let scalar = reader.read_bytes(scalar_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
90    let element =
91        reader.read_bytes(element_size).ok_or_else(|| anyhow!("Unexpected buffer end"))?;
92
93    Ok(CommitMsg { group_id, scalar, element, anti_clogging_token })
94}
95
96const CONFIRM_BYTES: usize = 32;
97
98fn parse_confirm(body: &[u8]) -> Result<ConfirmMsg<'_>, Error> {
99    let mut reader = BufferReader::new(body);
100    let send_confirm =
101        reader.read_value::<u16>().ok_or_else(|| anyhow!("Failed to read send confirm"))?;
102    let confirm = reader.read_bytes(CONFIRM_BYTES).ok_or_else(|| anyhow!("Buffer truncated"))?;
103    match reader.bytes_remaining() {
104        0 => Ok(ConfirmMsg { send_confirm, confirm }),
105        _ => bail!("Buffer too long"),
106    }
107}
108
109// Allow skipping checks on append_bytes() and append_value()
110#[allow(unused_must_use)]
111pub fn write_commit(
112    group_id: u16,
113    scalar: &[u8],
114    element: &[u8],
115    anti_clogging_token: &[u8],
116) -> AuthFrameTx {
117    let mut body = Vec::with_capacity(2 + scalar.len() + element.len() + anti_clogging_token.len());
118    body.append_value(&group_id);
119    body.append_bytes(anti_clogging_token);
120    body.append_bytes(scalar);
121    body.append_bytes(element);
122    AuthFrameTx { seq: 1, status_code: StatusCode::Success, body }
123}
124
125// Allow skipping checks on append_bytes() and append_value()
126#[allow(unused_must_use)]
127// This function is currently unused, but planned for future use
128#[allow(dead_code)]
129pub fn write_token(group_id: u16, token: &[u8]) -> AuthFrameTx {
130    let mut body = Vec::with_capacity(2 + token.len());
131    body.append_value(&group_id);
132    body.append_bytes(token);
133    AuthFrameTx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body }
134}
135
136// Allow skipping checks on append_bytes() and append_value()
137#[allow(unused_must_use)]
138pub fn write_confirm(send_confirm: u16, confirm: &[u8]) -> AuthFrameTx {
139    let mut body = Vec::with_capacity(2 + confirm.len());
140    body.append_value(&send_confirm);
141    body.append_bytes(confirm);
142    AuthFrameTx { seq: 2, status_code: StatusCode::Success, body }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use wlan_common::assert_variant;
149
150    #[rustfmt::skip]
151    const ECC_COMMIT_BODY: &[u8] = &[
152        // group id
153        19, 00,
154        // scalar [0x1; 32]
155        1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
156        // element [0x2; 64]
157        2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
158        2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
159    ];
160
161    #[rustfmt::skip]
162    const ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN: &[u8] = &[
163        // group id
164        19, 00,
165        // anti-clogging token
166        4, 4, 4, 4, 4, 4, 4, 4,
167        // scalar [0x1; 32]
168        1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
169        // element [0x2; 64]
170        2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
171        2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
172    ];
173
174    #[rustfmt::skip]
175    const ECC_CONFIRM_BODY: &[u8] = &[
176        // send-confirm
177        0x01, 0x00,
178        // confirm [0x3; 32]
179        3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
180    ];
181
182    #[rustfmt::skip]
183    const ECC_ACT_REQUIRED_BODY: &[u8] = &[
184        // group id
185        19, 00,
186        // anti-clogging token [0x4; 8]
187        4,4,4,4,4,4,4,4
188    ];
189
190    #[test]
191    fn test_parse_commit() {
192        let commit_msg =
193            AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: ECC_COMMIT_BODY };
194        let parse_result = parse(&commit_msg);
195        let commit = assert_variant!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
196        assert_eq!(commit.group_id, 19);
197        assert_eq!(commit.scalar, &[1u8; 32][..]);
198        assert_eq!(commit.element, &[2u8; 64][..]);
199        assert!(commit.anti_clogging_token.is_none());
200    }
201
202    #[test]
203    fn commit_with_anti_clogging_token() {
204        let commit_msg = AuthFrameRx {
205            seq: 1,
206            status_code: StatusCode::Success,
207            body: ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN,
208        };
209        let parse_result = parse(&commit_msg);
210        let commit = assert_variant!(parse_result, Ok(ParseSuccess::Commit(commit)) => commit);
211        assert_eq!(commit.group_id, 19);
212        let anti_clogging_token = assert_variant!(commit.anti_clogging_token, Some(token) => token);
213        assert_eq!(anti_clogging_token, &[0x4; 8]);
214        assert_eq!(commit.scalar, &[1u8; 32][..]);
215        assert_eq!(commit.element, &[2u8; 64][..]);
216    }
217
218    #[test]
219    fn unknown_group_id_commit() {
220        let mut body = ECC_COMMIT_BODY.to_vec();
221        body[0] = 0xff; // not a real group
222        let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &body[..] };
223        assert_variant!(parse(&commit_msg), Err(e) => {
224            assert!(format!("{:?}", e).contains("Unsupported SAE group ID: 255"))
225        });
226    }
227
228    #[test]
229    fn truncated_commit() {
230        let commit_msg =
231            AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &ECC_COMMIT_BODY[..20] };
232        assert_variant!(parse(&commit_msg), Err(e) => {
233            assert!(format!("{:?}", e).contains("Buffer truncated"))
234        });
235
236        let commit_msg = AuthFrameRx { seq: 1, status_code: StatusCode::Success, body: &[] };
237        assert_variant!(parse(&commit_msg), Err(e) => {
238            assert!(format!("{:?}", e).contains("Failed to read group ID"))
239        });
240    }
241
242    #[test]
243    fn test_parse_confirm() {
244        let confirm_msg =
245            AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: ECC_CONFIRM_BODY };
246        let parse_result = parse(&confirm_msg);
247        let confirm = assert_variant!(parse_result, Ok(ParseSuccess::Confirm(confirm)) => confirm);
248        assert_eq!(confirm.send_confirm, 1);
249        assert_eq!(confirm.confirm, &[3u8; 32][..]);
250    }
251
252    #[test]
253    fn truncated_confirm() {
254        let confirm_msg =
255            AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &ECC_CONFIRM_BODY[..20] };
256        assert_variant!(parse(&confirm_msg), Err(e) => {
257            assert!(format!("{:?}", e).contains("Buffer truncated"))
258        });
259
260        let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &[] };
261        assert_variant!(parse(&confirm_msg), Err(e) => {
262            assert!(format!("{:?}", e).contains("Failed to read send confirm"))
263        });
264    }
265
266    #[test]
267    fn padded_confirm() {
268        let mut body = ECC_CONFIRM_BODY.to_vec();
269        body.push(0xff);
270        let confirm_msg = AuthFrameRx { seq: 2, status_code: StatusCode::Success, body: &body[..] };
271        assert_variant!(parse(&confirm_msg), Err(e) => {
272            assert!(format!("{:?}", e).contains("Buffer too long"))
273        });
274    }
275
276    #[test]
277    fn test_parse_anti_clogging_token_required() {
278        let act_required = AuthFrameRx {
279            seq: 1,
280            status_code: StatusCode::AntiCloggingTokenRequired,
281            body: ECC_ACT_REQUIRED_BODY,
282        };
283        let parse_result = parse(&act_required);
284        let act = assert_variant!(parse_result, Ok(ParseSuccess::AntiCloggingToken(act)) => act);
285        assert_eq!(act.group_id, 19);
286        assert_eq!(act.anti_clogging_token, &[0x4; 8][..]);
287    }
288
289    #[test]
290    fn truncated_anti_clogging_token() {
291        let act_required = AuthFrameRx {
292            seq: 1,
293            status_code: StatusCode::AntiCloggingTokenRequired,
294            body: &[19, 00],
295        };
296        assert_variant!(parse(&act_required), Err(e) => {
297            assert!(format!("{:?}", e).contains("no token provided"))
298        });
299
300        let act_required =
301            AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[19] };
302        assert_variant!(parse(&act_required), Err(e) => {
303            assert!(format!("{:?}", e).contains("Failed to read group ID"))
304        });
305
306        let act_required =
307            AuthFrameRx { seq: 1, status_code: StatusCode::AntiCloggingTokenRequired, body: &[] };
308        assert_variant!(parse(&act_required), Err(e) => {
309            assert!(format!("{:?}", e).contains("Failed to read group ID"))
310        });
311    }
312
313    #[test]
314    fn test_write_commit() {
315        let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[]);
316        assert_eq!(auth_frame.seq, 1);
317        assert_eq!(auth_frame.status_code, StatusCode::Success);
318        assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY);
319    }
320
321    #[test]
322    fn test_write_commit_with_anti_clogging_token() {
323        let auth_frame = write_commit(19, &[1u8; 32], &[2u8; 64], &[4u8; 8]);
324        assert_eq!(auth_frame.seq, 1);
325        assert_eq!(auth_frame.status_code, StatusCode::Success);
326        let mut expected_body = ECC_COMMIT_BODY.to_vec();
327        expected_body.append(&mut vec![4u8; 8]);
328        assert_eq!(&auth_frame.body[..], ECC_COMMIT_BODY_WITH_ANTI_CLOGGING_TOKEN);
329    }
330
331    #[test]
332    fn test_write_confirm() {
333        let auth_frame = write_confirm(1, &[3u8; 32]);
334        assert_eq!(auth_frame.seq, 2);
335        assert_eq!(auth_frame.status_code, StatusCode::Success);
336        assert_eq!(&auth_frame.body[..], ECC_CONFIRM_BODY);
337    }
338
339    #[test]
340    fn test_write_anticlogging_token() {
341        let auth_frame = write_token(19, &[4u8; 8]);
342        assert_eq!(auth_frame.seq, 1);
343        assert_eq!(auth_frame.status_code, StatusCode::AntiCloggingTokenRequired);
344        assert_eq!(&auth_frame.body[..], ECC_ACT_REQUIRED_BODY);
345    }
346}