netlink_packet_core/
error.rs

1// SPDX-License-Identifier: MIT
2
3use std::mem::size_of;
4use std::num::NonZeroI32;
5use std::{fmt, io};
6
7use byteorder::{ByteOrder, NativeEndian};
8use netlink_packet_utils::DecodeError;
9
10use crate::{Emitable, Field, Parseable, Rest};
11
12const CODE: Field = 0..4;
13const PAYLOAD: Rest = 4..;
14const ERROR_HEADER_LEN: usize = PAYLOAD.start;
15
16#[derive(Debug, PartialEq, Eq, Clone)]
17#[non_exhaustive]
18pub struct ErrorBuffer<T> {
19    buffer: T,
20}
21
22impl<T: AsRef<[u8]>> ErrorBuffer<T> {
23    pub fn new(buffer: T) -> ErrorBuffer<T> {
24        ErrorBuffer { buffer }
25    }
26
27    /// Consume the packet, returning the underlying buffer.
28    pub fn into_inner(self) -> T {
29        self.buffer
30    }
31
32    pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
33        let packet = Self::new(buffer);
34        packet.check_buffer_length()?;
35        Ok(packet)
36    }
37
38    fn check_buffer_length(&self) -> Result<(), DecodeError> {
39        let len = self.buffer.as_ref().len();
40        if len < ERROR_HEADER_LEN {
41            Err(format!(
42                "invalid ErrorBuffer: length is {len} but ErrorBuffer are \
43                at least {ERROR_HEADER_LEN} bytes"
44            )
45            .into())
46        } else {
47            Ok(())
48        }
49    }
50
51    /// Return the error code.
52    ///
53    /// Returns `None` when there is no error to report (the message is an ACK),
54    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
55    /// message is a NACK).
56    pub fn code(&self) -> Option<NonZeroI32> {
57        let data = self.buffer.as_ref();
58        NonZeroI32::new(NativeEndian::read_i32(&data[CODE]))
59    }
60}
61
62impl<'a, T: AsRef<[u8]> + ?Sized> ErrorBuffer<&'a T> {
63    /// Return a pointer to the payload.
64    pub fn payload(&self) -> &'a [u8] {
65        let data = self.buffer.as_ref();
66        &data[PAYLOAD]
67    }
68}
69
70impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> ErrorBuffer<&'a mut T> {
71    /// Return a mutable pointer to the payload.
72    pub fn payload_mut(&mut self) -> &mut [u8] {
73        let data = self.buffer.as_mut();
74        &mut data[PAYLOAD]
75    }
76}
77
78impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
79    /// set the error code field
80    pub fn set_code(&mut self, value: i32) {
81        let data = self.buffer.as_mut();
82        NativeEndian::write_i32(&mut data[CODE], value)
83    }
84}
85
86/// An `NLMSG_ERROR` message.
87///
88/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a
89/// request which will indicate either success (an ACK) or failure (a NACK).
90///
91/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2
92#[derive(Debug, Default, Clone, PartialEq, Eq)]
93#[non_exhaustive]
94pub struct ErrorMessage {
95    /// The error code.
96    ///
97    /// Holds `None` when there is no error to report (the message is an ACK),
98    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
99    /// message is a NACK).
100    ///
101    /// See [Netlink message types] for details.
102    ///
103    /// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types
104    pub code: Option<NonZeroI32>,
105    /// The original request's header.
106    pub header: Vec<u8>,
107}
108
109impl Emitable for ErrorMessage {
110    fn buffer_len(&self) -> usize {
111        size_of::<i32>() + self.header.len()
112    }
113    fn emit(&self, buffer: &mut [u8]) {
114        let mut buffer = ErrorBuffer::new(buffer);
115        buffer.set_code(self.raw_code());
116        buffer.payload_mut().copy_from_slice(&self.header)
117    }
118}
119
120impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<ErrorBuffer<&'buffer T>> for ErrorMessage {
121    type Error = DecodeError;
122    fn parse(buf: &ErrorBuffer<&'buffer T>) -> Result<ErrorMessage, DecodeError> {
123        // FIXME: The payload of an error is basically a truncated packet, which
124        // requires custom logic to parse correctly. For now we just
125        // return it as a Vec<u8> let header: NetlinkHeader = {
126        //     NetlinkBuffer::new_checked(self.payload())
127        //         .context("failed to parse netlink header")?
128        //         .parse()
129        //         .context("failed to parse nelink header")?
130        // };
131        Ok(ErrorMessage { code: buf.code(), header: buf.payload().to_vec() })
132    }
133}
134
135impl ErrorMessage {
136    /// Returns the raw error code.
137    pub fn raw_code(&self) -> i32 {
138        self.code.map_or(0, NonZeroI32::get)
139    }
140
141    /// According to [`netlink(7)`](https://linux.die.net/man/7/netlink)
142    /// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements.
143    ///
144    /// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html)
145    /// using the absolute value from errno code
146    pub fn to_io(&self) -> io::Error {
147        io::Error::from_raw_os_error(self.raw_code().abs())
148    }
149}
150
151impl fmt::Display for ErrorMessage {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        fmt::Display::fmt(&self.to_io(), f)
154    }
155}
156
157impl From<ErrorMessage> for io::Error {
158    fn from(e: ErrorMessage) -> io::Error {
159        e.to_io()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn into_io_error() {
169        let io_err = io::Error::from_raw_os_error(95);
170        let err_msg = ErrorMessage { code: NonZeroI32::new(-95), header: vec![] };
171
172        let to_io: io::Error = err_msg.to_io();
173
174        assert_eq!(err_msg.to_string(), io_err.to_string());
175        assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
176    }
177
178    #[test]
179    fn parse_ack() {
180        let bytes = vec![0, 0, 0, 0];
181        let msg = ErrorBuffer::new_checked(&bytes)
182            .and_then(|buf| ErrorMessage::parse(&buf))
183            .expect("failed to parse NLMSG_ERROR");
184        assert_eq!(ErrorMessage { code: None, header: Vec::new() }, msg);
185        assert_eq!(msg.raw_code(), 0);
186    }
187
188    #[test]
189    fn parse_nack() {
190        // SAFETY: value is non-zero.
191        const ERROR_CODE: NonZeroI32 = unsafe { NonZeroI32::new_unchecked(-1234) };
192        let mut bytes = vec![0, 0, 0, 0];
193        NativeEndian::write_i32(&mut bytes, ERROR_CODE.get());
194        let msg = ErrorBuffer::new_checked(&bytes)
195            .and_then(|buf| ErrorMessage::parse(&buf))
196            .expect("failed to parse NLMSG_ERROR");
197        assert_eq!(ErrorMessage { code: Some(ERROR_CODE), header: Vec::new() }, msg);
198        assert_eq!(msg.raw_code(), ERROR_CODE.get());
199    }
200}