1use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::{DecodeError, ParseableParametrized};
7
8use crate::payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
9use crate::{
10 DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
11 NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable,
12};
13
14#[derive(Debug, PartialEq, Eq, Clone)]
16#[non_exhaustive]
17pub struct NetlinkMessage<I> {
18 pub header: NetlinkHeader,
20 pub payload: NetlinkPayload<I>,
22}
23
24impl<I> NetlinkMessage<I> {
25 pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
27 NetlinkMessage { header, payload }
28 }
29
30 pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
32 (self.header, self.payload)
33 }
34}
35
36impl<I> NetlinkMessage<I>
37where
38 I: NetlinkDeserializable,
39 I::Error: Into<DecodeError>,
40{
41 pub fn deserialize(buffer: &[u8], options: I::DeserializeOptions) -> Result<Self, DecodeError> {
43 let netlink_buffer = NetlinkBuffer::new(&buffer)?;
44 <Self as ParseableParametrized<
45 NetlinkBuffer<&&[u8]>,
46 I::DeserializeOptions,
47 >>::parse_with_param(&netlink_buffer, options)
48 }
49}
50
51impl<I> NetlinkMessage<I>
52where
53 I: NetlinkSerializable,
54{
55 pub fn buffer_len(&self) -> usize {
57 <Self as Emitable>::buffer_len(self)
58 }
59
60 pub fn serialize(&self, buffer: &mut [u8]) {
69 self.emit(buffer)
70 }
71
72 pub fn finalize(&mut self) {
84 self.header.length = self.buffer_len() as u32;
85 self.header.message_type = self.payload.message_type();
86 }
87}
88
89impl<'buffer, B, I> ParseableParametrized<NetlinkBuffer<&'buffer B>, I::DeserializeOptions>
90 for NetlinkMessage<I>
91where
92 B: AsRef<[u8]> + 'buffer,
93 I: NetlinkDeserializable,
94 I::Error: Into<DecodeError>,
95{
96 type Error = DecodeError;
97 fn parse_with_param(
98 buf: &NetlinkBuffer<&'buffer B>,
99 options: I::DeserializeOptions,
100 ) -> Result<Self, DecodeError> {
101 use self::NetlinkPayload::*;
102
103 let header = <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
104 .context("failed to parse netlink header")?;
105
106 let bytes = buf.payload();
107 let payload = match header.message_type {
108 NLMSG_ERROR => {
109 let msg = ErrorBuffer::new(&bytes)
110 .and_then(|buf| ErrorMessage::parse(&buf))
111 .map_err(|err| DecodeError::FailedToParseNlMsgError(err.into()))?;
112 Error(msg)
113 }
114 NLMSG_NOOP => Noop,
115 NLMSG_DONE => {
116 let msg = DoneBuffer::new(&bytes)
117 .and_then(|buf| DoneMessage::parse(&buf))
118 .map_err(|err| DecodeError::FailedToParseNlMsgDone(err.into()))?;
119 Done(msg)
120 }
121 NLMSG_OVERRUN => Overrun(bytes.to_vec()),
122 message_type => {
123 let inner_msg = I::deserialize(&header, bytes, options).map_err(|err| {
124 DecodeError::FailedToParseMessageWithType {
125 message_type,
126 source: Box::new(err.into()),
127 }
128 })?;
129 InnerMessage(inner_msg)
130 }
131 };
132 Ok(NetlinkMessage { header, payload })
133 }
134}
135
136impl<I> Emitable for NetlinkMessage<I>
137where
138 I: NetlinkSerializable,
139{
140 fn buffer_len(&self) -> usize {
141 use self::NetlinkPayload::*;
142
143 let payload_len = match self.payload {
144 Noop => 0,
145 Done(ref msg) => msg.buffer_len(),
146 Overrun(ref bytes) => bytes.len(),
147 Error(ref msg) => msg.buffer_len(),
148 InnerMessage(ref msg) => msg.buffer_len(),
149 };
150
151 self.header.buffer_len() + payload_len
152 }
153
154 fn emit(&self, buffer: &mut [u8]) {
155 use self::NetlinkPayload::*;
156
157 self.header.emit(buffer);
158
159 let buffer = &mut buffer[self.header.buffer_len()..self.header.length as usize];
160 match self.payload {
161 Noop => {}
162 Done(ref msg) => msg.emit(buffer),
163 Overrun(ref bytes) => buffer.copy_from_slice(bytes),
164 Error(ref msg) => msg.emit(buffer),
165 InnerMessage(ref msg) => msg.serialize(buffer),
166 }
167 }
168}
169
170impl<T> From<T> for NetlinkMessage<T>
171where
172 T: Into<NetlinkPayload<T>>,
173{
174 fn from(inner_message: T) -> Self {
175 NetlinkMessage { header: NetlinkHeader::default(), payload: inner_message.into() }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 use std::mem::size_of;
184 use std::num::NonZeroI32;
185
186 #[derive(Clone, Debug, Default, PartialEq)]
187 struct FakeNetlinkInnerMessage;
188
189 impl NetlinkSerializable for FakeNetlinkInnerMessage {
190 fn message_type(&self) -> u16 {
191 unimplemented!("unused by tests")
192 }
193
194 fn buffer_len(&self) -> usize {
195 unimplemented!("unused by tests")
196 }
197
198 fn serialize(&self, _buffer: &mut [u8]) {
199 unimplemented!("unused by tests")
200 }
201 }
202
203 impl NetlinkDeserializable for FakeNetlinkInnerMessage {
204 type DeserializeOptions = ();
205 type Error = DecodeError;
206
207 fn deserialize(
208 _header: &NetlinkHeader,
209 _payload: &[u8],
210 _options: (),
211 ) -> Result<Self, Self::Error> {
212 unimplemented!("unused by tests")
213 }
214 }
215
216 #[test]
217 fn test_done() {
218 let header = NetlinkHeader::default();
219 let done_msg = DoneMessage { code: 0, extended_ack: vec![6, 7, 8, 9] };
220 let mut want = NetlinkMessage::new(
221 header,
222 NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
223 );
224 want.finalize();
225
226 let len = want.buffer_len();
227 assert_eq!(len, header.buffer_len() + size_of::<i32>() + done_msg.extended_ack.len());
228
229 let mut buf = vec![1; len];
230 want.emit(&mut buf);
231
232 let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]).unwrap();
233 assert_eq!(done_buf.code(), done_msg.code);
234 assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
235
236 let got = NetlinkMessage::parse_with_param(&NetlinkBuffer::new(&buf).unwrap(), ()).unwrap();
237 assert_eq!(got, want);
238 }
239
240 #[test]
241 fn test_error() {
242 const ERROR_CODE: NonZeroI32 = unsafe { NonZeroI32::new_unchecked(-8765) };
244
245 let header = NetlinkHeader::default();
246 let error_msg = ErrorMessage { code: Some(ERROR_CODE), header: vec![] };
247 let mut want = NetlinkMessage::new(
248 header,
249 NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
250 );
251 want.finalize();
252
253 let len = want.buffer_len();
254 assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
255
256 let mut buf = vec![1; len];
257 want.emit(&mut buf);
258
259 let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]).unwrap();
260 assert_eq!(error_buf.code(), error_msg.code);
261
262 let got = NetlinkMessage::parse_with_param(&NetlinkBuffer::new(&buf).unwrap(), ()).unwrap();
263 assert_eq!(got, want);
264 }
265}