1use async_trait::async_trait;
6use futures::FutureExt;
7use futures::channel::mpsc;
8use futures::future::join;
9use futures::stream::{FuturesUnordered, StreamExt};
10use linux_uapi::{NLM_F_ACK, NLM_F_CAPPED};
11use netlink::NETLINK_LOG_TAG;
12use netlink::messaging::Sender;
13use netlink::multicast_groups::ModernGroup;
14use netlink_packet_core::{
15 ErrorMessage, NETLINK_HEADER_LEN, NetlinkHeader, NetlinkMessage, NetlinkPayload,
16};
17use netlink_packet_generic::constants::GENL_ID_CTRL;
18use netlink_packet_generic::ctrl::nlas::{GenlCtrlAttrs, McastGrpAttrs};
19use netlink_packet_generic::ctrl::{GenlCtrl, GenlCtrlCmd};
20use netlink_packet_generic::{GenlHeader, GenlMessage};
21use netlink_packet_utils::Emitable;
22use starnix_logging::track_stub;
23use starnix_sync::Mutex;
24use std::collections::{HashMap, HashSet};
25use std::num::NonZero;
26use std::ops::DerefMut;
27use std::sync::Arc;
28
29use starnix_logging::{log_error, log_info, log_warn};
30use starnix_uapi::errors::Errno;
31use starnix_uapi::{ENOENT, error};
32
33mod messages;
34mod nl80211;
35mod taskstats;
36
37pub use messages::GenericMessage;
38
39const MIN_FAMILY_ID: u16 = GENL_ID_CTRL + 1;
40const NLCTRL_FAMILY: &str = "nlctrl";
41
42#[async_trait]
43pub trait GenericNetlinkFamily<S>: Send + Sync {
44 fn name(&self) -> String;
48
49 fn multicast_groups(&self) -> Vec<String> {
53 vec![]
54 }
55
56 async fn stream_multicast_messages(
62 &self,
63 group: String,
64 assigned_family_id: u16,
65 message_sink: mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
66 );
67
68 async fn handle_message(&self, netlink_header: NetlinkHeader, payload: Vec<u8>, sender: &mut S);
74}
75
76fn extract_family_names(genl_ctrl: GenlCtrl) -> Vec<String> {
77 genl_ctrl
78 .nlas
79 .into_iter()
80 .filter_map(
81 |attr| {
82 if let GenlCtrlAttrs::FamilyName(name) = attr { Some(name) } else { None }
83 },
84 )
85 .collect()
86}
87
88#[derive(Copy, Clone, Eq, PartialEq, Hash)]
89struct ClientId(u64);
90
91struct GenericNetlinkServerState<S> {
94 family_ids: HashMap<String, u16>,
96 families: Vec<Arc<dyn GenericNetlinkFamily<S>>>,
99 new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>,
101 multicast_groups: HashMap<(String, String), ModernGroup>,
104 multicast_group_id_counter: ModernGroup,
106 client_id_counter: ClientId,
108 client_senders: HashMap<ClientId, S>,
110 multicast_group_memberships: HashMap<ModernGroup, HashSet<ClientId>>,
112}
113
114impl<S: Sender<GenericMessage>> GenericNetlinkServerState<S> {
115 fn new(new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>) -> Self {
116 Self {
117 family_ids: HashMap::new(),
118 families: vec![],
119 new_family_sender,
120 multicast_groups: HashMap::new(),
121 multicast_group_id_counter: ModernGroup(0),
122 client_id_counter: ClientId(0),
123 client_senders: HashMap::new(),
124 multicast_group_memberships: HashMap::new(),
125 }
126 }
127
128 fn add_family(&mut self, family: Arc<dyn GenericNetlinkFamily<S>>) {
129 match (self.families.len() as u16).checked_add(MIN_FAMILY_ID) {
130 Some(new_family_id) => {
131 self.family_ids.insert(family.name(), new_family_id);
132 if let Err(e) = self.new_family_sender.unbounded_send(Arc::clone(&family)) {
133 log_error!(
134 tag = NETLINK_LOG_TAG;
135 "Failed to setup multicast group handling for new generic \
136 netlink family {}: {}",
137 family.name(),
138 e
139 );
140 }
141 self.families.push(family);
142 }
143 None => {
144 log_error!(
145 tag = NETLINK_LOG_TAG;
146 "Failed to add generic netlink family: too many families"
147 );
148 }
149 }
150 }
151
152 fn get_family(&self, family_id: u16) -> Option<Arc<dyn GenericNetlinkFamily<S>>> {
153 if family_id >= MIN_FAMILY_ID
154 && ((family_id - MIN_FAMILY_ID) as usize) < self.families.len()
155 {
156 Some(Arc::clone(&self.families[(family_id - MIN_FAMILY_ID) as usize]))
157 } else {
158 None
159 }
160 }
161
162 fn get_multicast_group_id(&mut self, family: String, group: String) -> ModernGroup {
163 *self.multicast_groups.entry((family, group)).or_insert_with(|| {
164 self.multicast_group_id_counter.0 += 1;
165 self.multicast_group_id_counter
166 })
167 }
168
169 fn handle_ctrl_message(
170 &mut self,
171 mut netlink_header: NetlinkHeader,
172 genl_message: GenlMessage<GenlCtrl>,
173 sender: &mut S,
174 ) {
175 let (genl_header, genl_ctrl) = genl_message.into_parts();
176 match genl_ctrl.cmd {
177 GenlCtrlCmd::GetFamily => {
178 let family_names = extract_family_names(genl_ctrl);
179 log_info!(tag = NETLINK_LOG_TAG; "Netlink GetFamily request: {:?}", family_names);
180
181 for family in &family_names {
182 if family == NLCTRL_FAMILY {
183 self.send_get_family_response(
184 netlink_header,
185 genl_header,
186 NLCTRL_FAMILY,
187 GENL_ID_CTRL,
188 None,
189 sender,
190 );
191 } else if let Some(id) = self.family_ids.get(family).copied() {
192 log_info!(
193 tag = NETLINK_LOG_TAG;
194 "Serving requested netlink family {}",
195 family
196 );
197 let mcast_groups = self
198 .get_family(id)
199 .expect("Known family ID should always exist")
200 .multicast_groups()
201 .into_iter()
202 .map(|name| {
203 vec![
204 McastGrpAttrs::Name(name.clone()),
205 McastGrpAttrs::Id(
206 self.get_multicast_group_id(family.to_string(), name).0,
207 ),
208 ]
209 })
210 .collect();
211 self.send_get_family_response(
212 netlink_header,
213 genl_header,
214 family,
215 id,
216 Some(mcast_groups),
217 sender,
218 );
219 } else {
220 log_warn!(
221 tag = NETLINK_LOG_TAG;
222 "Cannot serve requested netlink family {}",
223 family
224 );
225
226 let mut buffer = [0; NETLINK_HEADER_LEN];
228 netlink_header.emit(&mut buffer[..NETLINK_HEADER_LEN]);
229 let mut error = ErrorMessage::default();
230 error.code = NonZero::new(-(ENOENT as i32));
231 error.header = buffer.to_vec();
232 netlink_header.flags = NLM_F_CAPPED as u16;
233 let mut netlink_message =
234 NetlinkMessage::new(netlink_header, NetlinkPayload::Error(error));
235 netlink_message.finalize();
236 sender.send(netlink_message, None);
237 }
238 }
239 }
240 GenlCtrlCmd::NewFamily => {
241 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewFamily")
242 }
243 GenlCtrlCmd::DelFamily => {
244 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelFamily")
245 }
246 GenlCtrlCmd::NewOps => {
247 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewOps")
248 }
249 GenlCtrlCmd::DelOps => {
250 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelOps")
251 }
252 GenlCtrlCmd::GetOps => {
253 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetOps")
254 }
255 GenlCtrlCmd::NewMcastGrp => {
256 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlNewMcastGrp")
257 }
258 GenlCtrlCmd::DelMcastGrp => {
259 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlDelMcastGrp")
260 }
261 GenlCtrlCmd::GetMcastGrp => {
262 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetMcastGrp")
263 }
264 GenlCtrlCmd::GetPolicy => {
265 track_stub!(TODO("https://fxbug.dev/297431602"), "NetlinkCtrlGetPolicy")
266 }
267 }
268 }
269
270 fn send_get_family_response(
271 &mut self,
272 mut netlink_header: NetlinkHeader,
273 genl_header: GenlHeader,
274 family: &str,
275 id: u16,
276 mcast_groups: Option<Vec<Vec<McastGrpAttrs>>>,
277 sender: &mut S,
278 ) {
279 let mut nlas =
280 vec![GenlCtrlAttrs::FamilyId(id), GenlCtrlAttrs::FamilyName(family.to_string())];
281 mcast_groups.map(|mg| nlas.push(GenlCtrlAttrs::McastGroups(mg)));
282 let resp_ctrl = GenlCtrl { cmd: GenlCtrlCmd::NewFamily, nlas };
283 let orig_flags = netlink_header.flags;
287 netlink_header.flags = 0;
288 let mut genl_message = GenlMessage::from_parts(genl_header, resp_ctrl);
289 genl_message.finalize();
290 let mut message = NetlinkMessage::new(
291 netlink_header,
292 NetlinkPayload::InnerMessage(GenericMessage::Ctrl(genl_message)),
293 );
294 message.finalize();
295 sender.send(message, None);
296 if orig_flags & NLM_F_ACK as u16 != 0 {
298 let mut buffer = [0; NETLINK_HEADER_LEN];
300 netlink_header.flags = NLM_F_CAPPED as u16;
302 netlink_header.emit(&mut buffer[..NETLINK_HEADER_LEN]);
303 let mut ack = ErrorMessage::default();
304 ack.code = None;
307 ack.header = buffer.to_vec();
308 let mut netlink_message =
309 NetlinkMessage::new(netlink_header, NetlinkPayload::Error(ack));
310 netlink_message.finalize();
311 sender.send(netlink_message, None);
312 }
313 }
314}
315
316#[derive(Clone)]
318struct GenericNetlinkServer<S> {
319 state: Arc<Mutex<GenericNetlinkServerState<S>>>,
320}
321
322impl<S: Sender<GenericMessage>> GenericNetlinkServer<S> {
323 fn new(new_family_sender: mpsc::UnboundedSender<Arc<dyn GenericNetlinkFamily<S>>>) -> Self {
324 Self { state: Arc::new(Mutex::new(GenericNetlinkServerState::new(new_family_sender))) }
325 }
326
327 async fn handle_generic_message(
328 &self,
329 message: NetlinkMessage<GenericMessage>,
330 sender: &mut S,
331 ) {
332 let (netlink_header, payload) = message.into_parts();
333 let req = match payload {
334 NetlinkPayload::InnerMessage(p) => p,
335 p => {
336 log_error!(tag = NETLINK_LOG_TAG; "Dropping unexpected netlink payload: {:?}", p);
337 return;
338 }
339 };
340 match req {
341 GenericMessage::Ctrl(ctrl_message) => {
342 self.state.lock().handle_ctrl_message(netlink_header, ctrl_message, sender)
343 }
344 GenericMessage::Other { family: family_id, payload } => {
345 let family = self.state.lock().get_family(family_id);
346 match family {
347 Some(family) => {
348 family.handle_message(netlink_header, payload, sender).await;
349 }
350 None => log_info!(
351 tag = NETLINK_LOG_TAG;
352 "Ignoring generic netlink message with unsupported family {}",
353 family_id,
354 ),
355 }
356 }
357 }
358 }
359
360 async fn run_generic_netlink_client(self, mut client: GenericNetlinkClient<S>) {
361 log_info!(tag = NETLINK_LOG_TAG; "Registered new generic netlink client");
362 loop {
363 match client.receiver.next().await {
364 Some(message) => self.handle_generic_message(message, &mut client.sender).await,
365 None => {
366 log_info!(tag = NETLINK_LOG_TAG; "Generic netlink client exited");
367 let mut state = self.state.lock();
368 for memberships in state.multicast_group_memberships.values_mut() {
369 memberships.remove(&client.client_id);
370 }
371 state.client_senders.remove(&client.client_id);
372 return;
373 }
374 }
375 }
376 }
377
378 async fn pipe_single_multicast_group(
379 &self,
380 mcast_group_id: ModernGroup,
381 mcast_stream: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
382 ) {
383 if self
384 .state
385 .lock()
386 .multicast_group_memberships
387 .insert(mcast_group_id, HashSet::new())
388 .is_some()
389 {
390 log_error!(
391 tag = NETLINK_LOG_TAG;
392 "pipe_single_multicast_group called on group {} but group is already served",
393 mcast_group_id.0
394 );
395 return;
396 }
397 let fut = mcast_stream.for_each(|mcast_message| {
398 let mut state_lock = self.state.lock();
399 let state = state_lock.deref_mut();
400 for client_id in state
401 .multicast_group_memberships
402 .get(&mcast_group_id)
403 .expect("Group memberships should always be present")
404 {
405 if let Some(sender) = state.client_senders.get_mut(client_id) {
406 sender.send(mcast_message.clone(), Some(mcast_group_id));
407 }
408 }
409 async {}
410 });
411 fut.await;
412 }
413
414 async fn pipe_multicast_traffic_for_family(self, family: Arc<dyn GenericNetlinkFamily<S>>) {
415 let unordered = FuturesUnordered::new();
416 let family_name = family.name().to_string();
417 let family_id =
418 *self.state.lock().family_ids.get(&family_name).expect("Failed to get family id");
419 for mcast_group in family.multicast_groups() {
420 let mcast_group_id =
421 self.state.lock().get_multicast_group_id(family_name.clone(), mcast_group.clone());
422 let (sink, receiver) = mpsc::unbounded();
423 unordered.push(family.stream_multicast_messages(mcast_group, family_id, sink));
424 unordered.push(Box::pin(self.pipe_single_multicast_group(mcast_group_id, receiver)));
425 }
426 unordered.collect::<Vec<()>>().await;
427 }
428}
429
430pub(crate) struct GenericNetlinkClient<S> {
431 client_id: ClientId,
432 sender: S,
433 receiver: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
434}
435
436pub struct GenericNetlinkWorkerParams<S: Sender<GenericMessage>> {
437 server: GenericNetlinkServer<S>,
438 new_client_receiver: mpsc::UnboundedReceiver<GenericNetlinkClient<S>>,
439 new_family_receiver: mpsc::UnboundedReceiver<Arc<dyn GenericNetlinkFamily<S>>>,
440}
441
442pub async fn run_generic_netlink_worker<S: Sender<GenericMessage>>(
443 params: GenericNetlinkWorkerParams<S>,
444 enable_nl80211: bool,
445) {
446 let nl80211_family = if enable_nl80211 {
449 Some(nl80211::Nl80211Family::new().expect("Failed to connect to Nl80211 netlink family"))
452 } else {
453 None
454 };
455 let taskstats_family = taskstats::TaskstatsFamily::new();
456 {
457 let mut state = params.server.state.lock();
458 if let Some(nl80211_family) = nl80211_family {
459 state.add_family(Arc::new(nl80211_family));
460 }
461 state.add_family(Arc::new(taskstats_family));
462 }
463
464 run_generic_netlink_worker_internal(params).await
465}
466
467fn run_generic_netlink_worker_internal<S: Sender<GenericMessage>>(
468 params: GenericNetlinkWorkerParams<S>,
469) -> impl std::future::Future<Output = ()> + Send {
470 let GenericNetlinkWorkerParams { server, new_client_receiver, new_family_receiver } = params;
471
472 let server_clone = server.clone();
473 let multicast_fut = new_family_receiver.for_each_concurrent(None, move |family| {
474 server_clone.clone().pipe_multicast_traffic_for_family(family)
475 });
476 let new_client_fut = new_client_receiver
477 .for_each_concurrent(None, move |client| server.clone().run_generic_netlink_client(client));
478
479 join(new_client_fut, multicast_fut).map(|_| ())
480}
481
482pub struct GenericNetlink<S> {
483 server: GenericNetlinkServer<S>,
484 new_client_sender: mpsc::UnboundedSender<GenericNetlinkClient<S>>,
485}
486
487impl<S: Sender<GenericMessage>> GenericNetlink<S> {
488 pub fn new() -> (Self, GenericNetlinkWorkerParams<S>) {
489 let (new_client_sender, new_client_receiver) = mpsc::unbounded();
490 let (new_family_sender, new_family_receiver) = mpsc::unbounded();
491 let server = GenericNetlinkServer::new(new_family_sender);
492 let generic_netlink = Self { server: server.clone(), new_client_sender };
493 let worker_params =
494 GenericNetlinkWorkerParams { server, new_client_receiver, new_family_receiver };
495 (generic_netlink, worker_params)
496 }
497
498 pub fn new_generic_client(
499 &self,
500 sender: S,
501 receiver: mpsc::UnboundedReceiver<NetlinkMessage<GenericMessage>>,
502 ) -> Result<GenericNetlinkClientHandle<S>, anyhow::Error> {
503 let mut state = self.server.state.lock();
504 let client_id = state.client_id_counter;
505 state.client_id_counter.0 += 1;
506 state.client_senders.insert(client_id, sender.clone());
507 let handle = GenericNetlinkClientHandle { client_id, server: self.server.clone() };
508 let new_client = GenericNetlinkClient { client_id, sender, receiver };
509 self.new_client_sender
510 .unbounded_send(new_client)
511 .map_err(|_| anyhow::anyhow!("Failed to connect a new generic netlink client"))?;
512 Ok(handle)
513 }
514}
515
516impl<S: Sender<GenericMessage>> GenericNetlink<S> {
517 pub fn add_family(&self, family: Arc<dyn GenericNetlinkFamily<S>>) {
518 self.server.state.lock().add_family(family)
519 }
520}
521
522pub struct GenericNetlinkClientHandle<S> {
523 client_id: ClientId,
524 server: GenericNetlinkServer<S>,
525}
526
527impl<S> GenericNetlinkClientHandle<S> {
528 pub(crate) fn add_membership(&self, group_id: ModernGroup) -> Result<(), Errno> {
529 let mut state = self.server.state.lock();
530 if let Some(memberships) = state.multicast_group_memberships.get_mut(&group_id) {
531 memberships.insert(self.client_id);
532 Ok(())
533 } else {
534 error!(EINVAL)
535 }
536 }
537}
538
539#[cfg(test)]
540mod test_utils {
541 use super::*;
542 use netlink_packet_core::NetlinkSerializable;
543
544 #[derive(Clone)]
545 pub(crate) struct TestSender<M> {
546 pub messages: Arc<Mutex<Vec<NetlinkMessage<M>>>>,
547 }
548
549 impl<M: Clone + NetlinkSerializable + Send + Sync> Sender<M> for TestSender<M> {
550 fn send(&mut self, message: NetlinkMessage<M>, _group: Option<ModernGroup>) {
551 self.messages.lock().push(message);
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::test_utils::*;
559 use super::*;
560 use assert_matches::assert_matches;
561 use fuchsia_async::TestExecutor;
562 use futures::future::Future;
563 use futures::pin_mut;
564 use netlink_packet_generic::GenlHeader;
565 use std::task::Poll;
566
567 const TEST_FAMILY: &str = "test_family";
568 const MCAST_GROUP_1: &str = "m1";
569 const MCAST_GROUP_2: &str = "m2";
570
571 fn getfamily_request() -> NetlinkMessage<GenericMessage> {
572 let getfamily_ctrl = GenlCtrl {
573 cmd: GenlCtrlCmd::GetFamily,
574 nlas: vec![GenlCtrlAttrs::FamilyName(TEST_FAMILY.to_string())],
575 };
576 let mut genl_message =
577 GenlMessage::new(GenlHeader { cmd: 0, version: 0 }, getfamily_ctrl, GENL_ID_CTRL);
578 genl_message.finalize();
579 let mut netlink_message = NetlinkMessage::new(
580 Default::default(),
581 NetlinkPayload::InnerMessage(GenericMessage::Ctrl(genl_message)),
582 );
583 netlink_message.finalize();
584 netlink_message
585 }
586
587 #[derive(Default)]
588 struct TestFamily {
589 messages_to_server: Mutex<Vec<Vec<u8>>>,
590 multicast_message_sinks:
591 Mutex<HashMap<String, mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>>>,
592 }
593
594 #[async_trait]
595 impl<S> GenericNetlinkFamily<S> for TestFamily {
596 fn name(&self) -> String {
597 TEST_FAMILY.into()
598 }
599
600 fn multicast_groups(&self) -> Vec<String> {
601 vec![MCAST_GROUP_1.to_string(), MCAST_GROUP_2.to_string()]
602 }
603
604 async fn stream_multicast_messages(
605 &self,
606 group: String,
607 _assigned_family_id: u16,
608 message_sink: mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
609 ) {
610 self.multicast_message_sinks.lock().insert(group, message_sink);
611 }
612
613 async fn handle_message(
614 &self,
615 _netlink_header: NetlinkHeader,
616 payload: Vec<u8>,
617 _sender: &mut S,
618 ) {
619 self.messages_to_server.lock().push(payload);
620 }
621 }
622
623 fn start_test_netlink()
624 -> (GenericNetlink<TestSender<GenericMessage>>, impl Future<Output = ()> + Send) {
625 let (netlink, worker_params) = GenericNetlink::new();
626 let worker = run_generic_netlink_worker_internal(worker_params);
627 (netlink, worker)
628 }
629
630 fn netlink_with_test_family() -> (
631 GenericNetlink<TestSender<GenericMessage>>,
632 Arc<TestFamily>,
633 impl Future<Output = ()> + Send,
634 ) {
635 let test_family = Arc::new(TestFamily::default());
636 let (netlink, worker) = start_test_netlink();
637 netlink.server.state.lock().add_family(Arc::clone(&test_family) as _);
638 (netlink, test_family, worker)
639 }
640
641 fn new_client(
642 netlink: &GenericNetlink<TestSender<GenericMessage>>,
643 ) -> (
644 Arc<Mutex<Vec<NetlinkMessage<GenericMessage>>>>,
645 mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
646 GenericNetlinkClientHandle<TestSender<GenericMessage>>,
647 ) {
648 let messages_to_client = Arc::new(Mutex::new(vec![]));
649 let (netlink_sender, receiver) = mpsc::unbounded();
650 let sender = TestSender { messages: messages_to_client.clone() };
651 let client_handle =
652 netlink.new_generic_client(sender, receiver).expect("Failed to add new generic client");
653 (messages_to_client, netlink_sender, client_handle)
654 }
655
656 #[test]
657 fn test_ctrl_getfamily_missing() {
658 let mut exec = TestExecutor::new();
659 let (netlink, worker) = start_test_netlink();
660 pin_mut!(worker);
661 let (messages_to_client, sender, _client_handle) = new_client(&netlink);
662
663 sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
664 assert!(exec.run_until_stalled(&mut worker) == Poll::Pending);
665
666 assert!(messages_to_client.lock().len() == 1);
668
669 let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
670 let err_msg = assert_matches!(payload, NetlinkPayload::Error(m) => m);
671 assert_eq!(err_msg.code, NonZero::new(-2));
672 }
673
674 #[test]
675 fn test_ctrl_getfamily() {
676 let mut exec = TestExecutor::new();
677 let (netlink, _test_family, fut) = netlink_with_test_family();
678 pin_mut!(fut);
679 let (messages_to_client, sender, _client_handle) = new_client(&netlink);
680
681 sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
682 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
683
684 assert!(messages_to_client.lock().len() == 1);
686 let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
687 let (_genl_header, ctrl_payload) = assert_matches!(
688 payload,
689 NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
690 assert_eq!(ctrl_payload.cmd, GenlCtrlCmd::NewFamily);
691 assert!(
692 ctrl_payload
693 .nlas
694 .iter()
695 .any(|nla| *nla == GenlCtrlAttrs::FamilyName(TEST_FAMILY.into()))
696 );
697 assert!(ctrl_payload.nlas.iter().any(|nla| matches!(nla, GenlCtrlAttrs::FamilyId(_))));
698 let multicast_groups = ctrl_payload
699 .nlas
700 .iter()
701 .filter_map(
702 |nla| if let GenlCtrlAttrs::McastGroups(vec) = nla { Some(vec) } else { None },
703 )
704 .next()
705 .expect("No multicast groups");
706 assert_eq!(multicast_groups.len(), 2);
707 assert!(multicast_groups.iter().any(|group| {
708 group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
709 group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_1.into()))
710 }));
711 assert!(multicast_groups.iter().any(|group| {
712 group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
713 group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_2.into()))
714 }));
715 }
716
717 #[test]
718 fn test_ctrl_getfamily_before_and_after_add_family() {
719 let mut exec = TestExecutor::new();
720 let (netlink, fut) = start_test_netlink();
721 pin_mut!(fut);
722 let (messages_to_client, sender, _client_handle) = new_client(&netlink);
723
724 sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
725 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
726
727 assert!(messages_to_client.lock().len() == 1);
729
730 let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
731 let err_msg = assert_matches!(payload, NetlinkPayload::Error(m) => m);
732 assert_eq!(err_msg.code, NonZero::new(-2));
733
734 let test_family = Arc::new(TestFamily::default());
736 netlink.add_family(test_family);
737
738 sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
739 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
740
741 assert!(messages_to_client.lock().len() == 1);
743 let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
744 let (_genl_header, ctrl_payload) = assert_matches!(
745 payload,
746 NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
747 assert_eq!(ctrl_payload.cmd, GenlCtrlCmd::NewFamily);
748 assert!(
749 ctrl_payload
750 .nlas
751 .iter()
752 .any(|nla| *nla == GenlCtrlAttrs::FamilyName(TEST_FAMILY.into()))
753 );
754 assert!(ctrl_payload.nlas.iter().any(|nla| matches!(nla, GenlCtrlAttrs::FamilyId(_))));
755 let multicast_groups = ctrl_payload
756 .nlas
757 .iter()
758 .filter_map(
759 |nla| if let GenlCtrlAttrs::McastGroups(vec) = nla { Some(vec) } else { None },
760 )
761 .next()
762 .expect("No multicast groups");
763 assert_eq!(multicast_groups.len(), 2);
764 assert!(multicast_groups.iter().any(|group| {
765 group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
766 group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_1.into()))
767 }));
768 assert!(multicast_groups.iter().any(|group| {
769 group.iter().any(|attr| matches!(attr, McastGrpAttrs::Id(_)));
770 group.iter().any(|attr| *attr == McastGrpAttrs::Name(MCAST_GROUP_2.into()))
771 }));
772 }
773
774 #[test]
775 fn test_send_family_message() {
776 let mut exec = TestExecutor::new();
777 let (netlink, test_family, fut) = netlink_with_test_family();
778 pin_mut!(fut);
779 let (messages_to_client, sender, _client_handle) = new_client(&netlink);
780
781 sender.unbounded_send(getfamily_request()).expect("Failed to send getfamily request");
782 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
783
784 assert!(messages_to_client.lock().len() == 1);
785 let (_netlink_header, payload) = messages_to_client.lock().pop().unwrap().into_parts();
786 let (_genl_header, ctrl_payload) = assert_matches!(
787 payload,
788 NetlinkPayload::InnerMessage(GenericMessage::Ctrl(m)) => m.into_parts());
789 let family_id = *ctrl_payload
790 .nlas
791 .iter()
792 .filter_map(|nla| if let GenlCtrlAttrs::FamilyId(id) = nla { Some(id) } else { None })
793 .next()
794 .expect("Could not find family id");
795
796 let mut netlink_message = NetlinkMessage::new(
797 Default::default(),
798 NetlinkPayload::InnerMessage(GenericMessage::Other {
799 family: family_id,
800 payload: vec![0, 1, 2, 3],
801 }),
802 );
803 netlink_message.finalize();
804 sender.unbounded_send(netlink_message).expect("Failed to send test family message");
805
806 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
807 assert_eq!(test_family.messages_to_server.lock().len(), 1);
808 }
809
810 #[test]
811 fn test_send_invalid_family_message() {
812 let mut exec = TestExecutor::new();
813 let (netlink, test_family, fut) = netlink_with_test_family();
814 pin_mut!(fut);
815 let (messages_to_client, sender, _client_handle) = new_client(&netlink);
816
817 let mut netlink_message = NetlinkMessage::new(
818 Default::default(),
819 NetlinkPayload::InnerMessage(GenericMessage::Other {
820 family: 1337,
821 payload: vec![0, 1, 2, 3],
822 }),
823 );
824 netlink_message.finalize();
825 sender.unbounded_send(netlink_message).expect("Failed to send test family message");
826 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
827 assert!(test_family.messages_to_server.lock().is_empty());
828 assert!(messages_to_client.lock().is_empty());
829 }
830
831 #[test]
832 fn test_server_gets_multicast_messages() {
833 let mut exec = TestExecutor::new();
834 let (_netlink, test_family, fut) = netlink_with_test_family();
835 pin_mut!(fut);
836 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
837 assert_eq!(test_family.multicast_message_sinks.lock().len(), 2);
838 assert!(test_family.multicast_message_sinks.lock().contains_key(MCAST_GROUP_1));
839 assert!(test_family.multicast_message_sinks.lock().contains_key(MCAST_GROUP_2));
840 }
841
842 #[test]
843 fn test_bad_multicast_subscription_fails() {
844 let mut exec = TestExecutor::new();
845 let (netlink, fut) = start_test_netlink();
846 pin_mut!(fut);
847 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
848
849 let (_messages_to_client, _sender, client_handle) = new_client(&netlink);
850 client_handle
851 .add_membership(ModernGroup(1337))
852 .expect_err("Should not be able to add invalid multicast membership");
853 }
854
855 #[test]
856 fn test_multicast_subscriptions() {
857 let mut exec = TestExecutor::new();
858 let (netlink, test_family, fut) = netlink_with_test_family();
859 pin_mut!(fut);
860 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
861 let (messages_to_client_1, _sender_1, client_handle_1) = new_client(&netlink);
862 let (messages_to_client_2, _sender_2, client_handle_2) = new_client(&netlink);
863 let (messages_to_client_3, _sender_3, _client_handle_3) = new_client(&netlink);
864
865 let mcast_group_1_id = netlink
866 .server
867 .state
868 .lock()
869 .get_multicast_group_id(TEST_FAMILY.to_string(), MCAST_GROUP_1.to_string());
870 client_handle_1.add_membership(mcast_group_1_id).expect("add_membership failed");
871 client_handle_2.add_membership(mcast_group_1_id).expect("add_membership failed");
872
873 assert!(messages_to_client_1.lock().is_empty());
874 assert!(messages_to_client_2.lock().is_empty());
875 assert!(messages_to_client_3.lock().is_empty());
876
877 let message_sink = test_family
878 .multicast_message_sinks
879 .lock()
880 .get(MCAST_GROUP_1)
881 .expect("Failed to find multicast message sender")
882 .clone();
883 let netlink_message = NetlinkMessage::new(NetlinkHeader::default(), NetlinkPayload::Noop);
884 message_sink.unbounded_send(netlink_message).expect("Failed to send message");
885 assert!(exec.run_until_stalled(&mut fut) == Poll::Pending);
886
887 assert_eq!(messages_to_client_1.lock().len(), 1);
889 assert_eq!(messages_to_client_2.lock().len(), 1);
890 assert!(messages_to_client_3.lock().is_empty());
892 }
893}