1use futures::Stream;
8use netlink_packet_core::{
9 NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload,
10 NetlinkSerializable,
11};
12use netlink_packet_route::RouteNetlinkMessageParseError;
13use netlink_packet_utils::nla::NlaError;
14use netlink_packet_utils::{DecodeError, Parseable, ParseableParametrized};
15use std::fmt::Debug;
16
17use crate::multicast_groups::ModernGroup;
18use crate::netlink_packet;
19use crate::netlink_packet::errno::Errno;
20
21pub trait Sender<M>: Clone + Send + Sync {
23 fn send(&mut self, message: NetlinkMessage<M>, group: Option<ModernGroup>);
30}
31
32pub trait Receiver<M, C>:
36 Stream<Item: UnvalidatedNetlinkMessage<Message = M, Credentials = C>> + Send
37where
38 M: Send + MessageWithPermission,
39 C: Send,
40{
41}
42
43impl<M, C, S> Receiver<M, C> for S
45where
46 M: Send + MessageWithPermission,
47 C: Send,
48 S: Stream<Item: UnvalidatedNetlinkMessage<Message = M, Credentials = C>> + Send,
49{
50}
51
52pub enum Permission {
55 NetlinkRouteRead,
57
58 NetlinkRouteWrite,
60
61 NetlinkSockDiagRead,
63
64 NetlinkSockDiagDestroy,
66}
67
68pub trait AccessControl<C>: Clone {
70 fn grant_assess(&self, creds: &C, permission: Permission) -> Result<(), Errno>;
73}
74
75pub trait MessageWithPermission {
78 fn permission(&self) -> Permission;
80}
81
82#[derive(Debug)]
84pub struct ParseError {
85 pub error: DecodeError,
87 pub header: Option<NetlinkHeader>,
92}
93
94pub trait MaybeParsedNetlinkMessage {
97 type Message: MessageWithPermission;
99
100 fn try_into_parsed(self) -> Result<NetlinkMessage<Self::Message>, ParseError>;
102}
103
104impl<M: MessageWithPermission> MaybeParsedNetlinkMessage for NetlinkMessage<M> {
105 type Message = M;
106 fn try_into_parsed(self) -> Result<NetlinkMessage<M>, ParseError> {
107 Ok(self)
108 }
109}
110
111pub struct UnparsedNetlinkMessage<B, M: NetlinkDeserializable> {
113 data: B,
114 options: M::DeserializeOptions,
115 _marker: std::marker::PhantomData<M>,
116}
117
118impl<B, M: NetlinkDeserializable> UnparsedNetlinkMessage<B, M> {
119 pub fn new(data: B, options: M::DeserializeOptions) -> Self {
121 Self { data, options, _marker: std::marker::PhantomData }
122 }
123}
124
125impl<M, B> MaybeParsedNetlinkMessage for UnparsedNetlinkMessage<B, M>
126where
127 B: AsRef<[u8]>,
128 M: NetlinkDeserializable + MessageWithPermission,
129 M::Error: Into<DecodeError>,
130{
131 type Message = M;
132
133 fn try_into_parsed(self) -> Result<NetlinkMessage<M>, ParseError> {
134 let Self { data, options, _marker } = self;
135 let data = data.as_ref();
136 let netlink_buffer =
137 NetlinkBuffer::new(&data).map_err(|error| ParseError { error, header: None })?;
138 NetlinkMessage::<M>::parse_with_param(&netlink_buffer, options).map_err(|error| {
139 ParseError {
140 error,
141 header: NetlinkHeader::parse(&netlink_buffer).ok(),
144 }
145 })
146 }
147}
148
149#[derive(Debug)]
151#[allow(missing_docs)]
152pub enum ValidationError {
153 Parse(ParseError),
155 Permission { header: NetlinkHeader, error: Errno },
157}
158
159fn nla_error_to_errno(error: &NlaError) -> Errno {
162 match error {
163 NlaError::BufferTooSmall { .. }
164 | NlaError::LengthMismatch { .. }
165 | NlaError::InvalidLength { .. } => Errno::EINVAL,
166 }
167}
168
169fn route_netlink_error_to_errno(error: &RouteNetlinkMessageParseError) -> Errno {
172 match error {
173 RouteNetlinkMessageParseError::ParseBuffer(decode_error)
174 | RouteNetlinkMessageParseError::InvalidLinkMessage(decode_error) => {
175 decode_error_to_errno(decode_error)
176 }
177 RouteNetlinkMessageParseError::InvalidRouteMessage(_)
178 | RouteNetlinkMessageParseError::InvalidAddrMessage(_)
179 | RouteNetlinkMessageParseError::InvalidPrefixMessage(_)
180 | RouteNetlinkMessageParseError::InvalidFibRuleMessage(_)
181 | RouteNetlinkMessageParseError::InvalidTcMessage(_)
182 | RouteNetlinkMessageParseError::InvalidNsidMessage(_)
183 | RouteNetlinkMessageParseError::InvalidNeighbourMessage(_)
184 | RouteNetlinkMessageParseError::InvalidNeighbourTableMessage(_)
185 | RouteNetlinkMessageParseError::InvalidNeighbourDiscoveryUserOptionMessage(_) => {
186 Errno::EINVAL
187 }
188 RouteNetlinkMessageParseError::UnknownMessageType(_) => Errno::ENOTSUP,
189 }
190}
191
192fn decode_error_to_errno(error: &DecodeError) -> Errno {
195 match error {
196 DecodeError::InvalidMACAddress
197 | DecodeError::InvalidIPAddress
198 | DecodeError::Utf8Error(_)
199 | DecodeError::InvalidU8
200 | DecodeError::InvalidU16
201 | DecodeError::InvalidU32
202 | DecodeError::InvalidU64
203 | DecodeError::InvalidU128
204 | DecodeError::InvalidI32
205 | DecodeError::InvalidBufferLength { .. } => Errno::EINVAL,
206 DecodeError::Nla(nla_error) => nla_error_to_errno(nla_error),
207 DecodeError::Other(error) => {
208 if let Some(error) = error.downcast_ref::<RouteNetlinkMessageParseError>() {
209 return route_netlink_error_to_errno(error);
210 }
211 if let Some(error) = error.downcast_ref::<netlink_packet_utils::DecodeError>() {
212 return decode_error_to_errno(error);
213 }
214 Errno::EINVAL
215 }
216 DecodeError::FailedToParseNlMsgError(error)
217 | DecodeError::FailedToParseNlMsgDone(error)
218 | DecodeError::FailedToParseMessageWithType { message_type: _, source: error }
219 | DecodeError::FailedToParseNetlinkHeader(error) => decode_error_to_errno(error),
220 }
221}
222
223impl ValidationError {
224 pub fn into_error_message<M: NetlinkSerializable>(self) -> Option<NetlinkMessage<M>> {
227 match self {
228 ValidationError::Parse(ParseError { error, header }) => {
229 let header = header?;
231 Some(netlink_packet::new_error(Err(decode_error_to_errno(&error)), header))
235 }
236 ValidationError::Permission { header, error } => {
237 Some(netlink_packet::new_error(Err(error), header))
238 }
239 }
240 }
241}
242
243#[derive(Clone, Debug)]
245pub struct NetlinkMessageWithCreds<M, C> {
246 message: M,
247 creds: C,
248}
249
250impl<M, C> NetlinkMessageWithCreds<M, C> {
251 pub fn new(message: M, creds: C) -> Self {
253 Self { message, creds }
254 }
255}
256
257pub trait UnvalidatedNetlinkMessage {
261 type Message;
263 type Credentials;
265
266 fn validate_creds_and_get_message<PS: AccessControl<Self::Credentials>>(
270 self,
271 access_control: &PS,
272 ) -> Result<NetlinkMessage<Self::Message>, ValidationError>;
273}
274
275impl<M, C> UnvalidatedNetlinkMessage for NetlinkMessageWithCreds<M, C>
276where
277 M: MaybeParsedNetlinkMessage,
278 M::Message: MessageWithPermission,
279{
280 type Message = M::Message;
281 type Credentials = C;
282
283 fn validate_creds_and_get_message<PS: AccessControl<C>>(
284 self,
285 access_control: &PS,
286 ) -> Result<NetlinkMessage<M::Message>, ValidationError> {
287 let Self { message, creds } = self;
288 let message = message.try_into_parsed().map_err(ValidationError::Parse)?;
289 let permission = match &message.payload {
290 NetlinkPayload::InnerMessage(msg) => msg.permission(),
291 NetlinkPayload::Done(_)
292 | NetlinkPayload::Error(_)
293 | NetlinkPayload::Noop
294 | NetlinkPayload::Overrun(_) => return Ok(message),
295 };
296
297 access_control
298 .grant_assess(&creds, permission)
299 .map_err(|error| ValidationError::Permission { header: message.header, error })?;
300 Ok(message)
301 }
302}
303
304pub trait NetlinkContext {
306 type Creds: Clone + Send + Debug;
308
309 type Sender<M: Clone + NetlinkSerializable + Send>: Sender<M>;
311
312 type Receiver<M: Send + MessageWithPermission + NetlinkDeserializable<Error: Into<DecodeError>>>: Receiver<M, Self::Creds>;
314
315 type AccessControl<'a>: AccessControl<Self::Creds>;
317}
318
319#[cfg(test)]
320pub(crate) mod testutil {
321 use super::*;
322 use crate::mpsc;
323 use futures::{FutureExt as _, StreamExt as _};
324 use netlink_packet_core::NetlinkSerializable;
325
326 #[derive(Clone, Debug, PartialEq, Eq)]
327 pub(crate) struct SentMessage<M> {
328 pub message: NetlinkMessage<M>,
329 pub group: Option<ModernGroup>,
330 }
331
332 impl<M> SentMessage<M> {
333 pub(crate) fn unicast(message: NetlinkMessage<M>) -> Self {
334 Self { message, group: None }
335 }
336
337 pub(crate) fn multicast(message: NetlinkMessage<M>, group: ModernGroup) -> Self {
338 Self { message, group: Some(group) }
339 }
340 }
341
342 #[derive(Clone, Debug)]
343 pub(crate) struct FakeSender<M> {
344 sender: futures::channel::mpsc::UnboundedSender<SentMessage<M>>,
345 }
346
347 impl<M: Clone + Send + NetlinkSerializable> Sender<M> for FakeSender<M> {
348 fn send(&mut self, message: NetlinkMessage<M>, group: Option<ModernGroup>) {
349 self.sender
350 .unbounded_send(SentMessage { message, group })
351 .expect("unable to send message");
352 }
353 }
354
355 pub(crate) struct FakeSenderSink<M> {
356 receiver: futures::channel::mpsc::UnboundedReceiver<SentMessage<M>>,
357 }
358
359 impl<M> FakeSenderSink<M> {
360 pub(crate) fn take_messages(&mut self) -> Vec<SentMessage<M>> {
361 let mut messages = Vec::new();
362 while let Some(msg_opt) = self.receiver.next().now_or_never() {
363 match msg_opt {
364 Some(msg) => messages.push(msg),
365 None => return messages, };
367 }
368 messages
370 }
371
372 pub(crate) async fn next_message(&mut self) -> SentMessage<M> {
373 self.receiver.next().await.expect("receiver unexpectedly closed")
374 }
375 }
376
377 pub(crate) fn fake_sender_with_sink<M>() -> (FakeSender<M>, FakeSenderSink<M>) {
378 let (sender, receiver) = futures::channel::mpsc::unbounded();
379 (FakeSender { sender }, FakeSenderSink { receiver })
380 }
381
382 #[derive(Default, Debug, Clone)]
383 pub(crate) struct FakeCreds {
384 error: Option<Errno>,
385 }
386
387 impl FakeCreds {
388 pub fn with_error(error: Errno) -> Self {
389 FakeCreds { error: Some(error) }
390 }
391 }
392
393 #[derive(Default, Clone)]
394 pub(crate) struct FakeAccessControl {}
395
396 impl AccessControl<FakeCreds> for FakeAccessControl {
397 fn grant_assess(&self, creds: &FakeCreds, _perm: Permission) -> Result<(), Errno> {
398 if let Some(ref error) = creds.error { Err(*error) } else { Ok(()) }
399 }
400 }
401
402 pub(crate) struct TestNetlinkContext;
403
404 impl NetlinkContext for TestNetlinkContext {
405 type Creds = FakeCreds;
406 type Sender<M: Clone + NetlinkSerializable + Send> = FakeSender<M>;
407 type Receiver<
408 M: Send + MessageWithPermission + NetlinkDeserializable<Error: Into<DecodeError>>,
409 > = mpsc::Receiver<NetlinkMessageWithCreds<NetlinkMessage<M>, Self::Creds>>;
410 type AccessControl<'a> = FakeAccessControl;
411 }
412}