Skip to main content

netlink_packet_core/
message.rs

1// SPDX-License-Identifier: MIT
2
3use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::{DecodeError, ParseableParametrized};
7
8use crate::payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
9use crate::{
10    DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
11    NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable,
12};
13
14/// Represent a netlink message.
15#[derive(Debug, PartialEq, Eq, Clone)]
16#[non_exhaustive]
17pub struct NetlinkMessage<I> {
18    /// Message header (this is common to all the netlink protocols)
19    pub header: NetlinkHeader,
20    /// Inner message, which depends on the netlink protocol being used.
21    pub payload: NetlinkPayload<I>,
22}
23
24impl<I> NetlinkMessage<I> {
25    /// Create a new netlink message from the given header and payload
26    pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
27        NetlinkMessage { header, payload }
28    }
29
30    /// Consume this message and return its header and payload
31    pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
32        (self.header, self.payload)
33    }
34}
35
36impl<I> NetlinkMessage<I>
37where
38    I: NetlinkDeserializable,
39    I::Error: Into<DecodeError>,
40{
41    /// Parse the given buffer as a netlink message
42    pub fn deserialize(buffer: &[u8], options: I::DeserializeOptions) -> Result<Self, DecodeError> {
43        let netlink_buffer = NetlinkBuffer::new(&buffer)?;
44        <Self as ParseableParametrized<
45            NetlinkBuffer<&&[u8]>,
46            I::DeserializeOptions,
47        >>::parse_with_param(&netlink_buffer, options)
48    }
49}
50
51impl<I> NetlinkMessage<I>
52where
53    I: NetlinkSerializable,
54{
55    /// Return the length of this message in bytes
56    pub fn buffer_len(&self) -> usize {
57        <Self as Emitable>::buffer_len(self)
58    }
59
60    /// Serialize this message and write the serialized data into the
61    /// given buffer. `buffer` must big large enough for the whole
62    /// message to fit, otherwise, this method will panic. To know how
63    /// big the serialized message is, call `buffer_len()`.
64    ///
65    /// # Panic
66    ///
67    /// This method panics if the buffer is not big enough.
68    pub fn serialize(&self, buffer: &mut [u8]) {
69        self.emit(buffer)
70    }
71
72    /// Ensure the header (`NetlinkHeader`) is consistent with the payload
73    /// (`NetlinkPayload`):
74    ///
75    /// - compute the payload length and set the header's length field
76    /// - check the payload type and set the header's message type field
77    ///   accordingly
78    ///
79    /// If you are not 100% sure the header is correct, this method should be
80    /// called before calling [`Emitable::emit()`](trait.Emitable.html#
81    /// tymethod.emit), as it could panic if the header is inconsistent with
82    /// the rest of the message.
83    pub fn finalize(&mut self) {
84        self.header.length = self.buffer_len() as u32;
85        self.header.message_type = self.payload.message_type();
86    }
87}
88
89impl<'buffer, B, I> ParseableParametrized<NetlinkBuffer<&'buffer B>, I::DeserializeOptions>
90    for NetlinkMessage<I>
91where
92    B: AsRef<[u8]> + 'buffer,
93    I: NetlinkDeserializable,
94    I::Error: Into<DecodeError>,
95{
96    type Error = DecodeError;
97    fn parse_with_param(
98        buf: &NetlinkBuffer<&'buffer B>,
99        options: I::DeserializeOptions,
100    ) -> Result<Self, DecodeError> {
101        use self::NetlinkPayload::*;
102
103        let header = <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
104            .context("failed to parse netlink header")?;
105
106        let bytes = buf.payload();
107        let payload = match header.message_type {
108            NLMSG_ERROR => {
109                let msg = ErrorBuffer::new(&bytes)
110                    .and_then(|buf| ErrorMessage::parse(&buf))
111                    .map_err(|err| DecodeError::FailedToParseNlMsgError(err.into()))?;
112                Error(msg)
113            }
114            NLMSG_NOOP => Noop,
115            NLMSG_DONE => {
116                let msg = DoneBuffer::new(&bytes)
117                    .and_then(|buf| DoneMessage::parse(&buf))
118                    .map_err(|err| DecodeError::FailedToParseNlMsgDone(err.into()))?;
119                Done(msg)
120            }
121            NLMSG_OVERRUN => Overrun(bytes.to_vec()),
122            message_type => {
123                let inner_msg = I::deserialize(&header, bytes, options).map_err(|err| {
124                    DecodeError::FailedToParseMessageWithType {
125                        message_type,
126                        source: Box::new(err.into()),
127                    }
128                })?;
129                InnerMessage(inner_msg)
130            }
131        };
132        Ok(NetlinkMessage { header, payload })
133    }
134}
135
136impl<I> Emitable for NetlinkMessage<I>
137where
138    I: NetlinkSerializable,
139{
140    fn buffer_len(&self) -> usize {
141        use self::NetlinkPayload::*;
142
143        let payload_len = match self.payload {
144            Noop => 0,
145            Done(ref msg) => msg.buffer_len(),
146            Overrun(ref bytes) => bytes.len(),
147            Error(ref msg) => msg.buffer_len(),
148            InnerMessage(ref msg) => msg.buffer_len(),
149        };
150
151        self.header.buffer_len() + payload_len
152    }
153
154    fn emit(&self, buffer: &mut [u8]) {
155        use self::NetlinkPayload::*;
156
157        self.header.emit(buffer);
158
159        let buffer = &mut buffer[self.header.buffer_len()..self.header.length as usize];
160        match self.payload {
161            Noop => {}
162            Done(ref msg) => msg.emit(buffer),
163            Overrun(ref bytes) => buffer.copy_from_slice(bytes),
164            Error(ref msg) => msg.emit(buffer),
165            InnerMessage(ref msg) => msg.serialize(buffer),
166        }
167    }
168}
169
170impl<T> From<T> for NetlinkMessage<T>
171where
172    T: Into<NetlinkPayload<T>>,
173{
174    fn from(inner_message: T) -> Self {
175        NetlinkMessage { header: NetlinkHeader::default(), payload: inner_message.into() }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    use std::mem::size_of;
184    use std::num::NonZeroI32;
185
186    #[derive(Clone, Debug, Default, PartialEq)]
187    struct FakeNetlinkInnerMessage;
188
189    impl NetlinkSerializable for FakeNetlinkInnerMessage {
190        fn message_type(&self) -> u16 {
191            unimplemented!("unused by tests")
192        }
193
194        fn buffer_len(&self) -> usize {
195            unimplemented!("unused by tests")
196        }
197
198        fn serialize(&self, _buffer: &mut [u8]) {
199            unimplemented!("unused by tests")
200        }
201    }
202
203    impl NetlinkDeserializable for FakeNetlinkInnerMessage {
204        type DeserializeOptions = ();
205        type Error = DecodeError;
206
207        fn deserialize(
208            _header: &NetlinkHeader,
209            _payload: &[u8],
210            _options: (),
211        ) -> Result<Self, Self::Error> {
212            unimplemented!("unused by tests")
213        }
214    }
215
216    #[test]
217    fn test_done() {
218        let header = NetlinkHeader::default();
219        let done_msg = DoneMessage { code: 0, extended_ack: vec![6, 7, 8, 9] };
220        let mut want = NetlinkMessage::new(
221            header,
222            NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
223        );
224        want.finalize();
225
226        let len = want.buffer_len();
227        assert_eq!(len, header.buffer_len() + size_of::<i32>() + done_msg.extended_ack.len());
228
229        let mut buf = vec![1; len];
230        want.emit(&mut buf);
231
232        let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]).unwrap();
233        assert_eq!(done_buf.code(), done_msg.code);
234        assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
235
236        let got = NetlinkMessage::parse_with_param(&NetlinkBuffer::new(&buf).unwrap(), ()).unwrap();
237        assert_eq!(got, want);
238    }
239
240    #[test]
241    fn test_error() {
242        // SAFETY: value is non-zero.
243        const ERROR_CODE: NonZeroI32 = unsafe { NonZeroI32::new_unchecked(-8765) };
244
245        let header = NetlinkHeader::default();
246        let error_msg = ErrorMessage { code: Some(ERROR_CODE), header: vec![] };
247        let mut want = NetlinkMessage::new(
248            header,
249            NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
250        );
251        want.finalize();
252
253        let len = want.buffer_len();
254        assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
255
256        let mut buf = vec![1; len];
257        want.emit(&mut buf);
258
259        let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]).unwrap();
260        assert_eq!(error_buf.code(), error_msg.code);
261
262        let got = NetlinkMessage::parse_with_param(&NetlinkBuffer::new(&buf).unwrap(), ()).unwrap();
263        assert_eq!(got, want);
264    }
265}