netlink_packet_core/
message.rs

1// SPDX-License-Identifier: MIT
2
3use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::DecodeError;
7
8use crate::payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
9use crate::{
10    DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
11    NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable,
12};
13
14/// Represent a netlink message.
15#[derive(Debug, PartialEq, Eq, Clone)]
16#[non_exhaustive]
17pub struct NetlinkMessage<I> {
18    /// Message header (this is common to all the netlink protocols)
19    pub header: NetlinkHeader,
20    /// Inner message, which depends on the netlink protocol being used.
21    pub payload: NetlinkPayload<I>,
22}
23
24impl<I> NetlinkMessage<I> {
25    /// Create a new netlink message from the given header and payload
26    pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
27        NetlinkMessage { header, payload }
28    }
29
30    /// Consume this message and return its header and payload
31    pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
32        (self.header, self.payload)
33    }
34}
35
36impl<I> NetlinkMessage<I>
37where
38    I: NetlinkDeserializable,
39    I::Error: Into<DecodeError>,
40{
41    /// Parse the given buffer as a netlink message
42    pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
43        let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
44        <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
45    }
46}
47
48impl<I> NetlinkMessage<I>
49where
50    I: NetlinkSerializable,
51{
52    /// Return the length of this message in bytes
53    pub fn buffer_len(&self) -> usize {
54        <Self as Emitable>::buffer_len(self)
55    }
56
57    /// Serialize this message and write the serialized data into the
58    /// given buffer. `buffer` must big large enough for the whole
59    /// message to fit, otherwise, this method will panic. To know how
60    /// big the serialized message is, call `buffer_len()`.
61    ///
62    /// # Panic
63    ///
64    /// This method panics if the buffer is not big enough.
65    pub fn serialize(&self, buffer: &mut [u8]) {
66        self.emit(buffer)
67    }
68
69    /// Ensure the header (`NetlinkHeader`) is consistent with the payload
70    /// (`NetlinkPayload`):
71    ///
72    /// - compute the payload length and set the header's length field
73    /// - check the payload type and set the header's message type field
74    ///   accordingly
75    ///
76    /// If you are not 100% sure the header is correct, this method should be
77    /// called before calling [`Emitable::emit()`](trait.Emitable.html#
78    /// tymethod.emit), as it could panic if the header is inconsistent with
79    /// the rest of the message.
80    pub fn finalize(&mut self) {
81        self.header.length = self.buffer_len() as u32;
82        self.header.message_type = self.payload.message_type();
83    }
84}
85
86impl<'buffer, B, I> Parseable<NetlinkBuffer<&'buffer B>> for NetlinkMessage<I>
87where
88    B: AsRef<[u8]> + 'buffer,
89    I: NetlinkDeserializable,
90    I::Error: Into<DecodeError>,
91{
92    type Error = DecodeError;
93    fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
94        use self::NetlinkPayload::*;
95
96        let header = <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
97            .context("failed to parse netlink header")?;
98
99        let bytes = buf.payload();
100        let payload = match header.message_type {
101            NLMSG_ERROR => {
102                let msg = ErrorBuffer::new_checked(&bytes)
103                    .and_then(|buf| ErrorMessage::parse(&buf))
104                    .map_err(|err| DecodeError::FailedToParseNlMsgError(err.into()))?;
105                Error(msg)
106            }
107            NLMSG_NOOP => Noop,
108            NLMSG_DONE => {
109                let msg = DoneBuffer::new_checked(&bytes)
110                    .and_then(|buf| DoneMessage::parse(&buf))
111                    .map_err(|err| DecodeError::FailedToParseNlMsgDone(err.into()))?;
112                Done(msg)
113            }
114            NLMSG_OVERRUN => Overrun(bytes.to_vec()),
115            message_type => {
116                let inner_msg = I::deserialize(&header, bytes).map_err(|err| {
117                    DecodeError::FailedToParseMessageWithType {
118                        message_type,
119                        source: Box::new(err.into()),
120                    }
121                })?;
122                InnerMessage(inner_msg)
123            }
124        };
125        Ok(NetlinkMessage { header, payload })
126    }
127}
128
129impl<I> Emitable for NetlinkMessage<I>
130where
131    I: NetlinkSerializable,
132{
133    fn buffer_len(&self) -> usize {
134        use self::NetlinkPayload::*;
135
136        let payload_len = match self.payload {
137            Noop => 0,
138            Done(ref msg) => msg.buffer_len(),
139            Overrun(ref bytes) => bytes.len(),
140            Error(ref msg) => msg.buffer_len(),
141            InnerMessage(ref msg) => msg.buffer_len(),
142        };
143
144        self.header.buffer_len() + payload_len
145    }
146
147    fn emit(&self, buffer: &mut [u8]) {
148        use self::NetlinkPayload::*;
149
150        self.header.emit(buffer);
151
152        let buffer = &mut buffer[self.header.buffer_len()..self.header.length as usize];
153        match self.payload {
154            Noop => {}
155            Done(ref msg) => msg.emit(buffer),
156            Overrun(ref bytes) => buffer.copy_from_slice(bytes),
157            Error(ref msg) => msg.emit(buffer),
158            InnerMessage(ref msg) => msg.serialize(buffer),
159        }
160    }
161}
162
163impl<T> From<T> for NetlinkMessage<T>
164where
165    T: Into<NetlinkPayload<T>>,
166{
167    fn from(inner_message: T) -> Self {
168        NetlinkMessage { header: NetlinkHeader::default(), payload: inner_message.into() }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    use std::mem::size_of;
177    use std::num::NonZeroI32;
178
179    #[derive(Clone, Debug, Default, PartialEq)]
180    struct FakeNetlinkInnerMessage;
181
182    impl NetlinkSerializable for FakeNetlinkInnerMessage {
183        fn message_type(&self) -> u16 {
184            unimplemented!("unused by tests")
185        }
186
187        fn buffer_len(&self) -> usize {
188            unimplemented!("unused by tests")
189        }
190
191        fn serialize(&self, _buffer: &mut [u8]) {
192            unimplemented!("unused by tests")
193        }
194    }
195
196    impl NetlinkDeserializable for FakeNetlinkInnerMessage {
197        type Error = DecodeError;
198
199        fn deserialize(_header: &NetlinkHeader, _payload: &[u8]) -> Result<Self, Self::Error> {
200            unimplemented!("unused by tests")
201        }
202    }
203
204    #[test]
205    fn test_done() {
206        let header = NetlinkHeader::default();
207        let done_msg = DoneMessage { code: 0, extended_ack: vec![6, 7, 8, 9] };
208        let mut want = NetlinkMessage::new(
209            header,
210            NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
211        );
212        want.finalize();
213
214        let len = want.buffer_len();
215        assert_eq!(len, header.buffer_len() + size_of::<i32>() + done_msg.extended_ack.len());
216
217        let mut buf = vec![1; len];
218        want.emit(&mut buf);
219
220        let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
221        assert_eq!(done_buf.code(), done_msg.code);
222        assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
223
224        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
225        assert_eq!(got, want);
226    }
227
228    #[test]
229    fn test_error() {
230        // SAFETY: value is non-zero.
231        const ERROR_CODE: NonZeroI32 = unsafe { NonZeroI32::new_unchecked(-8765) };
232
233        let header = NetlinkHeader::default();
234        let error_msg = ErrorMessage { code: Some(ERROR_CODE), header: vec![] };
235        let mut want = NetlinkMessage::new(
236            header,
237            NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
238        );
239        want.finalize();
240
241        let len = want.buffer_len();
242        assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
243
244        let mut buf = vec![1; len];
245        want.emit(&mut buf);
246
247        let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
248        assert_eq!(error_buf.code(), error_msg.code);
249
250        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
251        assert_eq!(got, want);
252    }
253}