1use alloc::collections::btree_map::Entry;
8use alloc::collections::{BTreeMap, HashMap};
9use core::fmt::{self, Debug, Display};
10use core::num::NonZeroU8;
11use derivative::Derivative;
12use log::debug;
13use net_types::ip::{GenericOverIp, Ip, IpVersionMarker, Mtu};
14use net_types::{SpecifiedAddr, ZonedAddr};
15use netstack3_base::socket::{DualStackIpExt, DualStackRemoteIp, SocketZonedAddrExt as _};
16use netstack3_base::sync::{PrimaryRc, StrongRc, WeakRc};
17use netstack3_base::{
18 AnyDevice, ContextPair, DeviceIdContext, Inspector, InspectorDeviceExt, InspectorExt,
19 IpDeviceAddr, IpExt, Mark, MarkDomain, Marks, ReferenceNotifiers, ReferenceNotifiersExt as _,
20 RemoveResourceResultWithContext, ResourceCounterContext, StrongDeviceIdentifier,
21 TxMetadataBindingsTypes, WeakDeviceIdentifier, ZonedAddressError,
22};
23use netstack3_filter::RawIpBody;
24use packet::{BufferMut, SliceBufViewMut};
25use packet_formats::icmp;
26use packet_formats::ip::{DscpAndEcn, IpPacket};
27use zerocopy::SplitByteSlice;
28
29use crate::internal::raw::counters::RawIpSocketCounters;
30use crate::internal::raw::filter::RawIpSocketIcmpFilter;
31use crate::internal::raw::protocol::RawIpSocketProtocol;
32use crate::internal::raw::state::{RawIpSocketLockedState, RawIpSocketState};
33use crate::internal::socket::{SendOneShotIpPacketError, SocketHopLimits};
34use crate::socket::{
35 IpSockCreateAndSendError, IpSocketHandler, RouteResolutionOptions, SendOptions,
36};
37use crate::DEFAULT_HOP_LIMITS;
38
39mod checksum;
40pub(crate) mod counters;
41pub(crate) mod filter;
42pub(crate) mod protocol;
43pub(crate) mod state;
44
45pub trait RawIpSocketsBindingsTypes: TxMetadataBindingsTypes {
47 type RawIpSocketState<I: Ip>: Send + Sync + Debug;
49}
50
51pub trait RawIpSocketsBindingsContext<I: IpExt, D: StrongDeviceIdentifier>:
53 RawIpSocketsBindingsTypes + Sized
54{
55 fn receive_packet<B: SplitByteSlice>(
57 &self,
58 socket: &RawIpSocketId<I, D::Weak, Self>,
59 packet: &I::Packet<B>,
60 device: &D,
61 );
62}
63
64pub struct RawIpSocketApi<I: Ip, C> {
66 ctx: C,
67 _ip_mark: IpVersionMarker<I>,
68}
69
70impl<I: Ip, C> RawIpSocketApi<I, C> {
71 pub fn new(ctx: C) -> Self {
73 Self { ctx, _ip_mark: IpVersionMarker::new() }
74 }
75}
76
77impl<I: IpExt + DualStackIpExt, C> RawIpSocketApi<I, C>
78where
79 C: ContextPair,
80 C::BindingsContext: RawIpSocketsBindingsTypes + ReferenceNotifiers + 'static,
81 C::CoreContext: RawIpSocketMapContext<I, C::BindingsContext>
82 + RawIpSocketStateContext<I, C::BindingsContext>
83 + ResourceCounterContext<RawIpApiSocketId<I, C>, RawIpSocketCounters<I>>,
84{
85 fn core_ctx(&mut self) -> &mut C::CoreContext {
86 let Self { ctx, _ip_mark } = self;
87 ctx.core_ctx()
88 }
89
90 fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
91 let Self { ctx, _ip_mark } = self;
92 ctx.contexts()
93 }
94
95 pub fn create(
97 &mut self,
98 protocol: RawIpSocketProtocol<I>,
99 external_state: <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
100 ) -> RawIpApiSocketId<I, C> {
101 let socket =
102 PrimaryRawIpSocketId(PrimaryRc::new(RawIpSocketState::new(protocol, external_state)));
103 let strong = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.insert(socket));
104 debug!("created raw IP socket {strong:?}, on protocol {protocol:?}");
105
106 if protocol.requires_system_checksums() {
107 self.core_ctx().with_locked_state_mut(&strong, |state| state.system_checksums = true)
108 }
109
110 strong
111 }
112
113 pub fn close(
115 &mut self,
116 id: RawIpApiSocketId<I, C>,
117 ) -> RemoveResourceResultWithContext<
118 <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
119 C::BindingsContext,
120 > {
121 let primary = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.remove(id));
122 debug!("removed raw IP socket {primary:?}");
123 let PrimaryRawIpSocketId(primary) = primary;
124
125 C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
126 primary,
127 |state: RawIpSocketState<I, _, C::BindingsContext>| state.into_external_state(),
128 )
129 }
130
131 pub fn send_to<B: BufferMut>(
136 &mut self,
137 id: &RawIpApiSocketId<I, C>,
138 remote_ip: Option<
139 ZonedAddr<
140 SpecifiedAddr<I::Addr>,
141 <C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId,
142 >,
143 >,
144 mut body: B,
145 ) -> Result<(), RawIpSocketSendToError> {
146 match id.protocol() {
147 RawIpSocketProtocol::Raw => return Err(RawIpSocketSendToError::ProtocolRaw),
148 RawIpSocketProtocol::Proto(_) => {}
149 }
150 let local_ip = None;
155
156 let remote_ip = match DualStackRemoteIp::<I, _>::new(remote_ip) {
157 DualStackRemoteIp::ThisStack(addr) => addr,
158 DualStackRemoteIp::OtherStack(_addr) => {
159 return Err(RawIpSocketSendToError::MappedRemoteIp)
160 }
161 };
162 let protocol = id.protocol().proto();
163
164 let (core_ctx, bindings_ctx) = self.contexts();
165 let result = core_ctx.with_locked_state_and_socket_handler(id, |state, core_ctx| {
166 let RawIpSocketLockedState {
167 bound_device,
168 icmp_filter: _,
169 hop_limits,
170 multicast_loop,
171 system_checksums,
172 marks,
173 } = state;
174 let (remote_ip, device) = remote_ip
175 .resolve_addr_with_device(bound_device.clone())
176 .map_err(RawIpSocketSendToError::Zone)?;
177 let send_options = RawIpSocketOptions {
178 hop_limits: &hop_limits,
179 multicast_loop: *multicast_loop,
180 marks: &marks,
181 };
182
183 let build_packet_fn =
184 |src_ip: IpDeviceAddr<I::Addr>| -> Result<RawIpBody<_, _>, RawIpSocketSendToError> {
185 if *system_checksums {
186 let buf = SliceBufViewMut::new(body.as_mut());
187 if !checksum::populate_checksum::<I, _>(
188 src_ip.addr(),
189 remote_ip.addr(),
190 protocol,
191 buf,
192 ) {
193 return Err(RawIpSocketSendToError::InvalidBody);
194 }
195 }
196 Ok(RawIpBody::new(protocol, src_ip.addr(), remote_ip.addr(), body))
197 };
198
199 let tx_metadata = Default::default();
202
203 core_ctx
204 .send_oneshot_ip_packet_with_fallible_serializer(
205 bindings_ctx,
206 device.as_ref().map(|d| d.as_ref()),
207 local_ip,
208 remote_ip,
209 protocol,
210 &send_options,
211 tx_metadata,
212 build_packet_fn,
213 )
214 .map_err(|e| match e {
215 SendOneShotIpPacketError::CreateAndSendError { err } => {
216 RawIpSocketSendToError::Ip(err)
217 }
218 SendOneShotIpPacketError::SerializeError(err) => err,
219 })
220 });
221 match &result {
222 Ok(()) => core_ctx
223 .increment_both(&id, |counters: &RawIpSocketCounters<I>| &counters.tx_packets),
224 Err(RawIpSocketSendToError::InvalidBody) => core_ctx
225 .increment_both(&id, |counters: &RawIpSocketCounters<I>| {
226 &counters.tx_checksum_errors
227 }),
228 Err(_) => {}
229 }
230 result
231 }
232
233 pub fn set_device(
241 &mut self,
242 id: &RawIpApiSocketId<I, C>,
243 device: Option<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
244 ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
245 let device = device.map(|strong| strong.downgrade());
246 self.core_ctx()
251 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.bound_device, device))
252 }
253
254 pub fn get_device(
256 &mut self,
257 id: &RawIpApiSocketId<I, C>,
258 ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
259 self.core_ctx().with_locked_state(id, |state| state.bound_device.clone())
260 }
261
262 pub fn set_icmp_filter(
267 &mut self,
268 id: &RawIpApiSocketId<I, C>,
269 filter: Option<RawIpSocketIcmpFilter<I>>,
270 ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
271 debug!("setting ICMP Filter on {id:?}: {filter:?}");
272 if !id.protocol().is_icmp() {
273 return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
274 }
275 Ok(self
276 .core_ctx()
277 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.icmp_filter, filter)))
278 }
279
280 pub fn get_icmp_filter(
285 &mut self,
286 id: &RawIpApiSocketId<I, C>,
287 ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
288 if !id.protocol().is_icmp() {
289 return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
290 }
291 Ok(self.core_ctx().with_locked_state(id, |state| state.icmp_filter.clone()))
292 }
293
294 pub fn set_unicast_hop_limit(
299 &mut self,
300 id: &RawIpApiSocketId<I, C>,
301 new_limit: Option<NonZeroU8>,
302 ) -> Option<NonZeroU8> {
303 self.core_ctx().with_locked_state_mut(id, |state| {
304 core::mem::replace(&mut state.hop_limits.unicast, new_limit)
305 })
306 }
307
308 pub fn get_unicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
310 self.core_ctx().with_locked_state(id, |state| {
311 state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).unicast
312 })
313 }
314
315 pub fn set_multicast_hop_limit(
320 &mut self,
321 id: &RawIpApiSocketId<I, C>,
322 new_limit: Option<NonZeroU8>,
323 ) -> Option<NonZeroU8> {
324 self.core_ctx().with_locked_state_mut(id, |state| {
325 core::mem::replace(&mut state.hop_limits.multicast, new_limit)
326 })
327 }
328
329 pub fn get_multicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
331 self.core_ctx().with_locked_state(id, |state| {
332 state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).multicast
333 })
334 }
335
336 pub fn set_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>, value: bool) -> bool {
340 self.core_ctx()
341 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.multicast_loop, value))
342 }
343
344 pub fn get_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>) -> bool {
346 self.core_ctx().with_locked_state(id, |state| state.multicast_loop)
347 }
348
349 pub fn set_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain, mark: Mark) {
351 self.core_ctx().with_locked_state_mut(id, |state| {
352 *state.marks.get_mut(domain) = mark;
353 })
354 }
355
356 pub fn get_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain) -> Mark {
358 self.core_ctx().with_locked_state(id, |state| *state.marks.get(domain))
359 }
360
361 pub fn inspect<N>(&mut self, inspector: &mut N)
363 where
364 N: Inspector
365 + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
366 {
367 self.core_ctx().with_socket_map_and_state_ctx(|socket_map, core_ctx| {
368 socket_map.iter_sockets().for_each(|socket| {
369 inspector.record_debug_child(socket, |node| {
370 node.record_display("TransportProtocol", socket.protocol().proto());
371 node.record_str("NetworkProtocol", I::NAME);
372 node.record_local_socket_addr::<
374 N,
375 I::Addr,
376 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
377 NoPortMarker,
378 >(None);
379 node.record_remote_socket_addr::<
381 N,
382 I::Addr,
383 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
384 NoPortMarker,
385 >(None);
386 core_ctx.with_locked_state(socket, |state| {
387 let RawIpSocketLockedState {
388 bound_device,
389 icmp_filter,
390 hop_limits: _,
391 multicast_loop: _,
392 marks: _,
393 system_checksums: _,
394 } = state;
395 if let Some(bound_device) = bound_device {
396 N::record_device(node, "BoundDevice", bound_device);
397 } else {
398 node.record_str("BoundDevice", "None");
399 }
400 if let Some(icmp_filter) = icmp_filter {
401 node.record_display("IcmpFilter", icmp_filter);
402 } else {
403 node.record_str("IcmpFilter", "None");
404 }
405 });
406 node.record_child("Counters", |node| {
407 node.delegate_inspectable(socket.state().counters())
408 })
409 })
410 })
411 })
412 }
413}
414
415#[derive(Debug)]
417pub enum RawIpSocketSendToError {
418 ProtocolRaw,
422 MappedRemoteIp,
425 InvalidBody,
429 Zone(ZonedAddressError),
431 Ip(IpSockCreateAndSendError),
433}
434
435#[derive(Debug, PartialEq)]
437pub enum RawIpSocketIcmpFilterError {
438 ProtocolNotIcmp,
440}
441
442struct PrimaryRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
444 PrimaryRc<RawIpSocketState<I, D, BT>>,
445);
446
447impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
448 for PrimaryRawIpSocketId<I, D, BT>
449{
450 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
451 let Self(rc) = self;
452 f.debug_tuple("RawIpSocketId").field(&PrimaryRc::debug_id(rc)).finish()
453 }
454}
455
456#[derive(Derivative, GenericOverIp)]
458#[derivative(Clone(bound = ""), Eq(bound = ""), Hash(bound = ""), PartialEq(bound = ""))]
459#[generic_over_ip(I, Ip)]
460pub struct RawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
461 StrongRc<RawIpSocketState<I, D, BT>>,
462);
463
464impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketId<I, D, BT> {
465 pub fn external_state(&self) -> &BT::RawIpSocketState<I> {
467 let RawIpSocketId(strong_rc) = self;
468 strong_rc.external_state()
469 }
470 pub fn protocol(&self) -> &RawIpSocketProtocol<I> {
472 let RawIpSocketId(strong_rc) = self;
473 strong_rc.protocol()
474 }
475 pub fn downgrade(&self) -> WeakRawIpSocketId<I, D, BT> {
477 let Self(rc) = self;
478 WeakRawIpSocketId(StrongRc::downgrade(rc))
479 }
480 pub fn state(&self) -> &RawIpSocketState<I, D, BT> {
482 let RawIpSocketId(strong_rc) = self;
483 &*strong_rc
484 }
485}
486
487impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
488 for RawIpSocketId<I, D, BT>
489{
490 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
491 let Self(rc) = self;
492 f.debug_tuple("RawIpSocketId").field(&StrongRc::debug_id(rc)).finish()
493 }
494}
495
496pub struct WeakRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
498 WeakRc<RawIpSocketState<I, D, BT>>,
499);
500
501impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
502 for WeakRawIpSocketId<I, D, BT>
503{
504 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
505 let Self(rc) = self;
506 f.debug_tuple("WeakRawIpSocketId").field(&WeakRc::debug_id(rc)).finish()
507 }
508}
509
510type RawIpApiSocketId<I, C> = RawIpSocketId<
512 I,
513 <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
514 <C as ContextPair>::BindingsContext,
515>;
516
517pub trait RawIpSocketStateContext<I: IpExt, BT: RawIpSocketsBindingsTypes>:
521 DeviceIdContext<AnyDevice>
522{
523 type SocketHandler<'a>: IpSocketHandler<
526 I,
527 BT,
528 DeviceId = Self::DeviceId,
529 WeakDeviceId = Self::WeakDeviceId,
530 >;
531
532 fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O>(
535 &mut self,
536 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
537 cb: F,
538 ) -> O;
539
540 fn with_locked_state_and_socket_handler<
543 O,
544 F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>, &mut Self::SocketHandler<'_>) -> O,
545 >(
546 &mut self,
547 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
548 cb: F,
549 ) -> O;
550
551 fn with_locked_state_mut<
554 O,
555 F: FnOnce(&mut RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O,
556 >(
557 &mut self,
558 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
559 cb: F,
560 ) -> O;
561}
562
563#[derive(Derivative)]
567#[derivative(Default(bound = ""))]
568pub struct RawIpSocketMap<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> {
569 sockets: BTreeMap<
580 RawIpSocketProtocol<I>,
581 HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>>,
582 >,
583}
584
585impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketMap<I, D, BT> {
586 fn insert(&mut self, socket: PrimaryRawIpSocketId<I, D, BT>) -> RawIpSocketId<I, D, BT> {
587 let RawIpSocketMap { sockets } = self;
588 let PrimaryRawIpSocketId(primary) = &socket;
589 let strong = RawIpSocketId(PrimaryRc::clone_strong(primary));
590 assert!(sockets
593 .entry(*strong.protocol())
594 .or_default()
595 .insert(strong.clone(), socket)
596 .is_none());
597 strong
598 }
599
600 fn remove(&mut self, socket: RawIpSocketId<I, D, BT>) -> PrimaryRawIpSocketId<I, D, BT> {
601 let RawIpSocketMap { sockets } = self;
605 let protocol = *socket.protocol();
606 match sockets.entry(protocol) {
607 Entry::Vacant(_) => unreachable!(
608 "{socket:?} with protocol {protocol:?} must be present in the socket map"
609 ),
610 Entry::Occupied(mut entry) => {
611 let map = entry.get_mut();
612 let primary = map.remove(&socket).unwrap();
613 if map.is_empty() {
616 let _: HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>> =
617 entry.remove();
618 }
619 primary
620 }
621 }
622 }
623
624 fn iter_sockets(&self) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
625 let RawIpSocketMap { sockets } = self;
626 sockets.values().flat_map(|sockets_by_protocol| sockets_by_protocol.keys())
627 }
628
629 fn iter_sockets_for_protocol(
630 &self,
631 protocol: &RawIpSocketProtocol<I>,
632 ) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
633 let RawIpSocketMap { sockets } = self;
634 sockets.get(protocol).map(|sockets| sockets.keys()).into_iter().flatten()
635 }
636}
637
638pub trait RawIpSocketMapContext<I: IpExt, BT: RawIpSocketsBindingsTypes>:
640 DeviceIdContext<AnyDevice>
641{
642 type StateCtx<'a>: RawIpSocketStateContext<I, BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
645 + ResourceCounterContext<RawIpSocketId<I, Self::WeakDeviceId, BT>, RawIpSocketCounters<I>>;
646
647 fn with_socket_map_and_state_ctx<
649 O,
650 F: FnOnce(&RawIpSocketMap<I, Self::WeakDeviceId, BT>, &mut Self::StateCtx<'_>) -> O,
651 >(
652 &mut self,
653 cb: F,
654 ) -> O;
655 fn with_socket_map_mut<O, F: FnOnce(&mut RawIpSocketMap<I, Self::WeakDeviceId, BT>) -> O>(
657 &mut self,
658 cb: F,
659 ) -> O;
660}
661
662pub trait RawIpSocketHandler<I: IpExt, BC>: DeviceIdContext<AnyDevice> {
664 fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
666 &mut self,
667 bindings_ctx: &mut BC,
668 packet: &I::Packet<B>,
669 device: &Self::DeviceId,
670 );
671}
672
673impl<I, BC, CC> RawIpSocketHandler<I, BC> for CC
674where
675 I: IpExt,
676 BC: RawIpSocketsBindingsContext<I, CC::DeviceId>,
677 CC: RawIpSocketMapContext<I, BC>,
678{
679 fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
680 &mut self,
681 bindings_ctx: &mut BC,
682 packet: &I::Packet<B>,
683 device: &CC::DeviceId,
684 ) {
685 let protocol = RawIpSocketProtocol::new(packet.proto());
686
687 match protocol {
690 RawIpSocketProtocol::Raw => {
691 debug!("received IP packet with raw protocol (IANA Reserved - 255); dropping");
692 return;
693 }
694 RawIpSocketProtocol::Proto(_) => {}
695 };
696
697 self.with_socket_map_and_state_ctx(|socket_map, core_ctx| {
698 socket_map.iter_sockets_for_protocol(&protocol).for_each(|socket| {
699 match core_ctx.with_locked_state(socket, |state| {
700 check_packet_for_delivery(packet, device, state)
701 }) {
702 DeliveryOutcome::Deliver => {
703 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
704 &counters.rx_packets
705 });
706 bindings_ctx.receive_packet(socket, packet, device);
707 }
708 DeliveryOutcome::WrongChecksum => {
709 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
710 &counters.rx_checksum_errors
711 });
712 }
713 DeliveryOutcome::WrongIcmpMessageType => {
714 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
715 &counters.rx_icmp_filtered
716 });
717 }
718 DeliveryOutcome::WrongDevice => {}
719 }
720 })
721 })
722 }
723}
724
725enum DeliveryOutcome {
727 Deliver,
729 WrongDevice,
731 WrongChecksum,
733 WrongIcmpMessageType,
736}
737
738fn check_packet_for_delivery<I: IpExt, D: StrongDeviceIdentifier, B: SplitByteSlice>(
740 packet: &I::Packet<B>,
741 device: &D,
742 socket: &RawIpSocketLockedState<I, D::Weak>,
743) -> DeliveryOutcome {
744 let RawIpSocketLockedState {
745 bound_device,
746 icmp_filter,
747 hop_limits: _,
748 marks: _,
749 multicast_loop: _,
750 system_checksums,
751 } = socket;
752 if bound_device.as_ref().is_some_and(|bound_device| bound_device != device) {
754 return DeliveryOutcome::WrongDevice;
755 }
756
757 if *system_checksums && !checksum::has_valid_checksum::<I, B>(packet) {
762 return DeliveryOutcome::WrongChecksum;
763 }
764
765 if icmp_filter.as_ref().is_some_and(|icmp_filter| {
767 debug_assert!(RawIpSocketProtocol::<I>::new(packet.proto()).is_icmp());
771 match icmp::peek_message_type(packet.body()) {
772 Err(_) => true,
782 Ok(message_type) => !icmp_filter.allows_type(message_type),
783 }
784 }) {
785 return DeliveryOutcome::WrongIcmpMessageType;
786 }
787
788 DeliveryOutcome::Deliver
789}
790
791struct RawIpSocketOptions<'a, I: Ip> {
793 hop_limits: &'a SocketHopLimits<I>,
794 multicast_loop: bool,
795 marks: &'a Marks,
796}
797
798impl<I: Ip> RouteResolutionOptions<I> for RawIpSocketOptions<'_, I> {
799 fn transparent(&self) -> bool {
800 false
801 }
802
803 fn marks(&self) -> &Marks {
804 self.marks
805 }
806}
807
808impl<I: IpExt> SendOptions<I> for RawIpSocketOptions<'_, I> {
809 fn hop_limit(&self, destination: &SpecifiedAddr<I::Addr>) -> Option<NonZeroU8> {
810 self.hop_limits.hop_limit_for_dst(destination)
811 }
812
813 fn multicast_loop(&self) -> bool {
814 self.multicast_loop
815 }
816
817 fn allow_broadcast(&self) -> Option<I::BroadcastMarker> {
818 None
819 }
820
821 fn dscp_and_ecn(&self) -> DscpAndEcn {
822 DscpAndEcn::default()
823 }
824
825 fn mtu(&self) -> Mtu {
826 Mtu::no_limit()
827 }
828}
829
830struct NoPortMarker {}
832
833impl Display for NoPortMarker {
834 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
835 write!(f, "NoPort")
836 }
837}
838
839#[cfg(test)]
840mod test {
841 use super::*;
842
843 use alloc::rc::Rc;
844 use alloc::vec;
845 use alloc::vec::Vec;
846 use assert_matches::assert_matches;
847 use core::cell::RefCell;
848 use core::convert::Infallible as Never;
849 use core::marker::PhantomData;
850 use ip_test_macro::ip_test;
851 use net_types::ip::{IpVersion, Ipv4, Ipv6};
852 use netstack3_base::sync::{DynDebugReferences, Mutex};
853 use netstack3_base::testutil::{
854 FakeStrongDeviceId, FakeTxMetadata, FakeWeakDeviceId, MultipleDevicesId, TestIpExt,
855 };
856 use netstack3_base::{ContextProvider, CounterContext, CtxPair};
857 use packet::{Buf, InnerPacketBuilder as _, ParseBuffer as _, Serializer as _};
858 use packet_formats::icmp::{
859 IcmpEchoReply, IcmpMessage, IcmpPacketBuilder, IcmpZeroCode, Icmpv6MessageType,
860 };
861 use packet_formats::ip::{IpPacketBuilder, IpProto, IpProtoExt, Ipv6Proto};
862 use packet_formats::ipv6::Ipv6Packet;
863 use test_case::test_case;
864
865 use crate::internal::socket::testutil::{FakeIpSocketCtx, InnerFakeIpSocketCtx};
866 use crate::socket::testutil::FakeDeviceConfig;
867 use crate::{SendIpPacketMeta, DEFAULT_HOP_LIMITS};
868
869 #[derive(Derivative, Debug)]
870 #[derivative(Default(bound = ""))]
871 struct FakeExternalSocketState<D> {
872 received_packets: Mutex<Vec<ReceivedIpPacket<D>>>,
874 }
875
876 #[derive(Debug, PartialEq)]
877 struct ReceivedIpPacket<D> {
878 data: Vec<u8>,
879 device: D,
880 }
881
882 #[derive(Derivative)]
883 #[derivative(Default(bound = ""))]
884 struct FakeBindingsCtx<D> {
885 _device_id_type: PhantomData<D>,
886 }
887
888 struct FakeCoreCtxState<I: IpExt, D: FakeStrongDeviceId> {
890 socket_map: Rc<RefCell<RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>>>,
895 ip_socket_ctx: FakeIpSocketCtx<I, D>,
899 counters: RawIpSocketCounters<I>,
901 }
902
903 impl<I: IpExt, D: FakeStrongDeviceId> InnerFakeIpSocketCtx<I, D> for FakeCoreCtxState<I, D> {
904 fn fake_ip_socket_ctx_mut(&mut self) -> &mut FakeIpSocketCtx<I, D> {
905 &mut self.ip_socket_ctx
906 }
907 }
908
909 type FakeCoreCtx<I, D> = netstack3_base::testutil::FakeCoreCtx<
910 FakeCoreCtxState<I, D>,
911 SendIpPacketMeta<I, D, SpecifiedAddr<<I as Ip>::Addr>>,
912 D,
913 >;
914
915 impl<D: FakeStrongDeviceId> TxMetadataBindingsTypes for FakeBindingsCtx<D> {
916 type TxMetadata = FakeTxMetadata;
917 }
918
919 impl<D: FakeStrongDeviceId> RawIpSocketsBindingsTypes for FakeBindingsCtx<D> {
920 type RawIpSocketState<I: Ip> = FakeExternalSocketState<D>;
921 }
922
923 impl<I: IpExt, D: Copy + FakeStrongDeviceId> RawIpSocketsBindingsContext<I, D>
924 for FakeBindingsCtx<D>
925 {
926 fn receive_packet<B: SplitByteSlice>(
927 &self,
928 socket: &RawIpSocketId<I, D::Weak, Self>,
929 packet: &I::Packet<B>,
930 device: &D,
931 ) {
932 let packet = ReceivedIpPacket { data: packet.to_vec(), device: *device };
933 let FakeExternalSocketState { received_packets } = socket.external_state();
934 received_packets.lock().push(packet);
935 }
936 }
937
938 impl<I: IpExt, D: FakeStrongDeviceId> RawIpSocketStateContext<I, FakeBindingsCtx<D>>
939 for FakeCoreCtx<I, D>
940 {
941 type SocketHandler<'a> = FakeCoreCtx<I, D>;
942 fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, D::Weak>) -> O>(
943 &mut self,
944 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
945 cb: F,
946 ) -> O {
947 let RawIpSocketId(state_rc) = id;
948 let guard = state_rc.locked_state().read();
949 cb(&guard)
950 }
951 fn with_locked_state_and_socket_handler<
952 O,
953 F: FnOnce(&RawIpSocketLockedState<I, D::Weak>, &mut Self::SocketHandler<'_>) -> O,
954 >(
955 &mut self,
956 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
957 cb: F,
958 ) -> O {
959 let RawIpSocketId(state_rc) = id;
960 let guard = state_rc.locked_state().read();
961 cb(&guard, self)
962 }
963 fn with_locked_state_mut<O, F: FnOnce(&mut RawIpSocketLockedState<I, D::Weak>) -> O>(
964 &mut self,
965 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
966 cb: F,
967 ) -> O {
968 let RawIpSocketId(state_rc) = id;
969 let mut guard = state_rc.locked_state().write();
970 cb(&mut guard)
971 }
972 }
973
974 impl<I: IpExt, D: FakeStrongDeviceId> CounterContext<RawIpSocketCounters<I>> for FakeCoreCtx<I, D> {
975 fn counters(&self) -> &RawIpSocketCounters<I> {
976 &self.state.counters
977 }
978 }
979
980 impl<I: IpExt, D: FakeStrongDeviceId>
981 ResourceCounterContext<
982 RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
983 RawIpSocketCounters<I>,
984 > for FakeCoreCtx<I, D>
985 {
986 fn per_resource_counters<'a>(
987 &'a self,
988 socket: &'a RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
989 ) -> &'a RawIpSocketCounters<I> {
990 socket.state().counters()
991 }
992 }
993
994 impl<I: IpExt, D: FakeStrongDeviceId> RawIpSocketMapContext<I, FakeBindingsCtx<D>>
995 for FakeCoreCtx<I, D>
996 {
997 type StateCtx<'a> = FakeCoreCtx<I, D>;
998 fn with_socket_map_and_state_ctx<
999 O,
1000 F: FnOnce(&RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>, &mut Self::StateCtx<'_>) -> O,
1001 >(
1002 &mut self,
1003 cb: F,
1004 ) -> O {
1005 let socket_map = self.state.socket_map.clone();
1006 let borrow = socket_map.borrow();
1007 cb(&borrow, self)
1008 }
1009 fn with_socket_map_mut<
1010 O,
1011 F: FnOnce(&mut RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>) -> O,
1012 >(
1013 &mut self,
1014 cb: F,
1015 ) -> O {
1016 cb(&mut self.state.socket_map.borrow_mut())
1017 }
1018 }
1019
1020 impl<D> ContextProvider for FakeBindingsCtx<D> {
1021 type Context = FakeBindingsCtx<D>;
1022 fn context(&mut self) -> &mut Self::Context {
1023 self
1024 }
1025 }
1026
1027 impl<D> ReferenceNotifiers for FakeBindingsCtx<D> {
1028 type ReferenceReceiver<T: 'static> = Never;
1029
1030 type ReferenceNotifier<T: Send + 'static> = Never;
1031
1032 fn new_reference_notifier<T: Send + 'static>(
1033 _debug_references: DynDebugReferences,
1034 ) -> (Self::ReferenceNotifier<T>, Self::ReferenceReceiver<T>) {
1035 unimplemented!("raw IP socket removal shouldn't be deferred in tests");
1036 }
1037 }
1038
1039 fn new_raw_ip_socket_api<I: IpExt + TestIpExt>() -> RawIpSocketApi<
1040 I,
1041 CtxPair<FakeCoreCtx<I, MultipleDevicesId>, FakeBindingsCtx<MultipleDevicesId>>,
1042 > {
1043 let device_configs = [MultipleDevicesId::A, MultipleDevicesId::B, MultipleDevicesId::C]
1045 .into_iter()
1046 .map(|device| FakeDeviceConfig {
1047 device,
1048 local_ips: vec![I::TEST_ADDRS.local_ip],
1049 remote_ips: vec![I::TEST_ADDRS.remote_ip],
1050 });
1051 let state = FakeCoreCtxState {
1052 socket_map: Default::default(),
1053 ip_socket_ctx: FakeIpSocketCtx::new(device_configs),
1054 counters: Default::default(),
1055 };
1056
1057 RawIpSocketApi::new(CtxPair::with_core_ctx(FakeCoreCtx::with_state(state)))
1058 }
1059
1060 const IP_BODY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
1062
1063 fn new_ip_packet_buf<I: IpExt + TestIpExt>(
1065 ip_body: &[u8],
1066 proto: I::Proto,
1067 ) -> impl AsRef<[u8]> {
1068 const TTL: u8 = 255;
1069 ip_body
1070 .into_serializer()
1071 .encapsulate(I::PacketBuilder::new(
1072 *I::TEST_ADDRS.local_ip,
1073 *I::TEST_ADDRS.remote_ip,
1074 TTL,
1075 proto,
1076 ))
1077 .serialize_vec_outer()
1078 .unwrap()
1079 }
1080
1081 fn new_icmp_message_buf<I: IpExt + TestIpExt, M: IcmpMessage<I> + Debug>(
1083 message: M,
1084 code: M::Code,
1085 ) -> impl AsRef<[u8]> {
1086 [].into_serializer()
1087 .encapsulate(IcmpPacketBuilder::new(
1088 *I::TEST_ADDRS.local_ip,
1089 *I::TEST_ADDRS.remote_ip,
1090 code,
1091 message,
1092 ))
1093 .serialize_vec_outer()
1094 .unwrap()
1095 }
1096
1097 #[ip_test(I)]
1098 #[test_case(IpProto::Udp; "UDP")]
1099 #[test_case(IpProto::Reserved; "IPPROTO_RAW")]
1100 fn create_and_close<I: IpExt + DualStackIpExt + TestIpExt>(proto: IpProto) {
1101 let mut api = new_raw_ip_socket_api::<I>();
1102 let sock = api.create(RawIpSocketProtocol::new(proto.into()), Default::default());
1103 let FakeExternalSocketState { received_packets: _ } = api.close(sock).into_removed();
1104 }
1105
1106 #[ip_test(I)]
1107 fn set_device<I: IpExt + DualStackIpExt + TestIpExt>() {
1108 let mut api = new_raw_ip_socket_api::<I>();
1109 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1110
1111 assert_eq!(api.get_device(&sock), None);
1112 assert_eq!(api.set_device(&sock, Some(&MultipleDevicesId::A)), None);
1113 assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::A)));
1114 assert_eq!(
1115 api.set_device(&sock, Some(&MultipleDevicesId::B)),
1116 Some(FakeWeakDeviceId(MultipleDevicesId::A))
1117 );
1118 assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1119 assert_eq!(api.set_device(&sock, None), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1120 assert_eq!(api.get_device(&sock), None);
1121 }
1122
1123 #[ip_test(I)]
1124 fn set_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt>() {
1125 let filter1 = RawIpSocketIcmpFilter::<I>::new([123; 32]);
1126 let filter2 = RawIpSocketIcmpFilter::<I>::new([234; 32]);
1127 let mut api = new_raw_ip_socket_api::<I>();
1128
1129 let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1130 assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1131 assert_eq!(api.set_icmp_filter(&sock, Some(filter1.clone())), Ok(None));
1132 assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter1.clone())));
1133 assert_eq!(api.set_icmp_filter(&sock, Some(filter2.clone())), Ok(Some(filter1.clone())));
1134 assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter2.clone())));
1135 assert_eq!(api.set_icmp_filter(&sock, None), Ok(Some(filter2)));
1136 assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1137
1138 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1140 assert_eq!(
1141 api.set_icmp_filter(&sock, Some(filter1)),
1142 Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp)
1143 );
1144 assert_eq!(api.get_icmp_filter(&sock), Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp));
1145 }
1146
1147 #[ip_test(I)]
1148 fn set_unicast_hop_limits<I: IpExt + DualStackIpExt + TestIpExt>() {
1149 let mut api = new_raw_ip_socket_api::<I>();
1150 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1151
1152 let limit1 = NonZeroU8::new(1).unwrap();
1153 let limit2 = NonZeroU8::new(2).unwrap();
1154
1155 assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1156 assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit1)), None);
1157 assert_eq!(api.get_unicast_hop_limit(&sock), limit1);
1158 assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1159 assert_eq!(api.get_unicast_hop_limit(&sock), limit2);
1160 assert_eq!(api.set_unicast_hop_limit(&sock, None), Some(limit2));
1161 assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1162 }
1163
1164 #[ip_test(I)]
1165 fn set_multicast_hop_limit<I: IpExt + DualStackIpExt + TestIpExt>() {
1166 let mut api = new_raw_ip_socket_api::<I>();
1167 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1168
1169 let limit1 = NonZeroU8::new(1).unwrap();
1170 let limit2 = NonZeroU8::new(2).unwrap();
1171
1172 assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1173 assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit1)), None);
1174 assert_eq!(api.get_multicast_hop_limit(&sock), limit1);
1175 assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1176 assert_eq!(api.get_multicast_hop_limit(&sock), limit2);
1177 assert_eq!(api.set_multicast_hop_limit(&sock, None), Some(limit2));
1178 assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1179 }
1180
1181 #[ip_test(I)]
1182 fn set_multicast_loop<I: IpExt + DualStackIpExt + TestIpExt>() {
1183 let mut api = new_raw_ip_socket_api::<I>();
1184 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1185
1186 assert_eq!(api.get_multicast_loop(&sock), true);
1188 assert_eq!(api.set_multicast_loop(&sock, false), true);
1189 assert_eq!(api.get_multicast_loop(&sock), false);
1190 assert_eq!(api.set_multicast_loop(&sock, true), false);
1191 assert_eq!(api.get_multicast_loop(&sock), true);
1192 }
1193
1194 #[ip_test(I)]
1195 fn receive_ip_packet<I: IpExt + DualStackIpExt + TestIpExt>() {
1196 let mut api = new_raw_ip_socket_api::<I>();
1197
1198 let proto: I::Proto = IpProto::Udp.into();
1201 let wrong_proto: I::Proto = IpProto::Tcp.into();
1202 let sock1 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1203 let sock2 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1204 let wrong_sock = api.create(RawIpSocketProtocol::new(wrong_proto), Default::default());
1205
1206 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1208 let buf = new_ip_packet_buf::<I>(&IP_BODY, proto);
1209 let mut buf_ref = buf.as_ref();
1210 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1211 {
1212 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1213 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &DEVICE);
1214 }
1215
1216 assert_eq!(api.core_ctx().state.counters.rx_packets.get(), 2);
1218 assert_eq!(sock1.state().counters().rx_packets.get(), 1);
1219 assert_eq!(sock2.state().counters().rx_packets.get(), 1);
1220 assert_eq!(wrong_sock.state().counters().rx_packets.get(), 0);
1221
1222 let FakeExternalSocketState { received_packets: sock1_packets } =
1223 api.close(sock1).into_removed();
1224 let FakeExternalSocketState { received_packets: sock2_packets } =
1225 api.close(sock2).into_removed();
1226 let FakeExternalSocketState { received_packets: wrong_sock_packets } =
1227 api.close(wrong_sock).into_removed();
1228
1229 for packets in [sock1_packets, sock2_packets] {
1231 let lock_guard = packets.lock();
1232 let ReceivedIpPacket { data, device } =
1233 assert_matches!(&lock_guard[..], [packet] => packet);
1234 assert_eq!(&data[..], buf.as_ref());
1235 assert_eq!(*device, DEVICE);
1236 }
1237 assert_matches!(&wrong_sock_packets.lock()[..], []);
1238 }
1239
1240 #[ip_test(I)]
1243 fn cannot_receive_ip_packet_with_proto_raw<I: IpExt + DualStackIpExt + TestIpExt>() {
1244 let mut api = new_raw_ip_socket_api::<I>();
1245 let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1246
1247 let protocols_to_test = match I::VERSION {
1250 IpVersion::V4 => vec![IpProto::Udp, IpProto::Reserved],
1251 IpVersion::V6 => vec![IpProto::Udp],
1254 };
1255 for proto in protocols_to_test {
1256 let buf = new_ip_packet_buf::<I>(&IP_BODY, proto.into());
1257 let mut buf_ref = buf.as_ref();
1258 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1259 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1260 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1261 }
1262
1263 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1264 assert_matches!(&received_packets.lock()[..], []);
1265 }
1266
1267 #[ip_test(I)]
1268 #[test_case(MultipleDevicesId::A, None, true; "no_bound_device")]
1269 #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::A), true; "bound_same_device")]
1270 #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::B), false; "bound_diff_device")]
1271 fn receive_ip_packet_with_bound_device<I: IpExt + DualStackIpExt + TestIpExt>(
1272 send_dev: MultipleDevicesId,
1273 bound_dev: Option<MultipleDevicesId>,
1274 should_deliver: bool,
1275 ) {
1276 const PROTO: IpProto = IpProto::Udp;
1277 let mut api = new_raw_ip_socket_api::<I>();
1278 let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1279
1280 assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1281
1282 let buf = new_ip_packet_buf::<I>(&IP_BODY, PROTO.into());
1284 let mut buf_ref = buf.as_ref();
1285 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1286 {
1287 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1288 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &send_dev);
1289 }
1290
1291 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1293 if should_deliver {
1294 let lock_guard = received_packets.lock();
1295 let ReceivedIpPacket { data, device } =
1296 assert_matches!(&lock_guard[..], [packet] => packet);
1297 assert_eq!(&data[..], buf.as_ref());
1298 assert_eq!(*device, send_dev);
1299 } else {
1300 assert_matches!(&received_packets.lock()[..], []);
1301 }
1302 }
1303
1304 #[ip_test(I)]
1305 #[test_case(None, true; "no_filter")]
1308 #[test_case(Some(RawIpSocketIcmpFilter::<I>::ALLOW_ALL), true; "allow_all")]
1309 #[test_case(Some(RawIpSocketIcmpFilter::<I>::DENY_ALL), false; "deny_all")]
1310 fn receive_ip_packet_with_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt>(
1311 filter: Option<RawIpSocketIcmpFilter<I>>,
1312 should_deliver: bool,
1313 ) {
1314 let mut api = new_raw_ip_socket_api::<I>();
1315 let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1316
1317 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1318 assert_eq!(core_ctx.state.counters.rx_icmp_filtered.get(), count);
1319 assert_eq!(sock.state().counters().rx_icmp_filtered.get(), count);
1320 };
1321 assert_counters(api.core_ctx(), 0);
1322
1323 assert_matches!(api.set_icmp_filter(&sock, filter), Ok(None));
1324
1325 let icmp_body = new_icmp_message_buf::<I, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode);
1327 let buf = new_ip_packet_buf::<I>(icmp_body.as_ref(), I::ICMP_IP_PROTO);
1328 let mut buf_ref = buf.as_ref();
1329 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1330 {
1331 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1332 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1333 }
1334
1335 assert_counters(api.core_ctx(), should_deliver.then_some(0).unwrap_or(1));
1337 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1338 if should_deliver {
1339 let lock_guard = received_packets.lock();
1340 let ReceivedIpPacket { data, device: _ } =
1341 assert_matches!(&lock_guard[..], [packet] => packet);
1342 assert_eq!(&data[..], buf.as_ref());
1343 } else {
1344 assert_matches!(&received_packets.lock()[..], []);
1345 }
1346 }
1347
1348 #[test]
1352 fn do_not_receive_icmpv6_packet_with_bad_checksum() {
1353 let mut api = new_raw_ip_socket_api::<Ipv6>();
1354 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1355
1356 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1357 assert_eq!(core_ctx.state.counters.rx_checksum_errors.get(), count);
1358 assert_eq!(sock.state().counters().rx_checksum_errors.get(), count);
1359 };
1360 assert_counters(api.core_ctx(), 0);
1361
1362 let mut icmp_body = new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1365 .as_ref()
1366 .to_vec();
1367 const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1368 assert_ne!(
1369 packet_formats::testutil::overwrite_icmpv6_checksum(
1370 icmp_body.as_mut(),
1371 CORRUPT_CHECKSUM
1372 )
1373 .expect("parse should succeed"),
1374 CORRUPT_CHECKSUM
1375 );
1376
1377 let buf = new_ip_packet_buf::<Ipv6>(icmp_body.as_ref(), Ipv6Proto::Icmpv6);
1378 let mut buf_ref = buf.as_ref();
1379 let packet = buf_ref.parse::<Ipv6Packet<_>>().expect("parse should succeed");
1380 {
1381 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1382 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1383 }
1384
1385 assert_counters(api.core_ctx(), 1);
1387 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1388 assert_matches!(&received_packets.lock()[..], []);
1389 }
1390
1391 #[ip_test(I)]
1392 #[test_case(None, None; "default_send")]
1393 #[test_case(Some(MultipleDevicesId::A), None; "with_bound_dev")]
1394 #[test_case(None, Some(123); "with_hop_limit")]
1395 fn send_to<I: IpExt + DualStackIpExt + TestIpExt>(
1396 bound_dev: Option<MultipleDevicesId>,
1397 hop_limit: Option<u8>,
1398 ) {
1399 const PROTO: IpProto = IpProto::Udp;
1400 let mut api = new_raw_ip_socket_api::<I>();
1401 let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1402
1403 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1404 assert_eq!(core_ctx.state.counters.tx_packets.get(), count);
1405 assert_eq!(sock.state().counters().tx_packets.get(), count);
1406 };
1407 assert_counters(api.core_ctx(), 0);
1408
1409 assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1410 let hop_limit = hop_limit.and_then(NonZeroU8::new);
1411 assert_eq!(api.set_unicast_hop_limit(&sock, hop_limit), None);
1412
1413 let remote_ip = ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip);
1414 assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1415 api.send_to(&sock, Some(remote_ip), Buf::new(IP_BODY.to_vec(), ..))
1416 .expect("send should succeed");
1417 let frames = api.core_ctx().take_frames();
1418 let (SendIpPacketMeta { device, src_ip, dst_ip, proto, mtu, ttl, .. }, data) =
1419 assert_matches!( &frames[..], [packet] => packet);
1420 assert_eq!(&data[..], &IP_BODY[..]);
1421 assert_eq!(*dst_ip, remote_ip.addr());
1422 assert_eq!(*src_ip, I::TEST_ADDRS.local_ip);
1423 if let Some(bound_dev) = bound_dev {
1424 assert_eq!(*device, bound_dev);
1425 }
1426 assert_eq!(*proto, <I as IpProtoExt>::Proto::from(PROTO));
1427 assert_eq!(*mtu, Mtu::max());
1428 assert_eq!(*ttl, hop_limit);
1429
1430 assert_counters(api.core_ctx(), 1);
1431 }
1432
1433 #[ip_test(I)]
1434 fn send_to_disallows_raw_protocol<I: IpExt + DualStackIpExt + TestIpExt>() {
1435 let mut api = new_raw_ip_socket_api::<I>();
1436 let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1437 assert_matches!(
1438 api.send_to(&sock, None, Buf::new(IP_BODY.to_vec(), ..)),
1439 Err(RawIpSocketSendToError::ProtocolRaw)
1440 );
1441 }
1442
1443 #[test]
1444 fn send_to_disallows_dualstack() {
1445 let mut api = new_raw_ip_socket_api::<Ipv6>();
1446 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1447 let mapped_remote_ip = ZonedAddr::Unzoned(Ipv4::TEST_ADDRS.local_ip.to_ipv6_mapped());
1448 assert_matches!(
1449 api.send_to(&sock, Some(mapped_remote_ip), Buf::new(IP_BODY.to_vec(), ..)),
1450 Err(RawIpSocketSendToError::MappedRemoteIp)
1451 );
1452 }
1453
1454 #[test]
1457 fn icmpv6_send_to_generates_checksum() {
1458 let mut api = new_raw_ip_socket_api::<Ipv6>();
1459 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1460
1461 let icmp_body_with_checksum =
1464 new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1465 .as_ref()
1466 .to_vec();
1467 const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1468 let mut icmp_body_without_checksum = icmp_body_with_checksum.clone();
1469 assert_ne!(
1470 packet_formats::testutil::overwrite_icmpv6_checksum(
1471 icmp_body_without_checksum.as_mut(),
1472 CORRUPT_CHECKSUM,
1473 )
1474 .expect("parse should succeed"),
1475 CORRUPT_CHECKSUM
1476 );
1477
1478 let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1480 assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1481 api.send_to(&sock, Some(remote_ip), Buf::new(icmp_body_without_checksum.to_vec(), ..))
1482 .expect("send should succeed");
1483
1484 let frames = api.core_ctx().take_frames();
1486 let (_send_ip_packet_meta, data) = assert_matches!( &frames[..], [packet] => packet);
1487 assert_eq!(&data[..], icmp_body_with_checksum);
1488 }
1489
1490 #[test_case(Icmpv6MessageType::DestUnreachable.into(), 4; "header-too-short")]
1492 #[test_case(0, 8; "message-type-zero-not-supported")]
1493 fn icmpv6_send_to_invalid_body(message_type: u8, header_len: usize) {
1494 let mut api = new_raw_ip_socket_api::<Ipv6>();
1495 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1496
1497 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1498 assert_eq!(core_ctx.state.counters.tx_checksum_errors.get(), count);
1499 assert_eq!(sock.state().counters().tx_checksum_errors.get(), count);
1500 };
1501
1502 let mut body = vec![0; header_len];
1503 body[0] = message_type;
1504 assert_counters(api.core_ctx(), 0);
1505
1506 let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1507 assert_matches!(
1508 api.send_to(&sock, Some(remote_ip), Buf::new(body, ..)),
1509 Err(RawIpSocketSendToError::InvalidBody)
1510 );
1511
1512 assert_counters(api.core_ctx(), 1);
1513 }
1514
1515 #[ip_test(I)]
1516 #[test_case::test_matrix(
1517 [MarkDomain::Mark1, MarkDomain::Mark2],
1518 [None, Some(0), Some(1)]
1519 )]
1520 fn raw_ip_socket_marks<I: TestIpExt + DualStackIpExt + IpExt>(
1521 domain: MarkDomain,
1522 mark: Option<u32>,
1523 ) {
1524 let mut api = new_raw_ip_socket_api::<I>();
1525 let socket = api.create(RawIpSocketProtocol::Raw, Default::default());
1526
1527 assert_eq!(api.get_mark(&socket, domain), Mark(None));
1529
1530 let mark = Mark(mark);
1531 api.set_mark(&socket, domain, mark);
1533 assert_eq!(api.get_mark(&socket, domain), mark);
1534 }
1535}