starnix_core/vfs/socket/socket_generic_netlink/
mod.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use async_trait::async_trait;
6use futures::FutureExt;
7use futures::channel::mpsc;
8use futures::future::join;
9use futures::stream::{FuturesUnordered, StreamExt};
10use linux_uapi::{NLM_F_ACK, NLM_F_CAPPED};
11use netlink::NETLINK_LOG_TAG;
12use netlink::messaging::Sender;
13use netlink::multicast_groups::ModernGroup;
14use netlink_packet_core::{
15    ErrorMessage, NETLINK_HEADER_LEN, NetlinkHeader, NetlinkMessage, NetlinkPayload,
16};
17use netlink_packet_generic::constants::GENL_ID_CTRL;
18use netlink_packet_generic::ctrl::nlas::{GenlCtrlAttrs, McastGrpAttrs};
19use netlink_packet_generic::ctrl::{GenlCtrl, GenlCtrlCmd};
20use netlink_packet_generic::{GenlHeader, GenlMessage};
21use netlink_packet_utils::Emitable;
22use starnix_logging::track_stub;
23use starnix_sync::Mutex;
24use std::collections::{HashMap, HashSet};
25use std::num::NonZero;
26use std::ops::DerefMut;
27use std::sync::Arc;
28
29use starnix_logging::{log_error, log_info, log_warn};
30use starnix_uapi::errors::Errno;
31use starnix_uapi::{ENOENT, error};
32
33mod messages;
34mod nl80211;
35mod taskstats;
36
37pub use messages::GenericMessage;
38
39const MIN_FAMILY_ID: u16 = GENL_ID_CTRL + 1;
40const NLCTRL_FAMILY: &str = "nlctrl";
41
42#[async_trait]
43pub trait GenericNetlinkFamily<S>: Send + Sync {
44    /// Return the unique name for this generic netlink protocol.
45    ///
46    /// This name is used by the ctrl server to identify this server for clients.
47    fn name(&self) -> String;
48
49    /// Return the multicast groups that are supported by this protocol family.
50    ///
51    /// Each multicast group is assigned a unique ID by the ctrl server.
52    fn multicast_groups(&self) -> Vec<String> {
53        vec![]
54    }
55
56    /// Returns a future that pipes messages for the given multicast group into
57    /// the given sink. The parent of this server is responsible for managing
58    /// multicast group memberships and routing these messages appropriately.
59    /// The assigned family id of this multicast group is passed in to be used
60    /// appropriately when constructing messages.
61    async fn stream_multicast_messages(
62        &self,
63        group: String,
64        assigned_family_id: u16,
65        message_sink: mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
66    );
67
68    /// Handle a netlink message targeted to this server.
69    ///
70    /// The given payload contains the generic netlink header and all subsequent data.
71    /// Protocol servers should implement their own generic netlink families and
72    /// deserialize messages using `GenlMessage::<_>::deserialize`.
73    async fn handle_message(&self, netlink_header: NetlinkHeader, payload: Vec<u8>, sender: &mut S);
74}
75
76fn extract_family_names(genl_ctrl: GenlCtrl) -> Vec<String> {
77    genl_ctrl
78        .nlas
79        .into_iter()
80        .filter_map(
81            |attr| {
82                if let GenlCtrlAttrs::FamilyName(name) = attr { Some(name) } else { None }
83            },
84        )
85        .collect()
86}
87
88#[derive(Copy, Clone, Eq, PartialEq, Hash)]
89struct ClientId(u64);
90
91/// All state required to the generic netlink server. This struct assumes
92/// synchronous access, and should be kept inside of a GenericNetlinkServer.
93struct GenericNetlinkServerState<S> {
94    /// Mapping from generic family server name to its assigned ID value.
95    family_ids: HashMap<String, u16>,
96    /// Servers for specific generic netlink families. Servers are stored in this
97    /// list by order of ID, such that protocol N is in servers[N - MIN_FAMILY_ID].
98    families: Vec<Arc<dyn GenericNetlinkFamily<S>>>,
99    /// Sink for new families to setup multicast group handling.
100    new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>,
101    /// Multicast groups, identified by (family name, group name). Multicast
102    /// group IDs are assigned uniquely across all generic families.
103    multicast_groups: HashMap<(String, String), ModernGroup>,
104    /// Counter used to generate unique ID values for all multicast groups.
105    multicast_group_id_counter: ModernGroup,
106    /// Unique internal IDs used to track clients.
107    client_id_counter: ClientId,
108    /// Senders for passing multicast traffic to clients.
109    client_senders: HashMap<ClientId, S>,
110    /// Mapping from multicast group -> list of subscribed client IDs.
111    multicast_group_memberships: HashMap<ModernGroup, HashSet<ClientId>>,
112}
113
114impl<S: Sender<GenericMessage>> GenericNetlinkServerState<S> {
115    fn new(new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>) -> Self {
116        Self {
117            family_ids: HashMap::new(),
118            families: vec![],
119            new_family_sender,
120            multicast_groups: HashMap::new(),
121            multicast_group_id_counter: ModernGroup(0),
122            client_id_counter: ClientId(0),
123            client_senders: HashMap::new(),
124            multicast_group_memberships: HashMap::new(),
125        }
126    }
127
128    fn add_family(&mut self, family: Arc<dyn GenericNetlinkFamily<S>>) {
129        match (self.families.len() as u16).checked_add(MIN_FAMILY_ID) {
130            Some(new_family_id) => {
131                self.family_ids.insert(family.name(), new_family_id);
132                if let Err(e) = self.new_family_sender.unbounded_send(Arc::clone(&family)) {
133                    log_error!(
134                        tag = NETLINK_LOG_TAG;
135                        "Failed to setup multicast group handling for new generic \
136                         netlink family {}: {}",
137                        family.name(),
138                        e
139                    );
140                }
141                self.families.push(family);
142            }
143            None => {
144                log_error!(
145                    tag = NETLINK_LOG_TAG;
146                    "Failed to add generic netlink family: too many families"
147                );
148            }
149        }
150    }
151
152    fn get_family(&self, family_id: u16) -> Option<Arc<dyn GenericNetlinkFamily<S>>> {
153        if family_id >= MIN_FAMILY_ID
154            && ((family_id - MIN_FAMILY_ID) as usize) < self.families.len()
155        {
156            Some(Arc::clone(&self.families[(family_id - MIN_FAMILY_ID) as usize]))
157        } else {
158            None
159        }
160    }
161
162    fn get_multicast_group_id(&mut self, family: String, group: String) -> ModernGroup {
163        *self.multicast_groups.entry((family, group)).or_insert_with(|| {
164            self.multicast_group_id_counter.0 += 1;
165            self.multicast_group_id_counter
166        })
167    }
168
169    fn handle_ctrl_message(
170        &mut self,
171        mut netlink_header: NetlinkHeader,
172        genl_message: GenlMessage<GenlCtrl>,
173        sender: &mut S,
174    ) {
175        let (genl_header, genl_ctrl) = genl_message.into_parts();
176        match genl_ctrl.cmd {
177            GenlCtrlCmd::GetFamily => {
178                let family_names = extract_family_names(genl_ctrl);
179                log_info!(tag = NETLINK_LOG_TAG; "Netlink GetFamily request: {:?}", family_names);
180
181                for family in &family_names {
182                    if family == NLCTRL_FAMILY {
183                        self.send_get_family_response(
184                            netlink_header,
185                            genl_header,
186                            NLCTRL_FAMILY,
187                            GENL_ID_CTRL,
188                            None,
189                            sender,
190                        );
191                    } else if let Some(id) = self.family_ids.get(family).copied() {
192                        log_info!(
193                            tag = NETLINK_LOG_TAG;
194                            "Serving requested netlink family {}",
195                            family
196                        );
197                        let mcast_groups = self
198                            .get_family(id)
199                            .expect("Known family ID should always exist")
200                            .multicast_groups()
201                            .into_iter()
202                            .map(|name| {
203                                vec![
204                                    McastGrpAttrs::Name(name.clone()),
205                                    McastGrpAttrs::Id(
206                                        self.get_multicast_group_id(family.to_string(), name).0,
207                                    ),
208                                ]
209                            })
210                            .collect();
211                        self.send_get_family_response(
212                            netlink_header,
213                            genl_header,
214                            family,
215                            id,
216                            Some(mcast_groups),
217                            sender,
218                        );
219                    } else {
220                        log_warn!(
221                            tag = NETLINK_LOG_TAG;
222                            "Cannot serve requested netlink family {}",
223                            family
224                        );
225
226                        // Send back error message
227                        let mut buffer = [0; NETLINK_HEADER_LEN];
228                        netlink_header.emit(&mut buffer[..NETLINK_HEADER_LEN]);
229                        let mut error = ErrorMessage::default();
230                        error.code = NonZero::new(-(ENOENT as i32));
231                        error.header = buffer.to_vec();
232                        netlink_header.flags = NLM_F_CAPPED as u16;
233                        let mut netlink_message =
234                            NetlinkMessage::new(netlink_header, NetlinkPayload::Error(error));
235                        netlink_message.finalize();
236                        sender.send(netlink_message, None);
237                    }
238                }
239            }
240            GenlCtrlCmd::NewFamily => {
241                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewFamily")
242            }
243            GenlCtrlCmd::DelFamily => {
244                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelFamily")
245            }
246            GenlCtrlCmd::NewOps => {
247                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewOps")
248            }
249            GenlCtrlCmd::DelOps => {
250                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelOps")
251            }
252            GenlCtrlCmd::GetOps => {
253                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetOps")
254            }
255            GenlCtrlCmd::NewMcastGrp => {
256                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewMcastGrp")
257            }
258            GenlCtrlCmd::DelMcastGrp => {
259                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelMcastGrp")
260            }
261            GenlCtrlCmd::GetMcastGrp => {
262                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetMcastGrp")
263            }
264            GenlCtrlCmd::GetPolicy => {
265                track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetPolicy")
266            }
267        }
268    }
269
270    fn send_get_family_response(
271        &mut self,
272        mut netlink_header: NetlinkHeader,
273        genl_header: GenlHeader,
274        family: &str,
275        id: u16,
276        mcast_groups: Option<Vec<Vec<McastGrpAttrs>>>,
277        sender: &mut S,
278    ) {
279        let mut nlas =
280            vec![GenlCtrlAttrs::FamilyId(id), GenlCtrlAttrs::FamilyName(family.to_string())];
281        mcast_groups.map(|mg| nlas.push(GenlCtrlAttrs::McastGroups(mg)));
282        let resp_ctrl = GenlCtrl { cmd: GenlCtrlCmd::NewFamily, nlas };
283        // Flags need to be cleared as we are sending a response back
284        // to the client, not requesting that the client send us
285        // data or ACKs.
286        let orig_flags = netlink_header.flags;
287        netlink_header.flags = 0;
288        let mut genl_message = GenlMessage::from_parts(genl_header, resp_ctrl);
289        genl_message.finalize();
290        let mut message = NetlinkMessage::new(
291            netlink_header,
292            NetlinkPayload::InnerMessage(GenericMessage::Ctrl(genl_message)),
293        );
294        message.finalize();
295        sender.send(message, None);
296        // Conversion is safe because 4 < 65535.
297        if orig_flags & NLM_F_ACK as u16 != 0 {
298            // ACK requested, send ACK
299            let mut buffer = [0; NETLINK_HEADER_LEN];
300            // Conversion is safe because 256 < 65535
301            netlink_header.flags = NLM_F_CAPPED as u16;
302            netlink_header.emit(&mut buffer[..NETLINK_HEADER_LEN]);
303            let mut ack = ErrorMessage::default();
304            // Netlink uses an error payload with no error code to indicate a
305            // successful ack.
306            ack.code = None;
307            ack.header = buffer.to_vec();
308            let mut netlink_message =
309                NetlinkMessage::new(netlink_header, NetlinkPayload::Error(ack));
310            netlink_message.finalize();
311            sender.send(netlink_message, None);
312        }
313    }
314}
315
316/// Coordinates all generic netlink clients and families.
317#[derive(Clone)]
318struct GenericNetlinkServer<S> {
319    state: Arc<Mutex<GenericNetlinkServerState<S>>>,
320}
321
322impl<S: Sender<GenericMessage>> GenericNetlinkServer<S> {
323    fn new(new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>) -> Self {
324        Self { state: Arc::new(Mutex::new(GenericNetlinkServerState::new(new_family_sender))) }
325    }
326
327    async fn handle_generic_message(
328        &self,
329        message: NetlinkMessage<GenericMessage>,
330        sender: &mut S,
331    ) {
332        let (netlink_header, payload) = message.into_parts();
333        let req = match payload {
334            NetlinkPayload::InnerMessage(p) => p,
335            p => {
336                log_error!(tag = NETLINK_LOG_TAG; "Dropping unexpected netlink payload: {:?}", p);
337                return;
338            }
339        };
340        match req {
341            GenericMessage::Ctrl(ctrl_message) => {
342                self.state.lock().handle_ctrl_message(netlink_header, ctrl_message, sender)
343            }
344            GenericMessage::Other { family: family_id, payload } => {
345                let family = self.state.lock().get_family(family_id);
346                match family {
347                    Some(family) => {
348                        family.handle_message(netlink_header, payload, sender).await;
349                    }
350                    None => log_info!(
351                        tag = NETLINK_LOG_TAG;
352                        "Ignoring generic netlink message with unsupported family {}",
353                        family_id,
354                    ),
355                }
356            }
357        }
358    }
359
360    async fn run_generic_netlink_client(self, mut client: GenericNetlinkClient<S>) {
361        log_info!(tag = NETLINK_LOG_TAG; "Registered new generic netlink client");
362        loop {
363            match client.receiver.next().await {
364                Some(message) => self.handle_generic_message(message, &mut client.sender).await,
365                None => {
366                    log_info!(tag = NETLINK_LOG_TAG; "Generic netlink client exited");
367                    let mut state = self.state.lock();
368                    for memberships in state.multicast_group_memberships.values_mut() {
369                        memberships.remove(&client.client_id);
370                    }
371                    state.client_senders.remove(&client.client_id);
372                    return;
373                }
374            }
375        }
376    }
377
378    async fn pipe_single_multicast_group(
379        &self,
380        mcast_group_id: ModernGroup,
381        mcast_stream: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
382    ) {
383        if self
384            .state
385            .lock()
386            .multicast_group_memberships
387            .insert(mcast_group_id, HashSet::new())
388            .is_some()
389        {
390            log_error!(
391                tag = NETLINK_LOG_TAG;
392                "pipe_single_multicast_group called on group {} but group is already served",
393                mcast_group_id.0
394            );
395            return;
396        }
397        let fut = mcast_stream.for_each(|mcast_message| {
398            let mut state_lock = self.state.lock();
399            let state = state_lock.deref_mut();
400            for client_id in state
401                .multicast_group_memberships
402                .get(&mcast_group_id)
403                .expect("Group memberships should always be present")
404            {
405                if let Some(sender) = state.client_senders.get_mut(client_id) {
406                    sender.send(mcast_message.clone(), Some(mcast_group_id));
407                }
408            }
409            async {}
410        });
411        fut.await;
412    }
413
414    async fn pipe_multicast_traffic_for_family(self, family: Arc<dyn GenericNetlinkFamily<S>>) {
415        let unordered = FuturesUnordered::new();
416        let family_name = family.name().to_string();
417        let family_id =
418            *self.state.lock().family_ids.get(&family_name).expect("Failed to get family id");
419        for mcast_group in family.multicast_groups() {
420            let mcast_group_id =
421                self.state.lock().get_multicast_group_id(family_name.clone(), mcast_group.clone());
422            let (sink, receiver) = mpsc::unbounded();
423            unordered.push(family.stream_multicast_messages(mcast_group, family_id, sink));
424            unordered.push(Box::pin(self.pipe_single_multicast_group(mcast_group_id, receiver)));
425        }
426        unordered.collect::<Vec<()>>().await;
427    }
428}
429
430pub(crate) struct GenericNetlinkClient<S> {
431    client_id: ClientId,
432    sender: S,
433    receiver: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
434}
435
436pub struct GenericNetlinkWorkerParams<S: Sender<GenericMessage>> {
437    server: GenericNetlinkServer<S>,
438    new_client_receiver: mpsc::UnboundedReceiver<GenericNetlinkClient<S>>,
439    new_family_receiver: mpsc::UnboundedReceiver<Arc<dyn GenericNetlinkFamily<S>>>,
440}
441
442pub async fn run_generic_netlink_worker<S: Sender<GenericMessage>>(
443    params: GenericNetlinkWorkerParams<S>,
444    enable_nl80211: bool,
445) {
446    // Initialize supported families on the worker, so that they shares an
447    // executor with the main netlink future.
448    let nl80211_family = if enable_nl80211 {
449        // This boolean is tied to availability of the Wlanix protocol, so this
450        // operation will always succeed unless our product config is invalid.
451        Some(nl80211::Nl80211Family::new().expect("Failed to connect to Nl80211 netlink family"))
452    } else {
453        None
454    };
455    let taskstats_family = taskstats::TaskstatsFamily::new();
456    {
457        let mut state = params.server.state.lock();
458        if let Some(nl80211_family) = nl80211_family {
459            state.add_family(Arc::new(nl80211_family));
460        }
461        state.add_family(Arc::new(taskstats_family));
462    }
463
464    run_generic_netlink_worker_internal(params).await
465}
466
467fn run_generic_netlink_worker_internal<S: Sender<GenericMessage>>(
468    params: GenericNetlinkWorkerParams<S>,
469) -> impl std::future::Future<Output = ()> + Send {
470    let GenericNetlinkWorkerParams { server, new_client_receiver, new_family_receiver } = params;
471
472    let server_clone = server.clone();
473    let multicast_fut = new_family_receiver.for_each_concurrent(None, move |family| {
474        server_clone.clone().pipe_multicast_traffic_for_family(family)
475    });
476    let new_client_fut = new_client_receiver
477        .for_each_concurrent(None, move |client| server.clone().run_generic_netlink_client(client));
478
479    join(new_client_fut, multicast_fut).map(|_| ())
480}
481
482pub struct GenericNetlink<S> {
483    server: GenericNetlinkServer<S>,
484    new_client_sender: mpsc::UnboundedSender<GenericNetlinkClient<S>>,
485}
486
487impl<S: Sender<GenericMessage>> GenericNetlink<S> {
488    pub fn new() -> (Self, GenericNetlinkWorkerParams<S>) {
489        let (new_client_sender, new_client_receiver) = mpsc::unbounded();
490        let (new_family_sender, new_family_receiver) = mpsc::unbounded();
491        let server = GenericNetlinkServer::new(new_family_sender);
492        let generic_netlink = Self { server: server.clone(), new_client_sender };
493        let worker_params =
494            GenericNetlinkWorkerParams { server, new_client_receiver, new_family_receiver };
495        (generic_netlink, worker_params)
496    }
497
498    pub fn new_generic_client(
499        &self,
500        sender: S,
501        receiver: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
502    ) -> Result<GenericNetlinkClientHandle<S>, anyhow::Error> {
503        let mut state = self.server.state.lock();
504        let client_id = state.client_id_counter;
505        state.client_id_counter.0 += 1;
506        state.client_senders.insert(client_id, sender.clone());
507        let handle = GenericNetlinkClientHandle { client_id, server: self.server.clone() };
508        let new_client = GenericNetlinkClient { client_id, sender, receiver };
509        self.new_client_sender
510            .unbounded_send(new_client)
511            .map_err(|_| anyhow::anyhow!("Failed to connect a new generic netlink client"))?;
512        Ok(handle)
513    }
514}
515
516impl<S: Sender<GenericMessage>> GenericNetlink<S> {
517    pub fn add_family(&self, family: Arc<dyn GenericNetlinkFamily<S>>) {
518        self.server.state.lock().add_family(family)
519    }
520}
521
522pub struct GenericNetlinkClientHandle<S> {
523    client_id: ClientId,
524    server: GenericNetlinkServer<S>,
525}
526
527impl<S> GenericNetlinkClientHandle<S> {
528    pub(crate) fn add_membership(&self, group_id: ModernGroup) -> Result<(), Errno> {
529        let mut state = self.server.state.lock();
530        if let Some(memberships) = state.multicast_group_memberships.get_mut(&group_id) {
531            memberships.insert(self.client_id);
532            Ok(())
533        } else {
534            error!(EINVAL)
535        }
536    }
537}
538
539#[cfg(test)]
540mod test_utils {
541    use super::*;
542    use netlink_packet_core::NetlinkSerializable;
543
544    #[derive(Clone)]
545    pub(crate) struct TestSender<M> {
546        pub messages: Arc<Mutex<Vec<NetlinkMessage<M>>>>,
547    }
548
549    impl<M: Clone + NetlinkSerializable + Send + Sync> Sender<M> for TestSender<M> {
550        fn send(&mut self, message: NetlinkMessage<M>, _group: Option<ModernGroup>) {
551            self.messages.lock().push(message);
552        }
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::test_utils::*;
559    use super::*;
560    use assert_matches::assert_matches;
561    use fuchsia_async::TestExecutor;
562    use futures::future::Future;
563    use futures::pin_mut;
564    use netlink_packet_generic::GenlHeader;
565    use std::task::Poll;
566
567    const TEST_FAMILY: &str = "test_family";
568    const MCAST_GROUP_1: &str = "m1";
569    const MCAST_GROUP_2: &str = "m2";
570
571    fn getfamily_request() -> NetlinkMessage<GenericMessage> {
572        let getfamily_ctrl = GenlCtrl {
573            cmd: GenlCtrlCmd::GetFamily,
574            nlas: vec![GenlCtrlAttrs::FamilyName(TEST_FAMILY.to_string())],
575        };
576        let mut genl_message =
577            GenlMessage::new(GenlHeader { cmd: 0, version: 0 }, getfamily_ctrl, GENL_ID_CTRL);
578        genl_message.finalize();
579        let mut netlink_message = NetlinkMessage::new(
580            Default::default(),
581            NetlinkPayload::InnerMessage(GenericMessage::Ctrl(genl_message)),
582        );
583        netlink_message.finalize();
584        netlink_message
585    }
586
587    #[derive(Default)]
588    struct TestFamily {
589        messages_to_server: Mutex<Vec<Vec<u8>>>,
590        multicast_message_sinks:
591            Mutex<HashMap<String, mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>>>,
592    }
593
594    #[async_trait]
595    impl<S> GenericNetlinkFamily<S> for TestFamily {
596        fn name(&self) -> String {
597            TEST_FAMILY.into()
598        }
599
600        fn multicast_groups(&self) -> Vec<String> {
601            vec![MCAST_GROUP_1.to_string(), MCAST_GROUP_2.to_string()]
602        }
603
604        async fn stream_multicast_messages(
605            &self,
606            group: String,
607            _assigned_family_id: u16,
608            message_sink: mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
609        ) {
610            self.multicast_message_sinks.lock().insert(group, message_sink);
611        }
612
613        async fn handle_message(
614            &self,
615            _netlink_header: NetlinkHeader,
616            payload: Vec<u8>,
617            _sender: &mut S,
618        ) {
619            self.messages_to_server.lock().push(payload);
620        }
621    }
622
623    fn start_test_netlink()
624    -> (GenericNetlink<TestSender<GenericMessage>>, impl Future<Output = ()> + Send) {
625        let (netlink, worker_params) = GenericNetlink::new();
626        let worker = run_generic_netlink_worker_internal(worker_params);
627        (netlink, worker)
628    }
629
630    fn netlink_with_test_family() -> (
631        GenericNetlink<TestSender<GenericMessage>>,
632        Arc<TestFamily>,
633        impl Future<Output = ()> + Send,
634    ) {
635        let test_family = Arc::new(TestFamily::default());
636        let (netlink, worker) = start_test_netlink();
637        netlink.server.state.lock().add_family(Arc::clone(&test_family) as _);
638        (netlink, test_family, worker)
639    }
640
641    fn new_client(
642        netlink: &GenericNetlink<TestSender<GenericMessage>>,
643    ) -> (
644        Arc<Mutex<Vec<NetlinkMessage<GenericMessage>>>>,
645        mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
646        GenericNetlinkClientHandle<TestSender<GenericMessage>>,
647    ) {
648        let messages_to_client = Arc::new(Mutex::new(vec![]));
649        let (netlink_sender, receiver) = mpsc::unbounded();
650        let sender = TestSender { messages: messages_to_client.clone() };
651        let client_handle =
652            netlink.new_generic_client(sender, receiver).expect("Failed to add new generic client");
653        (messages_to_client, netlink_sender, client_handle)
654    }
655
656    #[test]
657    fn test_ctrl_getfamily_missing() {
658        let mut exec = TestExecutor::new();
659        let (netlink, worker) = start_test_netlink();
660        pin_mut!(worker);
661        let (messages_to_client, sender, _client_handle) = new_client(&netlink);
662
663        sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
664        assert!(exec.run_until_stalled(&mut worker) == Poll::Pending);
665
666        // The family doesn't exist, so an error should be returned.
667        assert!(messages_to_client.lock().len() == 1);
668
669        let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
670        let err_msg = assert_matches!(payload, NetlinkPayload::Error(m) => m);
671        assert_eq!(err_msg.code, NonZero::new(-2));
672    }
673
674    #[test]
675    fn test_ctrl_getfamily() {
676        let mut exec = TestExecutor::new();
677        let (netlink, _test_family, fut) = netlink_with_test_family();
678        pin_mut!(fut);
679        let (messages_to_client, sender, _client_handle) = new_client(&netlink);
680
681        sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
682        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
683
684        // Verify that we got all expected information in the response.
685        assert!(messages_to_client.lock().len() == 1);
686        let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
687        let (_genl_header, ctrl_payload) = assert_matches!(
688            payload,
689            NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
690        assert_eq!(ctrl_payload.cmd, GenlCtrlCmd::NewFamily);
691        assert!(
692            ctrl_payload
693                .nlas
694                .iter()
695                .any(|nla| *nla == GenlCtrlAttrs::FamilyName(TEST_FAMILY.into()))
696        );
697        assert!(ctrl_payload.nlas.iter().any(|nla| matches!(nla, GenlCtrlAttrs::FamilyId(_))));
698        let multicast_groups = ctrl_payload
699            .nlas
700            .iter()
701            .filter_map(
702                |nla| if let GenlCtrlAttrs::McastGroups(vec) = nla { Some(vec) } else { None },
703            )
704            .next()
705            .expect("No multicast groups");
706        assert_eq!(multicast_groups.len(), 2);
707        assert!(multicast_groups.iter().any(|group| {
708            group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
709            group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_1.into()))
710        }));
711        assert!(multicast_groups.iter().any(|group| {
712            group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
713            group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_2.into()))
714        }));
715    }
716
717    #[test]
718    fn test_ctrl_getfamily_before_and_after_add_family() {
719        let mut exec = TestExecutor::new();
720        let (netlink, fut) = start_test_netlink();
721        pin_mut!(fut);
722        let (messages_to_client, sender, _client_handle) = new_client(&netlink);
723
724        sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
725        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
726
727        // The family doesn't exist, so an error should be returned.
728        assert!(messages_to_client.lock().len() == 1);
729
730        let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
731        let err_msg = assert_matches!(payload, NetlinkPayload::Error(m) => m);
732        assert_eq!(err_msg.code, NonZero::new(-2));
733
734        // Add the test family and try again.
735        let test_family = Arc::new(TestFamily::default());
736        netlink.add_family(test_family);
737
738        sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
739        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
740
741        // Verify that we got all expected information in the response.
742        assert!(messages_to_client.lock().len() == 1);
743        let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
744        let (_genl_header, ctrl_payload) = assert_matches!(
745            payload,
746            NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
747        assert_eq!(ctrl_payload.cmd, GenlCtrlCmd::NewFamily);
748        assert!(
749            ctrl_payload
750                .nlas
751                .iter()
752                .any(|nla| *nla == GenlCtrlAttrs::FamilyName(TEST_FAMILY.into()))
753        );
754        assert!(ctrl_payload.nlas.iter().any(|nla| matches!(nla, GenlCtrlAttrs::FamilyId(_))));
755        let multicast_groups = ctrl_payload
756            .nlas
757            .iter()
758            .filter_map(
759                |nla| if let GenlCtrlAttrs::McastGroups(vec) = nla { Some(vec) } else { None },
760            )
761            .next()
762            .expect("No multicast groups");
763        assert_eq!(multicast_groups.len(), 2);
764        assert!(multicast_groups.iter().any(|group| {
765            group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
766            group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_1.into()))
767        }));
768        assert!(multicast_groups.iter().any(|group| {
769            group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
770            group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_2.into()))
771        }));
772    }
773
774    #[test]
775    fn test_send_family_message() {
776        let mut exec = TestExecutor::new();
777        let (netlink, test_family, fut) = netlink_with_test_family();
778        pin_mut!(fut);
779        let (messages_to_client, sender, _client_handle) = new_client(&netlink);
780
781        sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
782        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
783
784        assert!(messages_to_client.lock().len() == 1);
785        let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
786        let (_genl_header, ctrl_payload) = assert_matches!(
787            payload,
788            NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
789        let family_id = *ctrl_payload
790            .nlas
791            .iter()
792            .filter_map(|nla| if let GenlCtrlAttrs::FamilyId(id) = nla { Some(id) } else { None })
793            .next()
794            .expect("Could not find family id");
795
796        let mut netlink_message = NetlinkMessage::new(
797            Default::default(),
798            NetlinkPayload::InnerMessage(GenericMessage::Other {
799                family: family_id,
800                payload: vec![0, 1, 2, 3],
801            }),
802        );
803        netlink_message.finalize();
804        sender.unbounded_send(netlink_message).expect("Failed to send test family message");
805
806        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
807        assert_eq!(test_family.messages_to_server.lock().len(), 1);
808    }
809
810    #[test]
811    fn test_send_invalid_family_message() {
812        let mut exec = TestExecutor::new();
813        let (netlink, test_family, fut) = netlink_with_test_family();
814        pin_mut!(fut);
815        let (messages_to_client, sender, _client_handle) = new_client(&netlink);
816
817        let mut netlink_message = NetlinkMessage::new(
818            Default::default(),
819            NetlinkPayload::InnerMessage(GenericMessage::Other {
820                family: 1337,
821                payload: vec![0, 1, 2, 3],
822            }),
823        );
824        netlink_message.finalize();
825        sender.unbounded_send(netlink_message).expect("Failed to send test family message");
826        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
827        assert!(test_family.messages_to_server.lock().is_empty());
828        assert!(messages_to_client.lock().is_empty());
829    }
830
831    #[test]
832    fn test_server_gets_multicast_messages() {
833        let mut exec = TestExecutor::new();
834        let (_netlink, test_family, fut) = netlink_with_test_family();
835        pin_mut!(fut);
836        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
837        assert_eq!(test_family.multicast_message_sinks.lock().len(), 2);
838        assert!(test_family.multicast_message_sinks.lock().contains_key(MCAST_GROUP_1));
839        assert!(test_family.multicast_message_sinks.lock().contains_key(MCAST_GROUP_2));
840    }
841
842    #[test]
843    fn test_bad_multicast_subscription_fails() {
844        let mut exec = TestExecutor::new();
845        let (netlink, fut) = start_test_netlink();
846        pin_mut!(fut);
847        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
848
849        let (_messages_to_client, _sender, client_handle) = new_client(&netlink);
850        client_handle
851            .add_membership(ModernGroup(1337))
852            .expect_err("Should not be able to add invalid multicast membership");
853    }
854
855    #[test]
856    fn test_multicast_subscriptions() {
857        let mut exec = TestExecutor::new();
858        let (netlink, test_family, fut) = netlink_with_test_family();
859        pin_mut!(fut);
860        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
861        let (messages_to_client_1, _sender_1, client_handle_1) = new_client(&netlink);
862        let (messages_to_client_2, _sender_2, client_handle_2) = new_client(&netlink);
863        let (messages_to_client_3, _sender_3, _client_handle_3) = new_client(&netlink);
864
865        let mcast_group_1_id = netlink
866            .server
867            .state
868            .lock()
869            .get_multicast_group_id(TEST_FAMILY.to_string(), MCAST_GROUP_1.to_string());
870        client_handle_1.add_membership(mcast_group_1_id).expect("add_membership failed");
871        client_handle_2.add_membership(mcast_group_1_id).expect("add_membership failed");
872
873        assert!(messages_to_client_1.lock().is_empty());
874        assert!(messages_to_client_2.lock().is_empty());
875        assert!(messages_to_client_3.lock().is_empty());
876
877        let message_sink = test_family
878            .multicast_message_sinks
879            .lock()
880            .get(MCAST_GROUP_1)
881            .expect("Failed to find multicast message sender")
882            .clone();
883        let netlink_message = NetlinkMessage::new(NetlinkHeader::default(), NetlinkPayload::Noop);
884        message_sink.unbounded_send(netlink_message).expect("Failed to send message");
885        assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
886
887        // All subscribed clients receive the message.
888        assert_eq!(messages_to_client_1.lock().len(), 1);
889        assert_eq!(messages_to_client_2.lock().len(), 1);
890        // Client 3 did not subscribe and should not receive the message.
891        assert!(messages_to_client_3.lock().is_empty());
892    }
893}