Skip to main content

netlink_packet_utils/
nla.rs

1// SPDX-License-Identifier: MIT
2
3use crate::DecodeError;
4use crate::traits::{Emitable, Parseable};
5use byteorder::{ByteOrder, NativeEndian};
6use core::ops::Range;
7use thiserror::Error;
8
9/// Represent a multi-bytes field with a fixed size in a packet
10type Field = Range<usize>;
11
12/// Identify the bits that represent the "nested" flag of a netlink attribute.
13pub const NLA_F_NESTED: u16 = 0x8000;
14/// Identify the bits that represent the "byte order" flag of a netlink
15/// attribute.
16pub const NLA_F_NET_BYTEORDER: u16 = 0x4000;
17/// Identify the bits that represent the type of a netlink attribute.
18pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED);
19/// NlA(RTA) align size
20pub const NLA_ALIGNTO: usize = 4;
21/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
22pub const NLA_HEADER_SIZE: usize = 4;
23
24#[derive(Debug, Clone, Error)]
25pub enum NlaError {
26    #[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)]
27    BufferTooSmall { buffer_len: usize },
28
29    #[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")]
30    LengthMismatch { buffer_len: usize, nla_len: u16 },
31
32    #[error(
33        "NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end
34    )]
35    InvalidLength { nla_len: u16 },
36}
37
38#[macro_export]
39macro_rules! nla_align {
40    ($len: expr) => {
41        ($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1)
42    };
43}
44
45const LENGTH: Field = 0..2;
46const TYPE: Field = 2..4;
47#[allow(non_snake_case)]
48fn VALUE(length: usize) -> Field {
49    TYPE.end..TYPE.end + length
50}
51
52// with Copy, NlaBuffer<&'buffer T> can be copied, which turns out to be pretty
53// conveninent. And since it's boils down to copying a reference it's pretty
54// cheap
55#[derive(Debug, PartialEq, Eq, Clone, Copy)]
56pub struct NlaBuffer<T: AsRef<[u8]>> {
57    buffer: T,
58}
59
60impl<T: AsRef<[u8]>> NlaBuffer<T> {
61    pub(crate) fn new_unchecked(buffer: T) -> Self {
62        Self { buffer }
63    }
64
65    pub fn new(buffer: T) -> Result<NlaBuffer<T>, NlaError> {
66        let buffer = Self::new_unchecked(buffer);
67        buffer.check_buffer_length()?;
68        Ok(buffer)
69    }
70
71    pub fn check_buffer_length(&self) -> Result<(), NlaError> {
72        let len = self.buffer.as_ref().len();
73        if len < TYPE.end {
74            Err(NlaError::BufferTooSmall { buffer_len: len }.into())
75        } else if len < self.length() as usize {
76            Err(NlaError::LengthMismatch { buffer_len: len, nla_len: self.length() }.into())
77        } else if (self.length() as usize) < TYPE.end {
78            Err(NlaError::InvalidLength { nla_len: self.length() }.into())
79        } else {
80            Ok(())
81        }
82    }
83
84    /// Consume the buffer, returning the underlying buffer.
85    pub fn into_inner(self) -> T {
86        self.buffer
87    }
88
89    /// Return a reference to the underlying buffer
90    pub fn inner(&self) -> &T {
91        &self.buffer
92    }
93
94    /// Return a mutable reference to the underlying buffer
95    pub fn inner_mut(&mut self) -> &mut T {
96        &mut self.buffer
97    }
98
99    /// Return the `type` field
100    pub fn kind(&self) -> u16 {
101        let data = self.buffer.as_ref();
102        NativeEndian::read_u16(&data[TYPE]) & NLA_TYPE_MASK
103    }
104
105    pub fn nested_flag(&self) -> bool {
106        let data = self.buffer.as_ref();
107        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NESTED) != 0
108    }
109
110    pub fn network_byte_order_flag(&self) -> bool {
111        let data = self.buffer.as_ref();
112        (NativeEndian::read_u16(&data[TYPE]) & NLA_F_NET_BYTEORDER) != 0
113    }
114
115    /// Return the `length` field. The `length` field corresponds to the length
116    /// of the nla header (type and length fields, and the value field).
117    /// However, it does not account for the potential padding that follows
118    /// the value field.
119    pub fn length(&self) -> u16 {
120        let data = self.buffer.as_ref();
121        NativeEndian::read_u16(&data[LENGTH])
122    }
123
124    /// Return the length of the `value` field
125    ///
126    /// # Panic
127    ///
128    /// This panics if the length field value is less than the attribut header
129    /// size.
130    pub fn value_length(&self) -> usize {
131        self.length() as usize - TYPE.end
132    }
133}
134
135impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
136    /// Set the `type` field
137    pub fn set_kind(&mut self, kind: u16) {
138        let data = self.buffer.as_mut();
139        NativeEndian::write_u16(&mut data[TYPE], kind & NLA_TYPE_MASK)
140    }
141
142    pub fn set_nested_flag(&mut self) {
143        let kind = self.kind();
144        let data = self.buffer.as_mut();
145        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NESTED)
146    }
147
148    pub fn set_network_byte_order_flag(&mut self) {
149        let kind = self.kind();
150        let data = self.buffer.as_mut();
151        NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER)
152    }
153
154    /// Set the `length` field
155    pub fn set_length(&mut self, length: u16) {
156        let data = self.buffer.as_mut();
157        NativeEndian::write_u16(&mut data[LENGTH], length)
158    }
159}
160
161impl<'buffer, T: AsRef<[u8]> + ?Sized> NlaBuffer<&'buffer T> {
162    /// Return the `value` field
163    pub fn value(&self) -> &[u8] {
164        &self.buffer.as_ref()[VALUE(self.value_length())]
165    }
166}
167
168impl<'buffer, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&'buffer mut T> {
169    /// Return the `value` field
170    pub fn value_mut(&mut self) -> &mut [u8] {
171        let length = VALUE(self.value_length());
172        &mut self.buffer.as_mut()[length]
173    }
174}
175
176#[derive(Debug, PartialEq, Eq, Clone)]
177pub struct DefaultNla {
178    kind: u16,
179    value: Vec<u8>,
180}
181
182impl DefaultNla {
183    pub fn new(kind: u16, value: Vec<u8>) -> Self {
184        Self { kind, value }
185    }
186}
187
188impl Nla for DefaultNla {
189    fn value_len(&self) -> usize {
190        self.value.len()
191    }
192    fn kind(&self) -> u16 {
193        self.kind
194    }
195    fn emit_value(&self, buffer: &mut [u8]) {
196        buffer.copy_from_slice(self.value.as_slice());
197    }
198}
199
200impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>> for DefaultNla {
201    type Error = DecodeError;
202
203    fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, Self::Error> {
204        let mut kind = buf.kind();
205
206        if buf.network_byte_order_flag() {
207            kind |= NLA_F_NET_BYTEORDER;
208        }
209
210        if buf.nested_flag() {
211            kind |= NLA_F_NESTED;
212        }
213
214        Ok(DefaultNla { kind, value: buf.value().to_vec() })
215    }
216}
217
218pub trait Nla {
219    fn value_len(&self) -> usize;
220    fn kind(&self) -> u16;
221    fn emit_value(&self, buffer: &mut [u8]);
222
223    #[inline]
224    fn is_nested(&self) -> bool {
225        (self.kind() & NLA_F_NESTED) != 0
226    }
227
228    #[inline]
229    fn is_network_byteorder(&self) -> bool {
230        (self.kind() & NLA_F_NET_BYTEORDER) != 0
231    }
232}
233
234impl<T: Nla> Emitable for T {
235    fn buffer_len(&self) -> usize {
236        nla_align!(self.value_len()) + NLA_HEADER_SIZE
237    }
238    fn emit(&self, buffer: &mut [u8]) {
239        let mut buffer = NlaBuffer::new_unchecked(buffer);
240        buffer.set_kind(self.kind());
241
242        if self.is_network_byteorder() {
243            buffer.set_network_byte_order_flag()
244        }
245
246        if self.is_nested() {
247            buffer.set_nested_flag()
248        }
249
250        // do not include the padding here, but do include the header
251        buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16);
252
253        self.emit_value(buffer.value_mut());
254
255        let padding = nla_align!(self.value_len()) - self.value_len();
256        for i in 0..padding {
257            buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0;
258        }
259    }
260}
261
262// FIXME: whern specialization lands, why can actually have
263//
264// impl<'a, T: Nla, I: Iterator<Item=T>> Emitable for I { ...}
265//
266// The reason this does not work today is because it conflicts with
267//
268// impl<T: Nla> Emitable for T { ... }
269impl<'a, T: Nla> Emitable for &'a [T] {
270    fn buffer_len(&self) -> usize {
271        self.iter().fold(0, |acc, nla| {
272            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
273            acc + nla.buffer_len()
274        })
275    }
276
277    fn emit(&self, buffer: &mut [u8]) {
278        let mut start = 0;
279        let mut end: usize;
280        for nla in self.iter() {
281            let attr_len = nla.buffer_len();
282            assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
283            end = start + attr_len;
284            nla.emit(&mut buffer[start..end]);
285            start = end;
286        }
287    }
288}
289
290/// An iterator that iteratates over nlas without decoding them. This is useful
291/// when looking for specific nlas.
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub struct NlasIterator<T> {
294    position: usize,
295    buffer: T,
296}
297
298impl<T> NlasIterator<T> {
299    pub fn new(buffer: T) -> Self {
300        NlasIterator { position: 0, buffer }
301    }
302}
303
304impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator for NlasIterator<&'buffer T> {
305    type Item = Result<NlaBuffer<&'buffer [u8]>, NlaError>;
306
307    fn next(&mut self) -> Option<Self::Item> {
308        if self.position >= self.buffer.as_ref().len() {
309            return None;
310        }
311
312        match NlaBuffer::new(&self.buffer.as_ref()[self.position..]) {
313            Ok(nla_buffer) => {
314                self.position += nla_align!(nla_buffer.length() as usize);
315                Some(Ok(nla_buffer))
316            }
317            Err(e) => {
318                // Make sure next time we call `next()`, we return None. We
319                // don't try to continue iterating after we
320                // failed to return a buffer.
321                self.position = self.buffer.as_ref().len();
322                Some(Err(e))
323            }
324        }
325    }
326}
327
328/// Describes how to handle errors when parsing attributes.
329pub enum NlaParseMode {
330    /// Any attribute parsing errors result in a failure of the entire message
331    /// parsing operation. Corresponds to the `NETLINK_GET_STRICT_CHK` option.
332    Strict,
333    /// Attribute parsing errors do not impact the success of message parsing.
334    /// Attributes that failed to parse are ignored.
335    Relaxed,
336}
337
338/// A type that can iterate over and parse Netlink attributes.
339pub trait HasNlas {
340    /// Returns an iterator over the Netlink attribute buffers in `self`, or
341    /// errors where the attribute length is invalid.
342    fn attributes(&self) -> impl Iterator<Item = Result<NlaBuffer<&[u8]>, NlaError>>;
343
344    /// Parses the Netlink attributes from `self`, using `parse_fn` to parse
345    /// each attribute from an attribute buffer.
346    fn parse_attributes<'a, A, E>(
347        &'a self,
348        mode: NlaParseMode,
349        mut parse_fn: impl FnMut(&NlaBuffer<&'a [u8]>) -> Result<A, E>,
350    ) -> Result<Vec<A>, E>
351    where
352        E: From<NlaError>,
353    {
354        self.attributes()
355            .filter_map(|nla_buf| {
356                match nla_buf.map_err(|e| e.into()).map(|b| parse_fn(&b)).flatten() {
357                    Ok(attr) => Some(Ok(attr)),
358                    Err(e) => match mode {
359                        NlaParseMode::Strict => Some(Err(e)),
360                        NlaParseMode::Relaxed => None,
361                    },
362                }
363            })
364            .collect()
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    use assert_matches::assert_matches;
373
374    #[test]
375    fn network_byteorder() {
376        // The IPSET_ATTR_TIMEOUT attribute should have the network byte order
377        // flag set. IPSET_ATTR_TIMEOUT(3600)
378        static TEST_ATTRIBUTE: &[u8] = &[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10];
379        let buffer = NlaBuffer::new(TEST_ATTRIBUTE).unwrap();
380        let buffer_is_net = buffer.network_byte_order_flag();
381        let buffer_is_nest = buffer.nested_flag();
382
383        let nla = DefaultNla::parse(&buffer).unwrap();
384        let mut emitted_buffer = vec![0; nla.buffer_len()];
385
386        nla.emit(&mut emitted_buffer);
387
388        let attr_is_net = nla.is_network_byteorder();
389        let attr_is_nest = nla.is_nested();
390
391        let emit = NlaBuffer::new(emitted_buffer).unwrap();
392        let emit_is_net = emit.network_byte_order_flag();
393        let emit_is_nest = emit.nested_flag();
394
395        assert_eq!([buffer_is_net, buffer_is_nest], [attr_is_net, attr_is_nest]);
396        assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]);
397    }
398
399    fn get_len() -> usize {
400        // usize::MAX
401        18446744073709551615
402    }
403
404    #[test]
405    fn test_align() {
406        assert_eq!(nla_align!(13), 16);
407        assert_eq!(nla_align!(16), 16);
408        assert_eq!(nla_align!(0), 0);
409        assert_eq!(nla_align!(1), 4);
410        assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3);
411    }
412    #[test]
413    #[should_panic]
414    fn test_align_overflow() {
415        assert_eq!(nla_align!(get_len() - 3), usize::MAX);
416    }
417
418    struct TestHasNlas(Vec<Result<Vec<u8>, NlaError>>);
419
420    impl HasNlas for TestHasNlas {
421        fn attributes(&self) -> impl Iterator<Item = Result<NlaBuffer<&[u8]>, NlaError>> {
422            let Self(buffers) = self;
423            buffers.iter().map(|res| match res {
424                Ok(bytes) => Ok(NlaBuffer::new_unchecked(bytes.as_slice())),
425                Err(e) => Err(e.clone()),
426            })
427        }
428    }
429
430    #[derive(Debug, PartialEq, Eq)]
431    struct TestAttr(Vec<u8>);
432
433    fn parse_test_attr(buf: &NlaBuffer<&[u8]>) -> Result<TestAttr, NlaError> {
434        Ok(TestAttr(buf.inner().to_vec()))
435    }
436
437    #[test]
438    fn parse_attributes_strict_fail_first() {
439        let nlas =
440            TestHasNlas(vec![Err(NlaError::InvalidLength { nla_len: 0 }), Ok(vec![1, 2, 3])]);
441
442        assert_matches!(
443            nlas.parse_attributes(NlaParseMode::Strict, parse_test_attr),
444            Err(NlaError::InvalidLength { .. })
445        );
446    }
447
448    #[test]
449    fn parse_attributes_strict_fail_middle() {
450        let nlas = TestHasNlas(vec![
451            Ok(vec![1, 2, 3]),
452            Err(NlaError::InvalidLength { nla_len: 0 }),
453            Ok(vec![4, 5, 6]),
454        ]);
455
456        assert_matches!(
457            nlas.parse_attributes(NlaParseMode::Strict, parse_test_attr),
458            Err(NlaError::InvalidLength { .. })
459        );
460    }
461
462    #[test]
463    fn parse_attributes_strict_fail_last() {
464        let nlas =
465            TestHasNlas(vec![Ok(vec![1, 2, 3]), Err(NlaError::InvalidLength { nla_len: 0 })]);
466
467        assert_matches!(
468            nlas.parse_attributes(NlaParseMode::Strict, parse_test_attr),
469            Err(NlaError::InvalidLength { .. })
470        );
471    }
472
473    #[test]
474    fn parse_attributes_strict_inner_fn_fails() {
475        let nlas = TestHasNlas(vec![Ok(vec![1, 2, 3]), Ok(vec![]), Ok(vec![4, 5, 6])]);
476
477        let res = nlas.parse_attributes(NlaParseMode::Strict, |buf| {
478            if buf.inner().is_empty() {
479                return Err(NlaError::BufferTooSmall { buffer_len: 0 });
480            }
481            Ok(TestAttr(buf.inner().to_vec()))
482        });
483        assert_matches!(res, Err(NlaError::BufferTooSmall { .. }));
484    }
485
486    #[test]
487    fn parse_attributes_relaxed_fail_first() {
488        let nlas =
489            TestHasNlas(vec![Err(NlaError::InvalidLength { nla_len: 0 }), Ok(vec![1, 2, 3])]);
490
491        assert_eq!(
492            nlas.parse_attributes(NlaParseMode::Relaxed, parse_test_attr).unwrap(),
493            vec![TestAttr(vec![1, 2, 3])]
494        );
495    }
496
497    #[test]
498    fn parse_attributes_relaxed_fail_middle() {
499        let nlas = TestHasNlas(vec![
500            Ok(vec![1, 2, 3]),
501            Err(NlaError::InvalidLength { nla_len: 0 }),
502            Ok(vec![4, 5, 6]),
503        ]);
504
505        assert_eq!(
506            nlas.parse_attributes(NlaParseMode::Relaxed, parse_test_attr).unwrap(),
507            vec![TestAttr(vec![1, 2, 3]), TestAttr(vec![4, 5, 6])]
508        );
509    }
510
511    #[test]
512    fn parse_attributes_relaxed_fail_last() {
513        let nlas =
514            TestHasNlas(vec![Ok(vec![1, 2, 3]), Err(NlaError::InvalidLength { nla_len: 0 })]);
515
516        assert_eq!(
517            nlas.parse_attributes(NlaParseMode::Relaxed, parse_test_attr).unwrap(),
518            vec![TestAttr(vec![1, 2, 3])]
519        );
520    }
521
522    #[test]
523    fn parse_attributes_relaxed_inner_fn_fails() {
524        let nlas = TestHasNlas(vec![Ok(vec![1, 2, 3]), Ok(vec![]), Ok(vec![4, 5, 6])]);
525
526        assert_eq!(
527            nlas.parse_attributes(NlaParseMode::Relaxed, |buf| {
528                if buf.inner().is_empty() {
529                    return Err(NlaError::BufferTooSmall { buffer_len: 0 });
530                }
531                Ok(TestAttr(buf.inner().to_vec()))
532            })
533            .unwrap(),
534            vec![TestAttr(vec![1, 2, 3]), TestAttr(vec![4, 5, 6])]
535        );
536    }
537}