netlink/
netlink_packet.rs1use std::num::NonZeroI32;
8
9use log::warn;
10use net_types::ip::{Ip, Ipv4Addr, Ipv6Addr};
11use netlink_packet_core::buffer::NETLINK_HEADER_LEN;
12use netlink_packet_core::constants::NLM_F_MULTIPART;
13use netlink_packet_core::{
14 DoneMessage, ErrorMessage, NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
15};
16use netlink_packet_route::route::RouteAddress;
17use netlink_packet_utils::Emitable as _;
18
19use crate::netlink_packet::errno::Errno;
20
21pub(crate) const UNSPECIFIED_SEQUENCE_NUMBER: u32 = 0;
22
23const DONE_ERROR_CODE: i32 = 0;
25
26pub(crate) fn new_done<T: NetlinkSerializable>(req_header: NetlinkHeader) -> NetlinkMessage<T> {
28 let mut done = DoneMessage::default();
29 done.code = DONE_ERROR_CODE;
30 let payload = NetlinkPayload::<T>::Done(done);
31 let mut resp_header = NetlinkHeader::default();
32 resp_header.sequence_number = req_header.sequence_number;
33 resp_header.flags |= NLM_F_MULTIPART;
34 let mut message = NetlinkMessage::new(resp_header, payload);
35 message.finalize();
37 message
38}
39
40pub(crate) fn ip_addr_from_route<I: Ip>(route_addr: &RouteAddress) -> Result<I::Addr, Errno> {
42 I::map_ip(
43 (),
44 |()| match route_addr {
45 RouteAddress::Inet(v4_addr) => Ok(Ipv4Addr::new(v4_addr.octets())),
46 RouteAddress::Inet6(_) => {
47 warn!("expected IPv4 address from route but got an IPv6 address");
48 Err(Errno::EINVAL)
49 }
50 RouteAddress::Mpls(_) | RouteAddress::Other(_) | _ => Err(Errno::ENOTSUP),
51 },
52 |()| match route_addr {
53 RouteAddress::Inet6(v6_addr) => Ok(Ipv6Addr::new(v6_addr.segments())),
54 RouteAddress::Inet(_) => {
55 warn!("expected IPv6 address from route but got an IPv4 address");
56 Err(Errno::EINVAL)
57 }
58 RouteAddress::Mpls(_) | RouteAddress::Other(_) | _ => Err(Errno::ENOTSUP),
59 },
60 )
61}
62
63pub(crate) mod errno {
64 use net_types::ip::GenericOverIp;
65
66 use super::*;
67
68 #[derive(Copy, Clone, Debug, PartialEq, GenericOverIp)]
72 #[generic_over_ip()]
73 pub struct Errno(i32);
74
75 impl Errno {
76 pub(crate) const EADDRNOTAVAIL: Errno = Errno::new(libc::EADDRNOTAVAIL).unwrap();
77 pub(crate) const EAFNOSUPPORT: Errno = Errno::new(libc::EAFNOSUPPORT).unwrap();
78 pub(crate) const EBUSY: Errno = Errno::new(libc::EBUSY).unwrap();
79 pub(crate) const EEXIST: Errno = Errno::new(libc::EEXIST).unwrap();
80 pub(crate) const EINVAL: Errno = Errno::new(libc::EINVAL).unwrap();
81 pub(crate) const ENODEV: Errno = Errno::new(libc::ENODEV).unwrap();
82 pub(crate) const ENOENT: Errno = Errno::new(libc::ENOENT).unwrap();
83 pub(crate) const ENOTSUP: Errno = Errno::new(libc::ENOTSUP).unwrap();
84 pub(crate) const ESRCH: Errno = Errno::new(libc::ESRCH).unwrap();
85 pub(crate) const ETOOMANYREFS: Errno = Errno::new(libc::ETOOMANYREFS).unwrap();
86
87 pub const fn new(code: i32) -> Option<Self> {
91 if code.is_positive() { Some(Errno(code)) } else { None }
92 }
93 }
94
95 impl From<Errno> for NonZeroI32 {
96 fn from(Errno(code): Errno) -> Self {
97 NonZeroI32::new(code).expect("Errno's code must be non-zero")
98 }
99 }
100
101 impl From<Errno> for i32 {
102 fn from(Errno(code): Errno) -> Self {
103 code
104 }
105 }
106
107 #[cfg(test)]
108 mod tests {
109 use super::*;
110 use test_case::test_case;
111
112 #[test_case(i32::MIN, None; "min")]
113 #[test_case(-10, None; "negative")]
114 #[test_case(0, None; "zero")]
115 #[test_case(10, Some(10); "positive")]
116 #[test_case(i32::MAX, Some(i32::MAX); "max")]
117 fn test_new_errno(raw_code: i32, expected_code: Option<i32>) {
118 assert_eq!(Errno::new(raw_code).map(Into::<i32>::into), expected_code)
119 }
120 }
121}
122
123pub(crate) fn new_error<T: NetlinkSerializable>(
127 error: Result<(), errno::Errno>,
128 req_header: NetlinkHeader,
129) -> NetlinkMessage<T> {
130 let error = {
131 assert_eq!(req_header.buffer_len(), NETLINK_HEADER_LEN);
132 let mut buffer = vec![0; NETLINK_HEADER_LEN];
133 req_header.emit(&mut buffer);
134
135 let code = match error {
136 Ok(()) => None,
137
138 Err(e) => Some(-NonZeroI32::from(e)),
140 };
141
142 let mut error = ErrorMessage::default();
143 error.code = code;
144 error.header = buffer;
145 error
146 };
147
148 let payload = NetlinkPayload::<T>::Error(error);
149 let mut resp_header = NetlinkHeader::default();
152 resp_header.sequence_number = req_header.sequence_number;
153 let mut message = NetlinkMessage::new(resp_header, payload);
154 message.finalize();
156 message
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 use assert_matches::assert_matches;
164 use netlink_packet_core::{NLMSG_DONE, NLMSG_ERROR, NetlinkBuffer};
165 use netlink_packet_route::RouteNetlinkMessage;
166 use netlink_packet_utils::Parseable as _;
167 use test_case::test_case;
168
169 use crate::netlink_packet::errno::Errno;
170
171 #[test_case(0, Ok(()); "ACK")]
172 #[test_case(0, Err(Errno::EINVAL); "EINVAL")]
173 #[test_case(1, Err(Errno::ENODEV); "ENODEV")]
174 fn test_new_error(sequence_number: u32, expected_error: Result<(), Errno>) {
175 let mut expected_header = NetlinkHeader::default();
177 expected_header.length = 0x01234567;
178 expected_header.message_type = 0x89AB;
179 expected_header.flags = 0xCDEF;
180 expected_header.sequence_number = sequence_number;
181 expected_header.port_number = 0x00000000;
182
183 let error = new_error::<RouteNetlinkMessage>(expected_error, expected_header);
184 let mut buf = vec![0; error.buffer_len()];
186 error.serialize(&mut buf);
187
188 let (header, payload) = error.into_parts();
189 assert_eq!(header.message_type, NLMSG_ERROR);
190 assert_eq!(header.sequence_number, sequence_number);
191 assert_matches!(
192 payload,
193 NetlinkPayload::Error(ErrorMessage{ code, header, .. }) => {
194 let expected_code = match expected_error {
195 Ok(()) => None,
196 Err(e) => Some(-NonZeroI32::from(e)),
197 };
198 assert_eq!(code, expected_code);
199 assert_eq!(
200 NetlinkHeader::parse(&NetlinkBuffer::new_unchecked(&header)).unwrap(),
203 expected_header,
204 );
205 }
206 );
207 }
208
209 #[test_case(0; "seq_0")]
210 #[test_case(1; "seq_1")]
211 fn test_new_done(sequence_number: u32) {
212 let mut req_header = NetlinkHeader::default();
213 req_header.sequence_number = sequence_number;
214
215 let done = new_done::<RouteNetlinkMessage>(req_header);
216 let mut buf = vec![0; done.buffer_len()];
218 done.serialize(&mut buf);
219
220 let (header, payload) = done.into_parts();
221 assert_eq!(header.sequence_number, sequence_number);
222 assert_eq!(header.message_type, NLMSG_DONE);
223 assert_eq!(header.flags, NLM_F_MULTIPART);
224 assert_matches!(
225 payload,
226 NetlinkPayload::Done(DoneMessage {code, extended_ack, ..}) => {
227 assert_eq!(code, DONE_ERROR_CODE);
228 assert_eq!(extended_ack, Vec::<u8>::new());
229 }
230 );
231 }
232}