1use std::num::NonZeroI32;
8
9use log::warn;
10use net_types::ip::{Ip, Ipv4Addr, Ipv6Addr};
11use netlink_packet_core::buffer::NETLINK_HEADER_LEN;
12use netlink_packet_core::constants::{
13 NLM_F_ACK, NLM_F_APPEND, NLM_F_ATOMIC, NLM_F_CREATE, NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL,
14 NLM_F_MATCH, NLM_F_MULTIPART, NLM_F_REPLACE, NLM_F_REQUEST, NLM_F_ROOT,
15};
16use netlink_packet_core::{
17 DoneMessage, ErrorMessage, NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
18};
19use netlink_packet_route::route::RouteAddress;
20use netlink_packet_utils::Emitable as _;
21
22use crate::netlink_packet::errno::Errno;
23
24pub(crate) const UNSPECIFIED_SEQUENCE_NUMBER: u32 = 0;
25
26const DONE_ERROR_CODE: i32 = 0;
28
29pub(crate) fn new_done<T: NetlinkSerializable>(req_header: NetlinkHeader) -> NetlinkMessage<T> {
31 let mut done = DoneMessage::default();
32 done.code = DONE_ERROR_CODE;
33 let payload = NetlinkPayload::<T>::Done(done);
34 let mut resp_header = NetlinkHeader::default();
35 resp_header.sequence_number = req_header.sequence_number;
36 resp_header.flags |= NLM_F_MULTIPART;
37 let mut message = NetlinkMessage::new(resp_header, payload);
38 message.finalize();
40 message
41}
42
43pub(crate) fn ip_addr_from_route<I: Ip>(route_addr: &RouteAddress) -> Result<I::Addr, Errno> {
45 I::map_ip(
46 (),
47 |()| match route_addr {
48 RouteAddress::Inet(v4_addr) => Ok(Ipv4Addr::new(v4_addr.octets())),
49 RouteAddress::Inet6(_) => {
50 warn!("expected IPv4 address from route but got an IPv6 address");
51 Err(Errno::EINVAL)
52 }
53 RouteAddress::Mpls(_) | RouteAddress::Other(_) | _ => Err(Errno::ENOTSUP),
54 },
55 |()| match route_addr {
56 RouteAddress::Inet6(v6_addr) => Ok(Ipv6Addr::new(v6_addr.segments())),
57 RouteAddress::Inet(_) => {
58 warn!("expected IPv6 address from route but got an IPv4 address");
59 Err(Errno::EINVAL)
60 }
61 RouteAddress::Mpls(_) | RouteAddress::Other(_) | _ => Err(Errno::ENOTSUP),
62 },
63 )
64}
65
66pub(crate) mod errno {
67 use net_types::ip::GenericOverIp;
68
69 use super::*;
70
71 #[derive(Copy, Clone, Debug, PartialEq, GenericOverIp)]
75 #[generic_over_ip()]
76 pub struct Errno(i32);
77
78 impl Errno {
79 pub(crate) const EADDRNOTAVAIL: Errno = Errno::new(libc::EADDRNOTAVAIL).unwrap();
80 pub(crate) const EAFNOSUPPORT: Errno = Errno::new(libc::EAFNOSUPPORT).unwrap();
81 pub(crate) const EBUSY: Errno = Errno::new(libc::EBUSY).unwrap();
82 pub(crate) const EEXIST: Errno = Errno::new(libc::EEXIST).unwrap();
83 pub(crate) const EINVAL: Errno = Errno::new(libc::EINVAL).unwrap();
84 pub(crate) const ENODEV: Errno = Errno::new(libc::ENODEV).unwrap();
85 pub(crate) const ENOENT: Errno = Errno::new(libc::ENOENT).unwrap();
86 pub(crate) const ENOTSUP: Errno = Errno::new(libc::ENOTSUP).unwrap();
87 pub(crate) const ESRCH: Errno = Errno::new(libc::ESRCH).unwrap();
88 pub(crate) const ETOOMANYREFS: Errno = Errno::new(libc::ETOOMANYREFS).unwrap();
89
90 pub const fn new(code: i32) -> Option<Self> {
94 if code.is_positive() { Some(Errno(code)) } else { None }
95 }
96 }
97
98 impl From<Errno> for NonZeroI32 {
99 fn from(Errno(code): Errno) -> Self {
100 NonZeroI32::new(code).expect("Errno's code must be non-zero")
101 }
102 }
103
104 impl From<Errno> for i32 {
105 fn from(Errno(code): Errno) -> Self {
106 code
107 }
108 }
109
110 #[cfg(test)]
111 mod tests {
112 use super::*;
113 use test_case::test_case;
114
115 #[test_case(i32::MIN, None; "min")]
116 #[test_case(-10, None; "negative")]
117 #[test_case(0, None; "zero")]
118 #[test_case(10, Some(10); "positive")]
119 #[test_case(i32::MAX, Some(i32::MAX); "max")]
120 fn test_new_errno(raw_code: i32, expected_code: Option<i32>) {
121 assert_eq!(Errno::new(raw_code).map(Into::<i32>::into), expected_code)
122 }
123 }
124}
125
126pub(crate) fn new_error<T: NetlinkSerializable>(
130 error: Result<(), errno::Errno>,
131 req_header: NetlinkHeader,
132) -> NetlinkMessage<T> {
133 let error = {
134 assert_eq!(req_header.buffer_len(), NETLINK_HEADER_LEN);
135 let mut buffer = vec![0; NETLINK_HEADER_LEN];
136 req_header.emit(&mut buffer);
137
138 let code = match error {
139 Ok(()) => None,
140
141 Err(e) => Some(-NonZeroI32::from(e)),
143 };
144
145 let mut error = ErrorMessage::default();
146 error.code = code;
147 error.header = buffer;
148 error
149 };
150
151 let payload = NetlinkPayload::<T>::Error(error);
152 let mut resp_header = NetlinkHeader::default();
155 resp_header.sequence_number = req_header.sequence_number;
156 let mut message = NetlinkMessage::new(resp_header, payload);
157 message.finalize();
159 message
160}
161
162#[derive(Clone, Copy, Debug, PartialEq)]
164pub(crate) enum NetlinkRequestType {
165 New,
167 Get,
169 Set,
171 Del,
173}
174
175pub(crate) fn netlink_flags_debug_string(flags: u16, request_type: NetlinkRequestType) -> String {
180 let mut flags_dbg = vec![];
181 if (flags & NLM_F_REQUEST) == NLM_F_REQUEST {
182 flags_dbg.push("REQUEST");
183 }
184 if (flags & NLM_F_MULTIPART) == NLM_F_MULTIPART {
185 flags_dbg.push("MULTI");
186 }
187 if (flags & NLM_F_ACK) == NLM_F_ACK {
188 flags_dbg.push("ACK");
189 }
190 if (flags & NLM_F_ECHO) == NLM_F_ECHO {
191 flags_dbg.push("ECHO");
192 }
193 match request_type {
194 NetlinkRequestType::Get => {
195 if (flags & NLM_F_DUMP) == NLM_F_DUMP {
196 flags_dbg.push("DUMP");
197 } else {
198 if (flags & NLM_F_ROOT) == NLM_F_ROOT {
200 flags_dbg.push("ROOT");
201 }
202 if (flags & NLM_F_MATCH) == NLM_F_MATCH {
203 flags_dbg.push("MATCH");
204 }
205 }
206 if (flags & NLM_F_ATOMIC) == NLM_F_ATOMIC {
207 flags_dbg.push("ATOMIC");
208 }
209 }
210 NetlinkRequestType::New => {
211 if (flags & NLM_F_REPLACE) == NLM_F_REPLACE {
212 flags_dbg.push("REPLACE");
213 }
214 if (flags & NLM_F_EXCL) == NLM_F_EXCL {
215 flags_dbg.push("EXCL");
216 }
217 if (flags & NLM_F_CREATE) == NLM_F_CREATE {
218 flags_dbg.push("CREATE");
219 }
220 if (flags & NLM_F_APPEND) == NLM_F_APPEND {
221 flags_dbg.push("APPEND");
222 }
223 }
224 NetlinkRequestType::Set | NetlinkRequestType::Del => {}
225 }
226 flags_dbg.join("|")
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 use assert_matches::assert_matches;
234 use netlink_packet_core::{NLMSG_DONE, NLMSG_ERROR, NetlinkBuffer};
235 use netlink_packet_route::RouteNetlinkMessage;
236 use netlink_packet_utils::Parseable as _;
237 use test_case::test_case;
238
239 use crate::netlink_packet::errno::Errno;
240
241 #[test_case(0, Ok(()); "ACK")]
242 #[test_case(0, Err(Errno::EINVAL); "EINVAL")]
243 #[test_case(1, Err(Errno::ENODEV); "ENODEV")]
244 fn test_new_error(sequence_number: u32, expected_error: Result<(), Errno>) {
245 let mut expected_header = NetlinkHeader::default();
247 expected_header.length = 0x01234567;
248 expected_header.message_type = 0x89AB;
249 expected_header.flags = 0xCDEF;
250 expected_header.sequence_number = sequence_number;
251 expected_header.port_number = 0x00000000;
252
253 let error = new_error::<RouteNetlinkMessage>(expected_error, expected_header);
254 let mut buf = vec![0; error.buffer_len()];
256 error.serialize(&mut buf);
257
258 let (header, payload) = error.into_parts();
259 assert_eq!(header.message_type, NLMSG_ERROR);
260 assert_eq!(header.sequence_number, sequence_number);
261 assert_matches!(
262 payload,
263 NetlinkPayload::Error(ErrorMessage{ code, header, .. }) => {
264 let expected_code = match expected_error {
265 Ok(()) => None,
266 Err(e) => Some(-NonZeroI32::from(e)),
267 };
268 assert_eq!(code, expected_code);
269 assert_eq!(
270 NetlinkHeader::parse(&NetlinkBuffer::new_unchecked(&header)).unwrap(),
273 expected_header,
274 );
275 }
276 );
277 }
278
279 #[test_case(0; "seq_0")]
280 #[test_case(1; "seq_1")]
281 fn test_new_done(sequence_number: u32) {
282 let mut req_header = NetlinkHeader::default();
283 req_header.sequence_number = sequence_number;
284
285 let done = new_done::<RouteNetlinkMessage>(req_header);
286 let mut buf = vec![0; done.buffer_len()];
288 done.serialize(&mut buf);
289
290 let (header, payload) = done.into_parts();
291 assert_eq!(header.sequence_number, sequence_number);
292 assert_eq!(header.message_type, NLMSG_DONE);
293 assert_eq!(header.flags, NLM_F_MULTIPART);
294 assert_matches!(
295 payload,
296 NetlinkPayload::Done(DoneMessage {code, extended_ack, ..}) => {
297 assert_eq!(code, DONE_ERROR_CODE);
298 assert_eq!(extended_ack, Vec::<u8>::new());
299 }
300 );
301 }
302
303 #[test_case(
304 0,
305 NetlinkRequestType::Get => "";
306 "no flags"
307 )]
308 #[test_case(
309 NLM_F_REQUEST,
310 NetlinkRequestType::Get => "REQUEST";
311 "request only"
312 )]
313 #[test_case(
314 NLM_F_REQUEST|NLM_F_MULTIPART|NLM_F_ACK|NLM_F_ECHO,
315 NetlinkRequestType::Get => "REQUEST|MULTI|ACK|ECHO";
316 "all generic flags"
317 )]
318 #[test_case(
319 NLM_F_REQUEST|NLM_F_DUMP,
320 NetlinkRequestType::Get => "REQUEST|DUMP";
321 "dump request"
322 )]
323 #[test_case(
324 NLM_F_REQUEST|NLM_F_MATCH|NLM_F_ROOT,
325 NetlinkRequestType::Get => "REQUEST|DUMP";
326 "dump is alias for match|root"
327 )]
328 #[test_case(
329 NLM_F_REQUEST|NLM_F_ATOMIC,
330 NetlinkRequestType::Get => "REQUEST|ATOMIC";
331 "other Get flags"
332 )]
333 #[test_case(
334 NLM_F_REQUEST|NLM_F_REPLACE|NLM_F_EXCL|NLM_F_CREATE|NLM_F_APPEND,
335 NetlinkRequestType::New
336 => "REQUEST|REPLACE|EXCL|CREATE|APPEND";
337 "New flags"
338 )]
339 #[test_case(
340 NLM_F_REQUEST|NLM_F_REPLACE|NLM_F_EXCL|NLM_F_CREATE|NLM_F_APPEND,
341 NetlinkRequestType::Del => "REQUEST";
342 "type-inappropriate flags ignored"
343 )]
344 fn netlink_flags_debug_string_tests(flags: u16, request_type: NetlinkRequestType) -> String {
345 netlink_flags_debug_string(flags, request_type)
346 }
347}