1use futures::Stream;
8use netlink_packet_core::{
9 NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload,
10 NetlinkSerializable,
11};
12use netlink_packet_route::RouteNetlinkMessageParseError;
13use netlink_packet_utils::nla::NlaError;
14use netlink_packet_utils::{DecodeError, Parseable};
15use std::fmt::Debug;
16
17use crate::multicast_groups::ModernGroup;
18use crate::netlink_packet;
19use crate::netlink_packet::errno::Errno;
20
21pub trait Sender<M>: Clone + Send + Sync {
23 fn send(&mut self, message: NetlinkMessage<M>, group: Option<ModernGroup>);
30}
31
32pub trait Receiver<M, C>:
36 Stream<Item: UnvalidatedNetlinkMessage<Message = M, Credentials = C>> + Send
37where
38 M: Send + MessageWithPermission,
39 C: Send,
40{
41}
42
43impl<M, C, S> Receiver<M, C> for S
45where
46 M: Send + MessageWithPermission,
47 C: Send,
48 S: Stream<Item: UnvalidatedNetlinkMessage<Message = M, Credentials = C>> + Send,
49{
50}
51
52pub enum Permission {
55 NetlinkRouteRead,
57
58 NetlinkRouteWrite,
60
61 NetlinkSockDiagRead,
63
64 NetlinkSockDiagDestroy,
66}
67
68pub trait AccessControl<C>: Clone {
70 fn grant_assess(&self, creds: &C, permission: Permission) -> Result<(), Errno>;
73}
74
75pub trait MessageWithPermission {
78 fn permission(&self) -> Permission;
80}
81
82#[derive(Debug)]
84pub struct ParseError {
85 pub error: DecodeError,
87 pub header: Option<NetlinkHeader>,
92}
93
94pub trait MaybeParsedNetlinkMessage {
97 type Message: MessageWithPermission;
99
100 fn try_into_parsed(self) -> Result<NetlinkMessage<Self::Message>, ParseError>;
102}
103
104impl<M: MessageWithPermission> MaybeParsedNetlinkMessage for NetlinkMessage<M> {
105 type Message = M;
106 fn try_into_parsed(self) -> Result<NetlinkMessage<M>, ParseError> {
107 Ok(self)
108 }
109}
110
111pub struct UnparsedNetlinkMessage<B, M> {
113 data: B,
114 _marker: std::marker::PhantomData<M>,
115}
116
117impl<B, M> UnparsedNetlinkMessage<B, M> {
118 pub fn new(data: B) -> Self {
120 Self { data, _marker: std::marker::PhantomData }
121 }
122}
123
124impl<M, B> MaybeParsedNetlinkMessage for UnparsedNetlinkMessage<B, M>
125where
126 B: AsRef<[u8]>,
127 M: NetlinkDeserializable + MessageWithPermission,
128 M::Error: Into<DecodeError>,
129{
130 type Message = M;
131
132 fn try_into_parsed(self) -> Result<NetlinkMessage<M>, ParseError> {
133 let Self { data, _marker } = self;
134 let data = data.as_ref();
135 let netlink_buffer =
136 NetlinkBuffer::new(&data).map_err(|error| ParseError { error, header: None })?;
137 NetlinkMessage::<M>::parse(&netlink_buffer).map_err(|error| ParseError {
138 error,
139 header: NetlinkHeader::parse(&netlink_buffer).ok(),
142 })
143 }
144}
145
146#[derive(Debug)]
148#[allow(missing_docs)]
149pub enum ValidationError {
150 Parse(ParseError),
152 Permission { header: NetlinkHeader, error: Errno },
154}
155
156fn nla_error_to_errno(error: &NlaError) -> Errno {
159 match error {
160 NlaError::BufferTooSmall { .. }
161 | NlaError::LengthMismatch { .. }
162 | NlaError::InvalidLength { .. } => Errno::EINVAL,
163 }
164}
165
166fn route_netlink_error_to_errno(error: &RouteNetlinkMessageParseError) -> Errno {
169 match error {
170 RouteNetlinkMessageParseError::ParseBuffer(decode_error)
171 | RouteNetlinkMessageParseError::InvalidLinkMessage(decode_error) => {
172 decode_error_to_errno(decode_error)
173 }
174 RouteNetlinkMessageParseError::InvalidRouteMessage(_)
175 | RouteNetlinkMessageParseError::InvalidAddrMessage(_)
176 | RouteNetlinkMessageParseError::InvalidPrefixMessage(_)
177 | RouteNetlinkMessageParseError::InvalidFibRuleMessage(_)
178 | RouteNetlinkMessageParseError::InvalidTcMessage(_)
179 | RouteNetlinkMessageParseError::InvalidNsidMessage(_)
180 | RouteNetlinkMessageParseError::InvalidNeighbourMessage(_)
181 | RouteNetlinkMessageParseError::InvalidNeighbourTableMessage(_)
182 | RouteNetlinkMessageParseError::InvalidNeighbourDiscoveryUserOptionMessage(_) => {
183 Errno::EINVAL
184 }
185 RouteNetlinkMessageParseError::UnknownMessageType(_) => Errno::ENOTSUP,
186 }
187}
188
189fn decode_error_to_errno(error: &DecodeError) -> Errno {
192 match error {
193 DecodeError::InvalidMACAddress
194 | DecodeError::InvalidIPAddress
195 | DecodeError::Utf8Error(_)
196 | DecodeError::InvalidU8
197 | DecodeError::InvalidU16
198 | DecodeError::InvalidU32
199 | DecodeError::InvalidU64
200 | DecodeError::InvalidU128
201 | DecodeError::InvalidI32
202 | DecodeError::InvalidBufferLength { .. } => Errno::EINVAL,
203 DecodeError::Nla(nla_error) => nla_error_to_errno(nla_error),
204 DecodeError::Other(error) => {
205 if let Some(error) = error.downcast_ref::<RouteNetlinkMessageParseError>() {
206 return route_netlink_error_to_errno(error);
207 }
208 if let Some(error) = error.downcast_ref::<netlink_packet_utils::DecodeError>() {
209 return decode_error_to_errno(error);
210 }
211 Errno::EINVAL
212 }
213 DecodeError::FailedToParseNlMsgError(error)
214 | DecodeError::FailedToParseNlMsgDone(error)
215 | DecodeError::FailedToParseMessageWithType { message_type: _, source: error }
216 | DecodeError::FailedToParseNetlinkHeader(error) => decode_error_to_errno(error),
217 }
218}
219
220impl ValidationError {
221 pub fn into_error_message<M: NetlinkSerializable>(self) -> Option<NetlinkMessage<M>> {
224 match self {
225 ValidationError::Parse(ParseError { error, header }) => {
226 let header = header?;
228 Some(netlink_packet::new_error(Err(decode_error_to_errno(&error)), header))
232 }
233 ValidationError::Permission { header, error } => {
234 Some(netlink_packet::new_error(Err(error), header))
235 }
236 }
237 }
238}
239
240#[derive(Clone, Debug)]
242pub struct NetlinkMessageWithCreds<M, C> {
243 message: M,
244 creds: C,
245}
246
247impl<M, C> NetlinkMessageWithCreds<M, C> {
248 pub fn new(message: M, creds: C) -> Self {
250 Self { message, creds }
251 }
252}
253
254pub trait UnvalidatedNetlinkMessage {
258 type Message;
260 type Credentials;
262
263 fn validate_creds_and_get_message<PS: AccessControl<Self::Credentials>>(
267 self,
268 access_control: &PS,
269 ) -> Result<NetlinkMessage<Self::Message>, ValidationError>;
270}
271
272impl<M, C> UnvalidatedNetlinkMessage for NetlinkMessageWithCreds<M, C>
273where
274 M: MaybeParsedNetlinkMessage,
275 M::Message: MessageWithPermission,
276{
277 type Message = M::Message;
278 type Credentials = C;
279
280 fn validate_creds_and_get_message<PS: AccessControl<C>>(
281 self,
282 access_control: &PS,
283 ) -> Result<NetlinkMessage<M::Message>, ValidationError> {
284 let Self { message, creds } = self;
285 let message = message.try_into_parsed().map_err(ValidationError::Parse)?;
286 let permission = match &message.payload {
287 NetlinkPayload::InnerMessage(msg) => msg.permission(),
288 NetlinkPayload::Done(_)
289 | NetlinkPayload::Error(_)
290 | NetlinkPayload::Noop
291 | NetlinkPayload::Overrun(_) => return Ok(message),
292 };
293
294 access_control
295 .grant_assess(&creds, permission)
296 .map_err(|error| ValidationError::Permission { header: message.header, error })?;
297 Ok(message)
298 }
299}
300
301pub trait NetlinkContext {
303 type Creds: Clone + Send + Debug;
305
306 type Sender<M: Clone + NetlinkSerializable + Send>: Sender<M>;
308
309 type Receiver<M: Send + MessageWithPermission + NetlinkDeserializable<Error: Into<DecodeError>>>: Receiver<M, Self::Creds>;
311
312 type AccessControl<'a>: AccessControl<Self::Creds>;
314}
315
316#[cfg(test)]
317pub(crate) mod testutil {
318 use super::*;
319 use crate::mpsc;
320 use futures::{FutureExt as _, StreamExt as _};
321 use netlink_packet_core::NetlinkSerializable;
322
323 #[derive(Clone, Debug, PartialEq, Eq)]
324 pub(crate) struct SentMessage<M> {
325 pub message: NetlinkMessage<M>,
326 pub group: Option<ModernGroup>,
327 }
328
329 impl<M> SentMessage<M> {
330 pub(crate) fn unicast(message: NetlinkMessage<M>) -> Self {
331 Self { message, group: None }
332 }
333
334 pub(crate) fn multicast(message: NetlinkMessage<M>, group: ModernGroup) -> Self {
335 Self { message, group: Some(group) }
336 }
337 }
338
339 #[derive(Clone, Debug)]
340 pub(crate) struct FakeSender<M> {
341 sender: futures::channel::mpsc::UnboundedSender<SentMessage<M>>,
342 }
343
344 impl<M: Clone + Send + NetlinkSerializable> Sender<M> for FakeSender<M> {
345 fn send(&mut self, message: NetlinkMessage<M>, group: Option<ModernGroup>) {
346 self.sender
347 .unbounded_send(SentMessage { message, group })
348 .expect("unable to send message");
349 }
350 }
351
352 pub(crate) struct FakeSenderSink<M> {
353 receiver: futures::channel::mpsc::UnboundedReceiver<SentMessage<M>>,
354 }
355
356 impl<M> FakeSenderSink<M> {
357 pub(crate) fn take_messages(&mut self) -> Vec<SentMessage<M>> {
358 let mut messages = Vec::new();
359 while let Some(msg_opt) = self.receiver.next().now_or_never() {
360 match msg_opt {
361 Some(msg) => messages.push(msg),
362 None => return messages, };
364 }
365 messages
367 }
368
369 pub(crate) async fn next_message(&mut self) -> SentMessage<M> {
370 self.receiver.next().await.expect("receiver unexpectedly closed")
371 }
372 }
373
374 pub(crate) fn fake_sender_with_sink<M>() -> (FakeSender<M>, FakeSenderSink<M>) {
375 let (sender, receiver) = futures::channel::mpsc::unbounded();
376 (FakeSender { sender }, FakeSenderSink { receiver })
377 }
378
379 #[derive(Default, Debug, Clone)]
380 pub(crate) struct FakeCreds {
381 error: Option<Errno>,
382 }
383
384 impl FakeCreds {
385 pub fn with_error(error: Errno) -> Self {
386 FakeCreds { error: Some(error) }
387 }
388 }
389
390 #[derive(Default, Clone)]
391 pub(crate) struct FakeAccessControl {}
392
393 impl AccessControl<FakeCreds> for FakeAccessControl {
394 fn grant_assess(&self, creds: &FakeCreds, _perm: Permission) -> Result<(), Errno> {
395 if let Some(ref error) = creds.error { Err(*error) } else { Ok(()) }
396 }
397 }
398
399 pub(crate) struct TestNetlinkContext;
400
401 impl NetlinkContext for TestNetlinkContext {
402 type Creds = FakeCreds;
403 type Sender<M: Clone + NetlinkSerializable + Send> = FakeSender<M>;
404 type Receiver<
405 M: Send + MessageWithPermission + NetlinkDeserializable<Error: Into<DecodeError>>,
406 > = mpsc::Receiver<NetlinkMessageWithCreds<NetlinkMessage<M>, Self::Creds>>;
407 type AccessControl<'a> = FakeAccessControl;
408 }
409}