netlink_packet_core/
message.rs
1use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::DecodeError;
7
8use crate::payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
9use crate::{
10 DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
11 NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable,
12};
13
14#[derive(Debug, PartialEq, Eq, Clone)]
16#[non_exhaustive]
17pub struct NetlinkMessage<I> {
18 pub header: NetlinkHeader,
20 pub payload: NetlinkPayload<I>,
22}
23
24impl<I> NetlinkMessage<I> {
25 pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
27 NetlinkMessage { header, payload }
28 }
29
30 pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
32 (self.header, self.payload)
33 }
34}
35
36impl<I> NetlinkMessage<I>
37where
38 I: NetlinkDeserializable,
39 I::Error: Into<DecodeError>,
40{
41 pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
43 let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
44 <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
45 }
46}
47
48impl<I> NetlinkMessage<I>
49where
50 I: NetlinkSerializable,
51{
52 pub fn buffer_len(&self) -> usize {
54 <Self as Emitable>::buffer_len(self)
55 }
56
57 pub fn serialize(&self, buffer: &mut [u8]) {
66 self.emit(buffer)
67 }
68
69 pub fn finalize(&mut self) {
81 self.header.length = self.buffer_len() as u32;
82 self.header.message_type = self.payload.message_type();
83 }
84}
85
86impl<'buffer, B, I> Parseable<NetlinkBuffer<&'buffer B>> for NetlinkMessage<I>
87where
88 B: AsRef<[u8]> + 'buffer,
89 I: NetlinkDeserializable,
90 I::Error: Into<DecodeError>,
91{
92 type Error = DecodeError;
93 fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
94 use self::NetlinkPayload::*;
95
96 let header = <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
97 .context("failed to parse netlink header")?;
98
99 let bytes = buf.payload();
100 let payload = match header.message_type {
101 NLMSG_ERROR => {
102 let msg = ErrorBuffer::new_checked(&bytes)
103 .and_then(|buf| ErrorMessage::parse(&buf))
104 .map_err(|err| DecodeError::FailedToParseNlMsgError(err.into()))?;
105 Error(msg)
106 }
107 NLMSG_NOOP => Noop,
108 NLMSG_DONE => {
109 let msg = DoneBuffer::new_checked(&bytes)
110 .and_then(|buf| DoneMessage::parse(&buf))
111 .map_err(|err| DecodeError::FailedToParseNlMsgDone(err.into()))?;
112 Done(msg)
113 }
114 NLMSG_OVERRUN => Overrun(bytes.to_vec()),
115 message_type => {
116 let inner_msg = I::deserialize(&header, bytes).map_err(|err| {
117 DecodeError::FailedToParseMessageWithType {
118 message_type,
119 source: Box::new(err.into()),
120 }
121 })?;
122 InnerMessage(inner_msg)
123 }
124 };
125 Ok(NetlinkMessage { header, payload })
126 }
127}
128
129impl<I> Emitable for NetlinkMessage<I>
130where
131 I: NetlinkSerializable,
132{
133 fn buffer_len(&self) -> usize {
134 use self::NetlinkPayload::*;
135
136 let payload_len = match self.payload {
137 Noop => 0,
138 Done(ref msg) => msg.buffer_len(),
139 Overrun(ref bytes) => bytes.len(),
140 Error(ref msg) => msg.buffer_len(),
141 InnerMessage(ref msg) => msg.buffer_len(),
142 };
143
144 self.header.buffer_len() + payload_len
145 }
146
147 fn emit(&self, buffer: &mut [u8]) {
148 use self::NetlinkPayload::*;
149
150 self.header.emit(buffer);
151
152 let buffer = &mut buffer[self.header.buffer_len()..self.header.length as usize];
153 match self.payload {
154 Noop => {}
155 Done(ref msg) => msg.emit(buffer),
156 Overrun(ref bytes) => buffer.copy_from_slice(bytes),
157 Error(ref msg) => msg.emit(buffer),
158 InnerMessage(ref msg) => msg.serialize(buffer),
159 }
160 }
161}
162
163impl<T> From<T> for NetlinkMessage<T>
164where
165 T: Into<NetlinkPayload<T>>,
166{
167 fn from(inner_message: T) -> Self {
168 NetlinkMessage { header: NetlinkHeader::default(), payload: inner_message.into() }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 use std::mem::size_of;
177 use std::num::NonZeroI32;
178
179 #[derive(Clone, Debug, Default, PartialEq)]
180 struct FakeNetlinkInnerMessage;
181
182 impl NetlinkSerializable for FakeNetlinkInnerMessage {
183 fn message_type(&self) -> u16 {
184 unimplemented!("unused by tests")
185 }
186
187 fn buffer_len(&self) -> usize {
188 unimplemented!("unused by tests")
189 }
190
191 fn serialize(&self, _buffer: &mut [u8]) {
192 unimplemented!("unused by tests")
193 }
194 }
195
196 impl NetlinkDeserializable for FakeNetlinkInnerMessage {
197 type Error = DecodeError;
198
199 fn deserialize(_header: &NetlinkHeader, _payload: &[u8]) -> Result<Self, Self::Error> {
200 unimplemented!("unused by tests")
201 }
202 }
203
204 #[test]
205 fn test_done() {
206 let header = NetlinkHeader::default();
207 let done_msg = DoneMessage { code: 0, extended_ack: vec![6, 7, 8, 9] };
208 let mut want = NetlinkMessage::new(
209 header,
210 NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
211 );
212 want.finalize();
213
214 let len = want.buffer_len();
215 assert_eq!(len, header.buffer_len() + size_of::<i32>() + done_msg.extended_ack.len());
216
217 let mut buf = vec![1; len];
218 want.emit(&mut buf);
219
220 let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
221 assert_eq!(done_buf.code(), done_msg.code);
222 assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
223
224 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
225 assert_eq!(got, want);
226 }
227
228 #[test]
229 fn test_error() {
230 const ERROR_CODE: NonZeroI32 = unsafe { NonZeroI32::new_unchecked(-8765) };
232
233 let header = NetlinkHeader::default();
234 let error_msg = ErrorMessage { code: Some(ERROR_CODE), header: vec![] };
235 let mut want = NetlinkMessage::new(
236 header,
237 NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
238 );
239 want.finalize();
240
241 let len = want.buffer_len();
242 assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
243
244 let mut buf = vec![1; len];
245 want.emit(&mut buf);
246
247 let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
248 assert_eq!(error_buf.code(), error_msg.code);
249
250 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
251 assert_eq!(got, want);
252 }
253}