bt_common/core/
ltv.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use thiserror::Error;
6
7use crate::packet_encoding::{Decodable, Encodable};
8
9/// Implement Ltv when a collection of types is represented in the Bluetooth
10/// specifications as a length-type-value structure.  They should have an
11/// associated type which can be retrieved from a type byte.
12pub trait LtValue: Sized {
13    type Type: Into<u8> + Copy + std::fmt::Debug;
14
15    const NAME: &'static str;
16
17    /// Given a type octet, return the associated Type if it is possible.
18    /// Returns None if the value is unrecognized.
19    fn type_from_octet(x: u8) -> Option<Self::Type>;
20
21    /// Returns length bounds for the type indicated, **including** the type
22    /// byte. Note that the assigned numbers from the Bluetooth SIG include
23    /// the type byte in their Length specifications.
24    // TODO: use impl std::ops::RangeBounds when RPITIT is sufficiently stable
25    fn length_range_from_type(ty: Self::Type) -> std::ops::RangeInclusive<u8>;
26
27    /// Retrieve the type of the current value.
28    fn into_type(&self) -> Self::Type;
29
30    /// The length of the encoded value, without the length and type byte.
31    /// This cannot be 255 in practice, as the length byte is only one octet
32    /// long.
33    fn value_encoded_len(&self) -> u8;
34
35    /// Decodes the value from a buffer, which does not include the type or
36    /// length bytes. The `buf` slice length is exactly what was specified
37    /// for this value in the encoded source.
38    fn decode_value(ty: &Self::Type, buf: &[u8]) -> Result<Self, crate::packet_encoding::Error>;
39
40    /// Encodes a value into `buf`, which is verified to be the correct length
41    /// as indicated by [LtValue::value_encoded_len].
42    fn encode_value(&self, buf: &mut [u8]) -> Result<(), crate::packet_encoding::Error>;
43
44    /// Decode a collection of LtValue structures that are present in a buffer.
45    /// If it is possible to continue decoding after encountering an error, does
46    /// so and includes the error. If an unrecoverable error occurs, does
47    /// not consume the final item and the last element in the result is the
48    /// error.
49    fn decode_all(buf: &[u8]) -> (Vec<Result<Self, Error<Self::Type>>>, usize) {
50        let mut results = Vec::new();
51        let mut total_consumed = 0;
52        loop {
53            if buf.len() <= total_consumed {
54                return (results, std::cmp::min(buf.len(), total_consumed));
55            }
56            let indicated_len = buf[total_consumed] as usize;
57            let range_end = std::cmp::min(buf.len() - 1, total_consumed + indicated_len);
58            match Self::decode(&buf[total_consumed..=range_end]) {
59                (Ok(item), consumed) => {
60                    results.push(Ok(item));
61                    total_consumed += consumed;
62                }
63                (Err(e), consumed) => {
64                    results.push(Err(e));
65                    total_consumed += consumed;
66                }
67            }
68        }
69    }
70
71    /// Encode a collection of LtValue structures into a buffer.
72    /// Even if the encoding fails, `buf` may still be modified by
73    /// previous encoding successes.
74    fn encode_all(
75        iter: impl Iterator<Item = Self>,
76        buf: &mut [u8],
77    ) -> Result<(), crate::packet_encoding::Error> {
78        let mut idx = 0;
79        for item in iter {
80            item.encode(&mut buf[idx..])?;
81            idx += item.encoded_len();
82        }
83        Ok(())
84    }
85}
86
87#[derive(Error, Debug, PartialEq)]
88pub enum Error<Type: std::fmt::Debug + Into<u8> + Copy> {
89    #[error("Buffer too short for next type")]
90    MissingType,
91    #[error("Buffer missing data indicated by length (type {0:?})")]
92    MissingData(Type),
93    #[error("Unrecognized type value for {0}: {1}")]
94    UnrecognizedType(String, u8),
95    #[error("Length of item ({0}) is outside allowed range for {1:?}: {2:?}")]
96    LengthOutOfRange(u8, Type, std::ops::RangeInclusive<u8>),
97    #[error("Error decoding type {0:?}: {1}")]
98    TypeFailedToDecode(Type, crate::packet_encoding::Error),
99}
100
101impl<Type: std::fmt::Debug + Into<u8> + Copy> Error<Type> {
102    pub fn type_value(&self) -> Option<u8> {
103        match self {
104            Self::MissingType => None,
105            Self::MissingData(t)
106            | Self::LengthOutOfRange(_, t, _)
107            | Self::TypeFailedToDecode(t, _) => Some((*t).into()),
108            Self::UnrecognizedType(_, value) => Some(*value),
109        }
110    }
111}
112
113impl<T: LtValue> Encodable for T {
114    type Error = crate::packet_encoding::Error;
115
116    fn encoded_len(&self) -> core::primitive::usize {
117        2 + self.value_encoded_len() as usize
118    }
119
120    fn encode(&self, buf: &mut [u8]) -> core::result::Result<(), Self::Error> {
121        if buf.len() < self.encoded_len() {
122            return Err(crate::packet_encoding::Error::BufferTooSmall);
123        }
124        buf[0] = self.value_encoded_len() + 1;
125        buf[1] = self.into_type().into();
126        self.encode_value(&mut buf[2..self.encoded_len()])?;
127        Ok(())
128    }
129}
130
131impl<T> Decodable for T
132where
133    T: LtValue,
134{
135    type Error = Error<T::Type>;
136
137    fn decode(buf: &[u8]) -> (core::result::Result<Self, Self::Error>, usize) {
138        if buf.len() < 2 {
139            return (Err(Error::MissingType), buf.len());
140        }
141        let indicated_len = buf[0] as usize;
142        let too_short = buf.len() < indicated_len + 1;
143
144        let Some(ty) = Self::type_from_octet(buf[1]) else {
145            return (
146                Err(Error::UnrecognizedType(Self::NAME.to_owned(), buf[1])),
147                if too_short { buf.len() } else { indicated_len + 1 },
148            );
149        };
150        if too_short {
151            return (Err(Error::MissingData(ty)), buf.len());
152        }
153        let size_range = Self::length_range_from_type(ty);
154        let remaining_len = (buf.len() - 1) as u8;
155        if !size_range.contains(&remaining_len) {
156            return (Err(Error::LengthOutOfRange(remaining_len, ty, size_range)), buf.len());
157        }
158        match Self::decode_value(&ty, &buf[2..=indicated_len]) {
159            Err(e) => (Err(Error::TypeFailedToDecode(ty, e)), indicated_len + 1),
160            Ok(s) => (Ok(s), indicated_len + 1),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[derive(Copy, Clone, PartialEq, Debug)]
170    enum TestType {
171        OneByte,
172        TwoBytes,
173        TwoBytesLittleEndian,
174        UnicodeString,
175        AlwaysError,
176    }
177
178    impl From<TestType> for u8 {
179        fn from(value: TestType) -> Self {
180            match value {
181                TestType::OneByte => 1,
182                TestType::TwoBytes => 2,
183                TestType::TwoBytesLittleEndian => 3,
184                TestType::UnicodeString => 4,
185                TestType::AlwaysError => 0xFF,
186            }
187        }
188    }
189
190    #[derive(PartialEq, Debug)]
191    enum TestValues {
192        OneByte(u8),
193        TwoBytes(u16),
194        TwoBytesLittleEndian(u16),
195        UnicodeString(String),
196        AlwaysError,
197    }
198
199    impl LtValue for TestValues {
200        type Type = TestType;
201
202        const NAME: &'static str = "TestValues";
203
204        fn type_from_octet(x: u8) -> Option<Self::Type> {
205            match x {
206                1 => Some(TestType::OneByte),
207                2 => Some(TestType::TwoBytes),
208                3 => Some(TestType::TwoBytesLittleEndian),
209                4 => Some(TestType::UnicodeString),
210                0xFF => Some(TestType::AlwaysError),
211                _ => None,
212            }
213        }
214
215        fn length_range_from_type(ty: Self::Type) -> std::ops::RangeInclusive<u8> {
216            match ty {
217                TestType::OneByte => 2..=2,
218                TestType::TwoBytes => 3..=3,
219                TestType::TwoBytesLittleEndian => 3..=3,
220                TestType::UnicodeString => 2..=255,
221                // AlwaysError fields can be any length (value will be thrown away)
222                TestType::AlwaysError => 1..=255,
223            }
224        }
225
226        fn into_type(&self) -> Self::Type {
227            match self {
228                TestValues::TwoBytes(_) => TestType::TwoBytes,
229                TestValues::TwoBytesLittleEndian(_) => TestType::TwoBytesLittleEndian,
230                TestValues::OneByte(_) => TestType::OneByte,
231                TestValues::UnicodeString(_) => TestType::UnicodeString,
232                TestValues::AlwaysError => TestType::AlwaysError,
233            }
234        }
235
236        fn value_encoded_len(&self) -> u8 {
237            match self {
238                TestValues::TwoBytes(_) => 2,
239                TestValues::TwoBytesLittleEndian(_) => 2,
240                TestValues::OneByte(_) => 1,
241                TestValues::UnicodeString(s) => s.len() as u8,
242                TestValues::AlwaysError => 0,
243            }
244        }
245
246        fn decode_value(
247            ty: &Self::Type,
248            buf: &[u8],
249        ) -> Result<Self, crate::packet_encoding::Error> {
250            match ty {
251                TestType::OneByte => Ok(TestValues::OneByte(buf[0])),
252                TestType::TwoBytes => {
253                    Ok(TestValues::TwoBytes(u16::from_be_bytes([buf[0], buf[1]])))
254                }
255                TestType::TwoBytesLittleEndian => {
256                    Ok(TestValues::TwoBytesLittleEndian(u16::from_le_bytes([buf[0], buf[1]])))
257                }
258                TestType::UnicodeString => {
259                    Ok(TestValues::UnicodeString(String::from_utf8_lossy(buf).into_owned()))
260                }
261                TestType::AlwaysError => Err(crate::packet_encoding::Error::OutOfRange),
262            }
263        }
264
265        fn encode_value(&self, buf: &mut [u8]) -> Result<(), crate::packet_encoding::Error> {
266            if buf.len() < self.value_encoded_len() as usize {
267                return Err(crate::packet_encoding::Error::BufferTooSmall);
268            }
269            match self {
270                TestValues::TwoBytes(x) => {
271                    [buf[0], buf[1]] = x.to_be_bytes();
272                }
273                TestValues::TwoBytesLittleEndian(x) => {
274                    [buf[0], buf[1]] = x.to_le_bytes();
275                }
276                TestValues::OneByte(x) => {
277                    buf[0] = *x;
278                }
279                TestValues::UnicodeString(s) => {
280                    buf.copy_from_slice(s.as_bytes());
281                }
282                TestValues::AlwaysError => {
283                    return Err(crate::packet_encoding::Error::InvalidParameter("test".to_owned()));
284                }
285            }
286            Ok(())
287        }
288    }
289
290    #[test]
291    fn decode_twobytes() {
292        let encoded = [0x03, 0x02, 0x10, 0x01, 0x03, 0x03, 0x10, 0x01];
293        let (decoded, consumed) = TestValues::decode_all(&encoded);
294        assert_eq!(consumed, encoded.len());
295        assert_eq!(decoded[0], Ok(TestValues::TwoBytes(4097)));
296        assert_eq!(decoded[1], Ok(TestValues::TwoBytesLittleEndian(272)));
297    }
298
299    #[test]
300    fn decode_unrecognized() {
301        let encoded = [0x03, 0x02, 0x10, 0x01, 0x03, 0x06, 0x10, 0x01];
302        let (decoded, consumed) = TestValues::decode_all(&encoded);
303        assert_eq!(consumed, encoded.len());
304        assert_eq!(decoded[0], Ok(TestValues::TwoBytes(4097)));
305        assert_eq!(decoded[1], Err(Error::UnrecognizedType("TestValues".to_owned(), 6)));
306    }
307
308    #[test]
309    fn decode_wronglength() {
310        let encoded = [0x03, 0x02, 0x10, 0x01, 0x04, 0x03, 0x10, 0x01];
311        let (decoded, consumed) = TestValues::decode_all(&encoded);
312        assert_eq!(consumed, encoded.len());
313        assert_eq!(decoded[0], Ok(TestValues::TwoBytes(4097)));
314        assert_eq!(decoded[1], Err(Error::MissingData(TestType::TwoBytesLittleEndian)));
315    }
316
317    #[test]
318    fn encode_twobytes() {
319        let value = TestValues::TwoBytes(0x0A0B);
320        let mut buf = [0; 4];
321        value.encode(&mut buf[..]).expect("should succeed");
322        assert_eq!(buf, [0x03, 0x02, 0x0A, 0x0B]);
323    }
324
325    #[test]
326    fn encode_all() {
327        let value1 = TestValues::OneByte(0x0A);
328        let value2 = TestValues::UnicodeString("Bluetooth".to_string());
329        let mut buf = [0; 14];
330        LtValue::encode_all(vec![value1, value2].into_iter(), &mut buf).expect("should succeed");
331        assert_eq!(
332            buf,
333            [0x02, 0x01, 0x0a, 0x0a, 0x04, 0x42, 0x6c, 0x75, 0x65, 0x74, 0x6f, 0x6f, 0x74, 0x68]
334        );
335    }
336
337    #[track_caller]
338    fn u8char(c: char) -> u8 {
339        c.try_into().unwrap()
340    }
341
342    #[test]
343    fn decode_variable_lengths() {
344        let encoded = [
345            0x03,
346            0x02,
347            0x10,
348            0x01,
349            0x0A,
350            0x04,
351            u8char('B'),
352            u8char('l'),
353            u8char('u'),
354            u8char('e'),
355            u8char('t'),
356            u8char('o'),
357            u8char('o'),
358            u8char('t'),
359            u8char('h'),
360            0x02,
361            0x01,
362            0x01,
363        ];
364        let (decoded, consumed) = TestValues::decode_all(&encoded);
365        assert_eq!(consumed, encoded.len());
366        assert_eq!(decoded[0], Ok(TestValues::TwoBytes(4097)));
367        assert_eq!(decoded[1], Ok(TestValues::UnicodeString("Bluetooth".to_owned())));
368        assert_eq!(decoded[2], Ok(TestValues::OneByte(1)));
369    }
370
371    #[test]
372    fn decode_with_error() {
373        let encoded = [0x03, 0x02, 0x10, 0x01, 0x02, 0xFF, 0xFF, 0x02, 0x01, 0x03];
374        let (decoded, consumed) = TestValues::decode_all(&encoded);
375        assert_eq!(consumed, encoded.len());
376        assert_eq!(decoded[0], Ok(TestValues::TwoBytes(4097)));
377        assert_eq!(
378            decoded[1],
379            Err(Error::TypeFailedToDecode(
380                TestType::AlwaysError,
381                crate::packet_encoding::Error::OutOfRange
382            ))
383        );
384        assert_eq!(decoded[2], Ok(TestValues::OneByte(3)));
385    }
386
387    #[test]
388    fn encode_with_error() {
389        let value = TestValues::AlwaysError;
390        let mut buf = [0; 10];
391        assert!(matches!(
392            value.encode(&mut buf),
393            Err(crate::packet_encoding::Error::InvalidParameter(_)),
394        ));
395
396        let value1 = TestValues::TwoBytes(0x0A0B);
397        let value2 = TestValues::OneByte(0x0A);
398        let mut buf = [0; 2]; // not enough buffer space.
399        LtValue::encode_all(vec![value1, value2].into_iter(), &mut buf).expect_err("should fail");
400    }
401}