1use alloc::collections::BTreeMap;
8use alloc::collections::btree_map::Entry;
9use core::fmt::{self, Debug, Display};
10use core::num::NonZeroU8;
11use derivative::Derivative;
12use log::debug;
13use net_types::ip::{GenericOverIp, Ip, IpVersionMarker, Mtu};
14use net_types::{SpecifiedAddr, ZonedAddr};
15use netstack3_base::socket::{
16 DualStackIpExt, DualStackRemoteIp, SocketCookie, SocketZonedAddrExt as _,
17};
18use netstack3_base::sync::{PrimaryRc, StrongRc, WeakRc};
19use netstack3_base::{
20 AnyDevice, ContextPair, DeviceIdContext, Inspector, InspectorDeviceExt, InspectorExt,
21 IpDeviceAddr, IpExt, Mark, MarkDomain, Marks, ReferenceNotifiers, ReferenceNotifiersExt as _,
22 RemoveResourceResultWithContext, ResourceCounterContext, StrongDeviceIdentifier,
23 TxMetadataBindingsTypes, WeakDeviceIdentifier, ZonedAddressError,
24};
25use netstack3_filter::{FilterIpExt, RawIpBody};
26use netstack3_hashmap::HashMap;
27use packet::{BufferMut, SliceBufViewMut};
28use packet_formats::icmp;
29use packet_formats::ip::{DscpAndEcn, IpPacket};
30use thiserror::Error;
31use zerocopy::SplitByteSlice;
32
33use crate::DEFAULT_HOP_LIMITS;
34use crate::internal::raw::counters::RawIpSocketCounters;
35use crate::internal::raw::filter::RawIpSocketIcmpFilter;
36use crate::internal::raw::protocol::RawIpSocketProtocol;
37use crate::internal::raw::state::{RawIpSocketLockedState, RawIpSocketState};
38use crate::internal::socket::{
39 IpSockCreateAndSendError, IpSocketArgs, IpSocketHandler, RouteResolutionOptions,
40 SendOneShotIpPacketError, SendOptions, SocketHopLimits,
41};
42
43mod checksum;
44pub(crate) mod counters;
45pub(crate) mod filter;
46pub(crate) mod protocol;
47pub(crate) mod state;
48
49pub trait RawIpSocketsBindingsTypes: TxMetadataBindingsTypes {
51 type RawIpSocketState<I: Ip>: Send + Sync + Debug;
53}
54
55pub enum ReceivePacketError {
57 QueueFull,
59}
60
61pub trait RawIpSocketsBindingsContext<I: IpExt, D: StrongDeviceIdentifier>:
63 RawIpSocketsBindingsTypes + Sized
64{
65 fn receive_packet<B: SplitByteSlice>(
67 &self,
68 socket: &RawIpSocketId<I, D::Weak, Self>,
69 packet: &I::Packet<B>,
70 device: &D,
71 ) -> Result<(), ReceivePacketError>;
72}
73
74pub struct RawIpSocketApi<I: Ip, C> {
76 ctx: C,
77 _ip_mark: IpVersionMarker<I>,
78}
79
80impl<I: Ip, C> RawIpSocketApi<I, C> {
81 pub fn new(ctx: C) -> Self {
83 Self { ctx, _ip_mark: IpVersionMarker::new() }
84 }
85}
86
87impl<I: IpExt + DualStackIpExt + FilterIpExt, C> RawIpSocketApi<I, C>
88where
89 C: ContextPair,
90 C::BindingsContext: RawIpSocketsBindingsTypes + ReferenceNotifiers + 'static,
91 C::CoreContext: RawIpSocketMapContext<I, C::BindingsContext>
92 + RawIpSocketStateContext<I, C::BindingsContext>
93 + ResourceCounterContext<RawIpApiSocketId<I, C>, RawIpSocketCounters<I>>,
94{
95 fn core_ctx(&mut self) -> &mut C::CoreContext {
96 let Self { ctx, _ip_mark } = self;
97 ctx.core_ctx()
98 }
99
100 fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
101 let Self { ctx, _ip_mark } = self;
102 ctx.contexts()
103 }
104
105 pub fn create(
107 &mut self,
108 protocol: RawIpSocketProtocol<I>,
109 external_state: <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
110 ) -> RawIpApiSocketId<I, C> {
111 let socket =
112 PrimaryRawIpSocketId(PrimaryRc::new(RawIpSocketState::new(protocol, external_state)));
113 let strong = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.insert(socket));
114 debug!("created raw IP socket {strong:?}, on protocol {protocol:?}");
115
116 if protocol.requires_system_checksums() {
117 self.core_ctx().with_locked_state_mut(&strong, |state| state.system_checksums = true)
118 }
119
120 strong
121 }
122
123 pub fn close(
125 &mut self,
126 id: RawIpApiSocketId<I, C>,
127 ) -> RemoveResourceResultWithContext<
128 <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
129 C::BindingsContext,
130 > {
131 let primary = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.remove(id));
132 debug!("removed raw IP socket {primary:?}");
133 let PrimaryRawIpSocketId(primary) = primary;
134
135 C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
136 primary,
137 |state: RawIpSocketState<I, _, C::BindingsContext>| state.into_external_state(),
138 )
139 }
140
141 pub fn send_to<B: BufferMut>(
146 &mut self,
147 id: &RawIpApiSocketId<I, C>,
148 remote_ip: Option<
149 ZonedAddr<
150 SpecifiedAddr<I::Addr>,
151 <C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId,
152 >,
153 >,
154 mut body: B,
155 ) -> Result<(), RawIpSocketSendToError> {
156 match id.protocol() {
157 RawIpSocketProtocol::Raw => return Err(RawIpSocketSendToError::ProtocolRaw),
158 RawIpSocketProtocol::Proto(_) => {}
159 }
160 let local_ip = None;
165
166 let remote_ip = match DualStackRemoteIp::<I, _>::new(remote_ip) {
167 DualStackRemoteIp::ThisStack(addr) => addr,
168 DualStackRemoteIp::OtherStack(_addr) => {
169 return Err(RawIpSocketSendToError::MappedRemoteIp);
170 }
171 };
172 let protocol = id.protocol().proto();
173
174 let (core_ctx, bindings_ctx) = self.contexts();
175 let result = core_ctx.with_locked_state_and_socket_handler(id, |state, core_ctx| {
176 let RawIpSocketLockedState {
177 bound_device,
178 icmp_filter: _,
179 hop_limits,
180 multicast_loop,
181 system_checksums,
182 marks,
183 } = state;
184 let (remote_ip, device) = remote_ip
185 .resolve_addr_with_device(bound_device.clone())
186 .map_err(RawIpSocketSendToError::Zone)?;
187 let send_options = RawIpSocketOptions {
188 hop_limits: &hop_limits,
189 multicast_loop: *multicast_loop,
190 marks: &marks,
191 };
192
193 let build_packet_fn =
194 |src_ip: IpDeviceAddr<I::Addr>| -> Result<RawIpBody<_, _>, RawIpSocketSendToError> {
195 if *system_checksums {
196 let buf = SliceBufViewMut::new(body.as_mut());
197 if !checksum::populate_checksum::<I, _>(
198 src_ip.addr(),
199 remote_ip.addr(),
200 protocol,
201 buf,
202 ) {
203 return Err(RawIpSocketSendToError::InvalidBody);
204 }
205 }
206 Ok(RawIpBody::new(protocol, src_ip.addr(), remote_ip.addr(), body))
207 };
208
209 let tx_metadata = Default::default();
212
213 core_ctx
214 .send_oneshot_ip_packet_with_fallible_serializer(
215 bindings_ctx,
216 IpSocketArgs {
217 device: device.as_ref().map(|d| d.as_ref()),
218 local_ip,
219 remote_ip,
220 proto: protocol,
221 options: &send_options,
222 },
223 tx_metadata,
224 build_packet_fn,
225 )
226 .map_err(|e| match e {
227 SendOneShotIpPacketError::CreateAndSendError { err } => {
228 RawIpSocketSendToError::Ip(err)
229 }
230 SendOneShotIpPacketError::SerializeError(err) => err,
231 })
232 });
233 match &result {
234 Ok(()) => core_ctx
235 .increment_both(&id, |counters: &RawIpSocketCounters<I>| &counters.tx_packets),
236 Err(RawIpSocketSendToError::InvalidBody) => core_ctx
237 .increment_both(&id, |counters: &RawIpSocketCounters<I>| {
238 &counters.tx_checksum_errors
239 }),
240 Err(_) => {}
241 }
242 result
243 }
244
245 pub fn set_device(
253 &mut self,
254 id: &RawIpApiSocketId<I, C>,
255 device: Option<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
256 ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
257 let device = device.map(|strong| strong.downgrade());
258 self.core_ctx()
263 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.bound_device, device))
264 }
265
266 pub fn get_device(
268 &mut self,
269 id: &RawIpApiSocketId<I, C>,
270 ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
271 self.core_ctx().with_locked_state(id, |state| state.bound_device.clone())
272 }
273
274 pub fn set_icmp_filter(
279 &mut self,
280 id: &RawIpApiSocketId<I, C>,
281 filter: Option<RawIpSocketIcmpFilter<I>>,
282 ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
283 debug!("setting ICMP Filter on {id:?}: {filter:?}");
284 if !id.protocol().is_icmp() {
285 return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
286 }
287 Ok(self
288 .core_ctx()
289 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.icmp_filter, filter)))
290 }
291
292 pub fn get_icmp_filter(
297 &mut self,
298 id: &RawIpApiSocketId<I, C>,
299 ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
300 if !id.protocol().is_icmp() {
301 return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
302 }
303 Ok(self.core_ctx().with_locked_state(id, |state| state.icmp_filter.clone()))
304 }
305
306 pub fn set_unicast_hop_limit(
311 &mut self,
312 id: &RawIpApiSocketId<I, C>,
313 new_limit: Option<NonZeroU8>,
314 ) -> Option<NonZeroU8> {
315 self.core_ctx().with_locked_state_mut(id, |state| {
316 core::mem::replace(&mut state.hop_limits.unicast, new_limit)
317 })
318 }
319
320 pub fn get_unicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
322 self.core_ctx().with_locked_state(id, |state| {
323 state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).unicast
324 })
325 }
326
327 pub fn set_multicast_hop_limit(
332 &mut self,
333 id: &RawIpApiSocketId<I, C>,
334 new_limit: Option<NonZeroU8>,
335 ) -> Option<NonZeroU8> {
336 self.core_ctx().with_locked_state_mut(id, |state| {
337 core::mem::replace(&mut state.hop_limits.multicast, new_limit)
338 })
339 }
340
341 pub fn get_multicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
343 self.core_ctx().with_locked_state(id, |state| {
344 state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).multicast
345 })
346 }
347
348 pub fn set_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>, value: bool) -> bool {
352 self.core_ctx()
353 .with_locked_state_mut(id, |state| core::mem::replace(&mut state.multicast_loop, value))
354 }
355
356 pub fn get_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>) -> bool {
358 self.core_ctx().with_locked_state(id, |state| state.multicast_loop)
359 }
360
361 pub fn set_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain, mark: Mark) {
363 self.core_ctx().with_locked_state_mut(id, |state| {
364 *state.marks.get_mut(domain) = mark;
365 })
366 }
367
368 pub fn get_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain) -> Mark {
370 self.core_ctx().with_locked_state(id, |state| *state.marks.get(domain))
371 }
372
373 pub fn inspect<N>(&mut self, inspector: &mut N)
375 where
376 N: Inspector
377 + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
378 {
379 self.core_ctx().with_socket_map_and_state_ctx(|socket_map, core_ctx| {
380 socket_map.iter_sockets().for_each(|socket| {
381 inspector.record_debug_child(socket, |node| {
382 node.record_display("TransportProtocol", socket.protocol().proto());
383 node.record_str("NetworkProtocol", I::NAME);
384 node.record_local_socket_addr::<
386 N,
387 I::Addr,
388 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
389 NoPortMarker,
390 >(None);
391 node.record_remote_socket_addr::<
393 N,
394 I::Addr,
395 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
396 NoPortMarker,
397 >(None);
398 core_ctx.with_locked_state(socket, |state| {
399 let RawIpSocketLockedState {
400 bound_device,
401 icmp_filter,
402 hop_limits: _,
403 multicast_loop: _,
404 marks,
405 system_checksums: _,
406 } = state;
407 if let Some(bound_device) = bound_device {
408 N::record_device(node, "BoundDevice", bound_device);
409 } else {
410 node.record_str("BoundDevice", "None");
411 }
412 if let Some(icmp_filter) = icmp_filter {
413 node.record_display("IcmpFilter", icmp_filter);
414 } else {
415 node.record_str("IcmpFilter", "None");
416 }
417 node.delegate_inspectable(marks);
418 });
419 node.record_child("Counters", |node| {
420 node.delegate_inspectable(socket.state().counters())
421 })
422 })
423 })
424 })
425 }
426}
427
428#[derive(Debug, Error)]
430pub enum RawIpSocketSendToError {
431 #[error("send_to is disallowed for raw IP socket with protocol = raw")]
435 ProtocolRaw,
436 #[error("dual-stack operations are not supported on raw IP sockets")]
439 MappedRemoteIp,
440 #[error("provided packet body was invalid (e.g. bad checksum)")]
444 InvalidBody,
445 #[error("resolving remote IP's zone: {0}")]
447 Zone(#[from] ZonedAddressError),
448 #[error("failed to send packet: {0}")]
450 Ip(#[from] IpSockCreateAndSendError),
451}
452
453#[derive(Debug, PartialEq, Error)]
455pub enum RawIpSocketIcmpFilterError {
456 #[error("socket's protocol does not allow ICMP filters")]
458 ProtocolNotIcmp,
459}
460
461struct PrimaryRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
463 PrimaryRc<RawIpSocketState<I, D, BT>>,
464);
465
466impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
467 for PrimaryRawIpSocketId<I, D, BT>
468{
469 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
470 let Self(rc) = self;
471 f.debug_tuple("RawIpSocketId").field(&PrimaryRc::debug_id(rc)).finish()
472 }
473}
474
475#[derive(Derivative, GenericOverIp)]
477#[derivative(Clone(bound = ""), Eq(bound = ""), Hash(bound = ""), PartialEq(bound = ""))]
478#[generic_over_ip(I, Ip)]
479pub struct RawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
480 StrongRc<RawIpSocketState<I, D, BT>>,
481);
482
483impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketId<I, D, BT> {
484 pub fn socket_cookie(&self) -> SocketCookie {
486 let Self(rc) = self;
487 SocketCookie::new(rc.resource_token())
488 }
489}
490
491impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketId<I, D, BT> {
492 pub fn external_state(&self) -> &BT::RawIpSocketState<I> {
494 let RawIpSocketId(strong_rc) = self;
495 strong_rc.external_state()
496 }
497 pub fn protocol(&self) -> &RawIpSocketProtocol<I> {
499 let RawIpSocketId(strong_rc) = self;
500 strong_rc.protocol()
501 }
502 pub fn downgrade(&self) -> WeakRawIpSocketId<I, D, BT> {
504 let Self(rc) = self;
505 WeakRawIpSocketId(StrongRc::downgrade(rc))
506 }
507 pub fn state(&self) -> &RawIpSocketState<I, D, BT> {
509 let RawIpSocketId(strong_rc) = self;
510 &*strong_rc
511 }
512}
513
514impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
515 for RawIpSocketId<I, D, BT>
516{
517 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
518 let Self(rc) = self;
519 f.debug_tuple("RawIpSocketId").field(&StrongRc::debug_id(rc)).finish()
520 }
521}
522
523pub struct WeakRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
525 WeakRc<RawIpSocketState<I, D, BT>>,
526);
527
528impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
529 for WeakRawIpSocketId<I, D, BT>
530{
531 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
532 let Self(rc) = self;
533 f.debug_tuple("WeakRawIpSocketId").field(&WeakRc::debug_id(rc)).finish()
534 }
535}
536
537type RawIpApiSocketId<I, C> = RawIpSocketId<
539 I,
540 <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
541 <C as ContextPair>::BindingsContext,
542>;
543
544pub trait RawIpSocketStateContext<I: IpExt + FilterIpExt, BT: RawIpSocketsBindingsTypes>:
548 DeviceIdContext<AnyDevice>
549{
550 type SocketHandler<'a>: IpSocketHandler<I, BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>;
553
554 fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O>(
557 &mut self,
558 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
559 cb: F,
560 ) -> O;
561
562 fn with_locked_state_and_socket_handler<
565 O,
566 F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>, &mut Self::SocketHandler<'_>) -> O,
567 >(
568 &mut self,
569 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
570 cb: F,
571 ) -> O;
572
573 fn with_locked_state_mut<
576 O,
577 F: FnOnce(&mut RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O,
578 >(
579 &mut self,
580 id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
581 cb: F,
582 ) -> O;
583}
584
585#[derive(Derivative)]
589#[derivative(Default(bound = ""))]
590pub struct RawIpSocketMap<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> {
591 sockets: BTreeMap<
602 RawIpSocketProtocol<I>,
603 HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>>,
604 >,
605}
606
607impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketMap<I, D, BT> {
608 fn insert(&mut self, socket: PrimaryRawIpSocketId<I, D, BT>) -> RawIpSocketId<I, D, BT> {
609 let RawIpSocketMap { sockets } = self;
610 let PrimaryRawIpSocketId(primary) = &socket;
611 let strong = RawIpSocketId(PrimaryRc::clone_strong(primary));
612 assert!(
615 sockets.entry(*strong.protocol()).or_default().insert(strong.clone(), socket).is_none()
616 );
617 strong
618 }
619
620 fn remove(&mut self, socket: RawIpSocketId<I, D, BT>) -> PrimaryRawIpSocketId<I, D, BT> {
621 let RawIpSocketMap { sockets } = self;
625 let protocol = *socket.protocol();
626 match sockets.entry(protocol) {
627 Entry::Vacant(_) => unreachable!(
628 "{socket:?} with protocol {protocol:?} must be present in the socket map"
629 ),
630 Entry::Occupied(mut entry) => {
631 let map = entry.get_mut();
632 let primary = map.remove(&socket).unwrap();
633 if map.is_empty() {
636 let _: HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>> =
637 entry.remove();
638 }
639 primary
640 }
641 }
642 }
643
644 fn iter_sockets(&self) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
645 let RawIpSocketMap { sockets } = self;
646 sockets.values().flat_map(|sockets_by_protocol| sockets_by_protocol.keys())
647 }
648
649 fn iter_sockets_for_protocol(
650 &self,
651 protocol: &RawIpSocketProtocol<I>,
652 ) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
653 let RawIpSocketMap { sockets } = self;
654 sockets.get(protocol).map(|sockets| sockets.keys()).into_iter().flatten()
655 }
656}
657
658pub trait RawIpSocketMapContext<I: IpExt + FilterIpExt, BT: RawIpSocketsBindingsTypes>:
660 DeviceIdContext<AnyDevice>
661{
662 type StateCtx<'a>: RawIpSocketStateContext<I, BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
665 + ResourceCounterContext<RawIpSocketId<I, Self::WeakDeviceId, BT>, RawIpSocketCounters<I>>;
666
667 fn with_socket_map_and_state_ctx<
669 O,
670 F: FnOnce(&RawIpSocketMap<I, Self::WeakDeviceId, BT>, &mut Self::StateCtx<'_>) -> O,
671 >(
672 &mut self,
673 cb: F,
674 ) -> O;
675 fn with_socket_map_mut<O, F: FnOnce(&mut RawIpSocketMap<I, Self::WeakDeviceId, BT>) -> O>(
677 &mut self,
678 cb: F,
679 ) -> O;
680}
681
682pub trait RawIpSocketHandler<I: IpExt, BC>: DeviceIdContext<AnyDevice> {
684 fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
686 &mut self,
687 bindings_ctx: &mut BC,
688 packet: &I::Packet<B>,
689 device: &Self::DeviceId,
690 );
691}
692
693impl<I, BC, CC> RawIpSocketHandler<I, BC> for CC
694where
695 I: IpExt + FilterIpExt,
696 BC: RawIpSocketsBindingsContext<I, CC::DeviceId>,
697 CC: RawIpSocketMapContext<I, BC>,
698{
699 fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
700 &mut self,
701 bindings_ctx: &mut BC,
702 packet: &I::Packet<B>,
703 device: &CC::DeviceId,
704 ) {
705 let protocol = RawIpSocketProtocol::new(packet.proto());
706
707 match protocol {
710 RawIpSocketProtocol::Raw => {
711 debug!("received IP packet with raw protocol (IANA Reserved - 255); dropping");
712 return;
713 }
714 RawIpSocketProtocol::Proto(_) => {}
715 };
716
717 self.with_socket_map_and_state_ctx(|socket_map, core_ctx| {
718 socket_map.iter_sockets_for_protocol(&protocol).for_each(|socket| {
719 match core_ctx.with_locked_state(socket, |state| {
720 check_packet_for_delivery(packet, device, state)
721 }) {
722 DeliveryOutcome::Deliver => {
723 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
724 &counters.rx_packets
725 });
726 match bindings_ctx.receive_packet(socket, packet, device) {
727 Ok(()) => {}
728 Err(ReceivePacketError::QueueFull) => {
729 core_ctx.increment_both(
730 &socket,
731 |counters: &RawIpSocketCounters<I>| &counters.rx_queue_full,
732 );
733 }
734 }
735 }
736 DeliveryOutcome::WrongChecksum => {
737 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
738 &counters.rx_checksum_errors
739 });
740 }
741 DeliveryOutcome::WrongIcmpMessageType => {
742 core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
743 &counters.rx_icmp_filtered
744 });
745 }
746 DeliveryOutcome::WrongDevice => {}
747 }
748 })
749 })
750 }
751}
752
753enum DeliveryOutcome {
755 Deliver,
757 WrongDevice,
759 WrongChecksum,
761 WrongIcmpMessageType,
764}
765
766fn check_packet_for_delivery<I: IpExt, D: StrongDeviceIdentifier, B: SplitByteSlice>(
768 packet: &I::Packet<B>,
769 device: &D,
770 socket: &RawIpSocketLockedState<I, D::Weak>,
771) -> DeliveryOutcome {
772 let RawIpSocketLockedState {
773 bound_device,
774 icmp_filter,
775 hop_limits: _,
776 marks: _,
777 multicast_loop: _,
778 system_checksums,
779 } = socket;
780 if bound_device.as_ref().is_some_and(|bound_device| bound_device != device) {
782 return DeliveryOutcome::WrongDevice;
783 }
784
785 if *system_checksums && !checksum::has_valid_checksum::<I, B>(packet) {
790 return DeliveryOutcome::WrongChecksum;
791 }
792
793 if icmp_filter.as_ref().is_some_and(|icmp_filter| {
795 debug_assert!(RawIpSocketProtocol::<I>::new(packet.proto()).is_icmp());
799 match icmp::peek_message_type(packet.body()) {
800 Err(_) => true,
810 Ok(message_type) => !icmp_filter.allows_type(message_type),
811 }
812 }) {
813 return DeliveryOutcome::WrongIcmpMessageType;
814 }
815
816 DeliveryOutcome::Deliver
817}
818
819struct RawIpSocketOptions<'a, I: Ip> {
821 hop_limits: &'a SocketHopLimits<I>,
822 multicast_loop: bool,
823 marks: &'a Marks,
824}
825
826impl<I: Ip> RouteResolutionOptions<I> for RawIpSocketOptions<'_, I> {
827 fn transparent(&self) -> bool {
828 false
829 }
830
831 fn marks(&self) -> &Marks {
832 self.marks
833 }
834}
835
836impl<I: IpExt> SendOptions<I> for RawIpSocketOptions<'_, I> {
837 fn hop_limit(&self, destination: &SpecifiedAddr<I::Addr>) -> Option<NonZeroU8> {
838 self.hop_limits.hop_limit_for_dst(destination)
839 }
840
841 fn multicast_loop(&self) -> bool {
842 self.multicast_loop
843 }
844
845 fn allow_broadcast(&self) -> Option<I::BroadcastMarker> {
846 None
847 }
848
849 fn dscp_and_ecn(&self) -> DscpAndEcn {
850 DscpAndEcn::default()
851 }
852
853 fn mtu(&self) -> Mtu {
854 Mtu::no_limit()
855 }
856}
857
858struct NoPortMarker {}
860
861impl Display for NoPortMarker {
862 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
863 write!(f, "NoPort")
864 }
865}
866
867#[cfg(test)]
868mod test {
869 use super::*;
870
871 use alloc::rc::Rc;
872 use alloc::vec;
873 use alloc::vec::Vec;
874 use assert_matches::assert_matches;
875 use core::cell::RefCell;
876 use core::convert::Infallible as Never;
877 use core::marker::PhantomData;
878 use core::ops::DerefMut;
879 use ip_test_macro::ip_test;
880 use net_types::ip::{IpVersion, Ipv4, Ipv6};
881 use netstack3_base::sync::{DynDebugReferences, Mutex};
882 use netstack3_base::testutil::{
883 FakeStrongDeviceId, FakeTxMetadata, FakeWeakDeviceId, MultipleDevicesId, TestIpExt,
884 };
885 use netstack3_base::{ContextProvider, CounterContext, CtxPair};
886 use packet::{
887 Buf, InnerPacketBuilder as _, PacketBuilder as _, ParseBuffer as _, Serializer as _,
888 };
889 use packet_formats::icmp::{
890 IcmpEchoReply, IcmpMessage, IcmpPacketBuilder, IcmpZeroCode, Icmpv6MessageType,
891 };
892 use packet_formats::ip::{IpPacketBuilder, IpProto, IpProtoExt, Ipv6Proto};
893 use packet_formats::ipv6::Ipv6Packet;
894 use test_case::test_case;
895
896 use crate::internal::socket::testutil::{FakeIpSocketCtx, InnerFakeIpSocketCtx};
897 use crate::socket::testutil::FakeDeviceConfig;
898 use crate::{DEFAULT_HOP_LIMITS, SendIpPacketMeta};
899
900 #[derive(Derivative, Debug)]
901 #[derivative(Default(bound = ""))]
902 struct FakeExternalSocketState<D> {
903 received_packets: Mutex<RxQueue<D>>,
905 }
906
907 #[derive(Derivative, Debug)]
908 #[derivative(Default(bound = ""))]
909 struct RxQueue<D> {
910 packets: Vec<ReceivedIpPacket<D>>,
911 #[derivative(Default(value = "usize::MAX"))]
912 max_size: usize,
913 }
914
915 #[derive(Debug, PartialEq)]
916 struct ReceivedIpPacket<D> {
917 data: Vec<u8>,
918 device: D,
919 }
920
921 #[derive(Derivative)]
922 #[derivative(Default(bound = ""))]
923 struct FakeBindingsCtx<D> {
924 _device_id_type: PhantomData<D>,
925 }
926
927 struct FakeCoreCtxState<I: IpExt, D: FakeStrongDeviceId> {
929 socket_map: Rc<RefCell<RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>>>,
934 ip_socket_ctx: FakeIpSocketCtx<I, D>,
938 counters: RawIpSocketCounters<I>,
940 }
941
942 impl<I: IpExt, D: FakeStrongDeviceId> InnerFakeIpSocketCtx<I, D> for FakeCoreCtxState<I, D> {
943 fn fake_ip_socket_ctx_mut(&mut self) -> &mut FakeIpSocketCtx<I, D> {
944 &mut self.ip_socket_ctx
945 }
946 }
947
948 type FakeCoreCtx<I, D> = netstack3_base::testutil::FakeCoreCtx<
949 FakeCoreCtxState<I, D>,
950 SendIpPacketMeta<I, D, SpecifiedAddr<<I as Ip>::Addr>>,
951 D,
952 >;
953
954 impl<D: FakeStrongDeviceId> TxMetadataBindingsTypes for FakeBindingsCtx<D> {
955 type TxMetadata = FakeTxMetadata;
956 }
957
958 impl<D: FakeStrongDeviceId> RawIpSocketsBindingsTypes for FakeBindingsCtx<D> {
959 type RawIpSocketState<I: Ip> = FakeExternalSocketState<D>;
960 }
961
962 impl<I: IpExt, D: Copy + FakeStrongDeviceId> RawIpSocketsBindingsContext<I, D>
963 for FakeBindingsCtx<D>
964 {
965 fn receive_packet<B: SplitByteSlice>(
966 &self,
967 socket: &RawIpSocketId<I, D::Weak, Self>,
968 packet: &I::Packet<B>,
969 device: &D,
970 ) -> Result<(), ReceivePacketError> {
971 let packet = ReceivedIpPacket { data: packet.to_vec(), device: *device };
972 let FakeExternalSocketState { received_packets } = socket.external_state();
973 let mut lock_guard = received_packets.lock();
974 let RxQueue { packets, max_size } = lock_guard.deref_mut();
975 if packets.len() < *max_size {
976 packets.push(packet);
977 Ok(())
978 } else {
979 Err(ReceivePacketError::QueueFull)
980 }
981 }
982 }
983
984 impl<I: IpExt + FilterIpExt, D: FakeStrongDeviceId>
985 RawIpSocketStateContext<I, FakeBindingsCtx<D>> for FakeCoreCtx<I, D>
986 {
987 type SocketHandler<'a> = FakeCoreCtx<I, D>;
988 fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, D::Weak>) -> O>(
989 &mut self,
990 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
991 cb: F,
992 ) -> O {
993 let RawIpSocketId(state_rc) = id;
994 let guard = state_rc.locked_state().read();
995 cb(&guard)
996 }
997 fn with_locked_state_and_socket_handler<
998 O,
999 F: FnOnce(&RawIpSocketLockedState<I, D::Weak>, &mut Self::SocketHandler<'_>) -> O,
1000 >(
1001 &mut self,
1002 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
1003 cb: F,
1004 ) -> O {
1005 let RawIpSocketId(state_rc) = id;
1006 let guard = state_rc.locked_state().read();
1007 cb(&guard, self)
1008 }
1009 fn with_locked_state_mut<O, F: FnOnce(&mut RawIpSocketLockedState<I, D::Weak>) -> O>(
1010 &mut self,
1011 id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
1012 cb: F,
1013 ) -> O {
1014 let RawIpSocketId(state_rc) = id;
1015 let mut guard = state_rc.locked_state().write();
1016 cb(&mut guard)
1017 }
1018 }
1019
1020 impl<I: IpExt, D: FakeStrongDeviceId> CounterContext<RawIpSocketCounters<I>> for FakeCoreCtx<I, D> {
1021 fn counters(&self) -> &RawIpSocketCounters<I> {
1022 &self.state.counters
1023 }
1024 }
1025
1026 impl<I: IpExt, D: FakeStrongDeviceId>
1027 ResourceCounterContext<
1028 RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
1029 RawIpSocketCounters<I>,
1030 > for FakeCoreCtx<I, D>
1031 {
1032 fn per_resource_counters<'a>(
1033 &'a self,
1034 socket: &'a RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
1035 ) -> &'a RawIpSocketCounters<I> {
1036 socket.state().counters()
1037 }
1038 }
1039
1040 impl<I: IpExt + FilterIpExt, D: FakeStrongDeviceId> RawIpSocketMapContext<I, FakeBindingsCtx<D>>
1041 for FakeCoreCtx<I, D>
1042 {
1043 type StateCtx<'a> = FakeCoreCtx<I, D>;
1044 fn with_socket_map_and_state_ctx<
1045 O,
1046 F: FnOnce(&RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>, &mut Self::StateCtx<'_>) -> O,
1047 >(
1048 &mut self,
1049 cb: F,
1050 ) -> O {
1051 let socket_map = self.state.socket_map.clone();
1052 let borrow = socket_map.borrow();
1053 cb(&borrow, self)
1054 }
1055 fn with_socket_map_mut<
1056 O,
1057 F: FnOnce(&mut RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>) -> O,
1058 >(
1059 &mut self,
1060 cb: F,
1061 ) -> O {
1062 cb(&mut self.state.socket_map.borrow_mut())
1063 }
1064 }
1065
1066 impl<D> ContextProvider for FakeBindingsCtx<D> {
1067 type Context = FakeBindingsCtx<D>;
1068 fn context(&mut self) -> &mut Self::Context {
1069 self
1070 }
1071 }
1072
1073 impl<D> ReferenceNotifiers for FakeBindingsCtx<D> {
1074 type ReferenceReceiver<T: 'static> = Never;
1075
1076 type ReferenceNotifier<T: Send + 'static> = Never;
1077
1078 fn new_reference_notifier<T: Send + 'static>(
1079 _debug_references: DynDebugReferences,
1080 ) -> (Self::ReferenceNotifier<T>, Self::ReferenceReceiver<T>) {
1081 unimplemented!("raw IP socket removal shouldn't be deferred in tests");
1082 }
1083 }
1084
1085 fn new_raw_ip_socket_api<I: IpExt + TestIpExt>() -> RawIpSocketApi<
1086 I,
1087 CtxPair<FakeCoreCtx<I, MultipleDevicesId>, FakeBindingsCtx<MultipleDevicesId>>,
1088 > {
1089 let device_configs = [MultipleDevicesId::A, MultipleDevicesId::B, MultipleDevicesId::C]
1091 .into_iter()
1092 .map(|device| FakeDeviceConfig {
1093 device,
1094 local_ips: vec![I::TEST_ADDRS.local_ip],
1095 remote_ips: vec![I::TEST_ADDRS.remote_ip],
1096 });
1097 let state = FakeCoreCtxState {
1098 socket_map: Default::default(),
1099 ip_socket_ctx: FakeIpSocketCtx::new(device_configs),
1100 counters: Default::default(),
1101 };
1102
1103 RawIpSocketApi::new(CtxPair::with_core_ctx(FakeCoreCtx::with_state(state)))
1104 }
1105
1106 const IP_BODY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
1108
1109 fn new_ip_packet_buf<I: IpExt + TestIpExt>(
1111 ip_body: &[u8],
1112 proto: I::Proto,
1113 ) -> impl AsRef<[u8]> {
1114 const TTL: u8 = 255;
1115 I::PacketBuilder::new(*I::TEST_ADDRS.local_ip, *I::TEST_ADDRS.remote_ip, TTL, proto)
1116 .wrap_body(ip_body.into_serializer())
1117 .serialize_vec_outer()
1118 .unwrap()
1119 }
1120
1121 fn new_icmp_message_buf<I: IpExt + TestIpExt, M: IcmpMessage<I> + Debug>(
1123 message: M,
1124 code: M::Code,
1125 ) -> impl AsRef<[u8]> {
1126 IcmpPacketBuilder::new(*I::TEST_ADDRS.local_ip, *I::TEST_ADDRS.remote_ip, code, message)
1127 .wrap_body([].into_serializer())
1128 .serialize_vec_outer()
1129 .unwrap()
1130 }
1131
1132 #[ip_test(I)]
1133 #[test_case(IpProto::Udp; "UDP")]
1134 #[test_case(IpProto::Reserved; "IPPROTO_RAW")]
1135 fn create_and_close<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>(proto: IpProto) {
1136 let mut api = new_raw_ip_socket_api::<I>();
1137 let sock = api.create(RawIpSocketProtocol::new(proto.into()), Default::default());
1138 let FakeExternalSocketState { received_packets: _ } = api.close(sock).into_removed();
1139 }
1140
1141 #[ip_test(I)]
1142 fn set_device<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1143 let mut api = new_raw_ip_socket_api::<I>();
1144 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1145
1146 assert_eq!(api.get_device(&sock), None);
1147 assert_eq!(api.set_device(&sock, Some(&MultipleDevicesId::A)), None);
1148 assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::A)));
1149 assert_eq!(
1150 api.set_device(&sock, Some(&MultipleDevicesId::B)),
1151 Some(FakeWeakDeviceId(MultipleDevicesId::A))
1152 );
1153 assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1154 assert_eq!(api.set_device(&sock, None), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1155 assert_eq!(api.get_device(&sock), None);
1156 }
1157
1158 #[ip_test(I)]
1159 fn set_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1160 let filter1 = RawIpSocketIcmpFilter::<I>::new([123; 32]);
1161 let filter2 = RawIpSocketIcmpFilter::<I>::new([234; 32]);
1162 let mut api = new_raw_ip_socket_api::<I>();
1163
1164 let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1165 assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1166 assert_eq!(api.set_icmp_filter(&sock, Some(filter1.clone())), Ok(None));
1167 assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter1.clone())));
1168 assert_eq!(api.set_icmp_filter(&sock, Some(filter2.clone())), Ok(Some(filter1.clone())));
1169 assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter2.clone())));
1170 assert_eq!(api.set_icmp_filter(&sock, None), Ok(Some(filter2)));
1171 assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1172
1173 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1175 assert_eq!(
1176 api.set_icmp_filter(&sock, Some(filter1)),
1177 Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp)
1178 );
1179 assert_eq!(api.get_icmp_filter(&sock), Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp));
1180 }
1181
1182 #[ip_test(I)]
1183 fn set_unicast_hop_limits<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1184 let mut api = new_raw_ip_socket_api::<I>();
1185 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1186
1187 let limit1 = NonZeroU8::new(1).unwrap();
1188 let limit2 = NonZeroU8::new(2).unwrap();
1189
1190 assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1191 assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit1)), None);
1192 assert_eq!(api.get_unicast_hop_limit(&sock), limit1);
1193 assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1194 assert_eq!(api.get_unicast_hop_limit(&sock), limit2);
1195 assert_eq!(api.set_unicast_hop_limit(&sock, None), Some(limit2));
1196 assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1197 }
1198
1199 #[ip_test(I)]
1200 fn set_multicast_hop_limit<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1201 let mut api = new_raw_ip_socket_api::<I>();
1202 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1203
1204 let limit1 = NonZeroU8::new(1).unwrap();
1205 let limit2 = NonZeroU8::new(2).unwrap();
1206
1207 assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1208 assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit1)), None);
1209 assert_eq!(api.get_multicast_hop_limit(&sock), limit1);
1210 assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1211 assert_eq!(api.get_multicast_hop_limit(&sock), limit2);
1212 assert_eq!(api.set_multicast_hop_limit(&sock, None), Some(limit2));
1213 assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1214 }
1215
1216 #[ip_test(I)]
1217 fn set_multicast_loop<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1218 let mut api = new_raw_ip_socket_api::<I>();
1219 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1220
1221 assert_eq!(api.get_multicast_loop(&sock), true);
1223 assert_eq!(api.set_multicast_loop(&sock, false), true);
1224 assert_eq!(api.get_multicast_loop(&sock), false);
1225 assert_eq!(api.set_multicast_loop(&sock, true), false);
1226 assert_eq!(api.get_multicast_loop(&sock), true);
1227 }
1228
1229 #[ip_test(I)]
1230 fn receive_ip_packet<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1231 let mut api = new_raw_ip_socket_api::<I>();
1232
1233 let proto: I::Proto = IpProto::Udp.into();
1236 let wrong_proto: I::Proto = IpProto::Tcp.into();
1237 let sock1 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1238 let sock2 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1239 let wrong_sock = api.create(RawIpSocketProtocol::new(wrong_proto), Default::default());
1240
1241 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1243 let buf = new_ip_packet_buf::<I>(&IP_BODY, proto);
1244 let mut buf_ref = buf.as_ref();
1245 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1246 {
1247 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1248 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &DEVICE);
1249 }
1250
1251 assert_eq!(api.core_ctx().state.counters.rx_packets.get(), 2);
1253 assert_eq!(sock1.state().counters().rx_packets.get(), 1);
1254 assert_eq!(sock2.state().counters().rx_packets.get(), 1);
1255 assert_eq!(wrong_sock.state().counters().rx_packets.get(), 0);
1256
1257 let FakeExternalSocketState { received_packets: sock1_packets } =
1258 api.close(sock1).into_removed();
1259 let FakeExternalSocketState { received_packets: sock2_packets } =
1260 api.close(sock2).into_removed();
1261 let FakeExternalSocketState { received_packets: wrong_sock_packets } =
1262 api.close(wrong_sock).into_removed();
1263
1264 for packets in [sock1_packets, sock2_packets] {
1266 let lock_guard = packets.lock();
1267 let ReceivedIpPacket { data, device } =
1268 assert_matches!(&lock_guard.packets[..], [packet] => packet);
1269 assert_eq!(&data[..], buf.as_ref());
1270 assert_eq!(*device, DEVICE);
1271 }
1272 assert_matches!(&wrong_sock_packets.lock().packets[..], []);
1273 }
1274
1275 #[ip_test(I)]
1278 fn cannot_receive_ip_packet_with_proto_raw<
1279 I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt,
1280 >() {
1281 let mut api = new_raw_ip_socket_api::<I>();
1282 let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1283
1284 let protocols_to_test = match I::VERSION {
1287 IpVersion::V4 => vec![IpProto::Udp, IpProto::Reserved],
1288 IpVersion::V6 => vec![IpProto::Udp],
1291 };
1292 for proto in protocols_to_test {
1293 let buf = new_ip_packet_buf::<I>(&IP_BODY, proto.into());
1294 let mut buf_ref = buf.as_ref();
1295 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1296 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1297 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1298 }
1299
1300 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1301 assert_matches!(&received_packets.lock().packets[..], []);
1302 }
1303
1304 #[ip_test(I)]
1305 #[test_case(MultipleDevicesId::A, None, true; "no_bound_device")]
1306 #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::A), true; "bound_same_device")]
1307 #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::B), false; "bound_diff_device")]
1308 fn receive_ip_packet_with_bound_device<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>(
1309 send_dev: MultipleDevicesId,
1310 bound_dev: Option<MultipleDevicesId>,
1311 should_deliver: bool,
1312 ) {
1313 const PROTO: IpProto = IpProto::Udp;
1314 let mut api = new_raw_ip_socket_api::<I>();
1315 let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1316
1317 assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1318
1319 let buf = new_ip_packet_buf::<I>(&IP_BODY, PROTO.into());
1321 let mut buf_ref = buf.as_ref();
1322 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1323 {
1324 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1325 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &send_dev);
1326 }
1327
1328 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1330 if should_deliver {
1331 let lock_guard = received_packets.lock();
1332 let ReceivedIpPacket { data, device } =
1333 assert_matches!(&lock_guard.packets[..], [packet] => packet);
1334 assert_eq!(&data[..], buf.as_ref());
1335 assert_eq!(*device, send_dev);
1336 } else {
1337 assert_matches!(&received_packets.lock().packets[..], []);
1338 }
1339 }
1340
1341 #[ip_test(I)]
1342 #[test_case(None, true; "no_filter")]
1345 #[test_case(Some(RawIpSocketIcmpFilter::<I>::ALLOW_ALL), true; "allow_all")]
1346 #[test_case(Some(RawIpSocketIcmpFilter::<I>::DENY_ALL), false; "deny_all")]
1347 fn receive_ip_packet_with_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>(
1348 filter: Option<RawIpSocketIcmpFilter<I>>,
1349 should_deliver: bool,
1350 ) {
1351 let mut api = new_raw_ip_socket_api::<I>();
1352 let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1353
1354 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1355 assert_eq!(core_ctx.state.counters.rx_icmp_filtered.get(), count);
1356 assert_eq!(sock.state().counters().rx_icmp_filtered.get(), count);
1357 };
1358 assert_counters(api.core_ctx(), 0);
1359
1360 assert_matches!(api.set_icmp_filter(&sock, filter), Ok(None));
1361
1362 let icmp_body = new_icmp_message_buf::<I, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode);
1364 let buf = new_ip_packet_buf::<I>(icmp_body.as_ref(), I::ICMP_IP_PROTO);
1365 let mut buf_ref = buf.as_ref();
1366 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1367 {
1368 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1369 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1370 }
1371
1372 assert_counters(api.core_ctx(), should_deliver.then_some(0).unwrap_or(1));
1374 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1375 if should_deliver {
1376 let lock_guard = received_packets.lock();
1377 let ReceivedIpPacket { data, device: _ } =
1378 assert_matches!(&lock_guard.packets[..], [packet] => packet);
1379 assert_eq!(&data[..], buf.as_ref());
1380 } else {
1381 assert_matches!(&received_packets.lock().packets[..], []);
1382 }
1383 }
1384
1385 #[test]
1389 fn do_not_receive_icmpv6_packet_with_bad_checksum() {
1390 let mut api = new_raw_ip_socket_api::<Ipv6>();
1391 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1392
1393 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1394 assert_eq!(core_ctx.state.counters.rx_checksum_errors.get(), count);
1395 assert_eq!(sock.state().counters().rx_checksum_errors.get(), count);
1396 };
1397 assert_counters(api.core_ctx(), 0);
1398
1399 let mut icmp_body = new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1402 .as_ref()
1403 .to_vec();
1404 const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1405 assert_ne!(
1406 packet_formats::testutil::overwrite_icmpv6_checksum(
1407 icmp_body.as_mut(),
1408 CORRUPT_CHECKSUM
1409 )
1410 .expect("parse should succeed"),
1411 CORRUPT_CHECKSUM
1412 );
1413
1414 let buf = new_ip_packet_buf::<Ipv6>(icmp_body.as_ref(), Ipv6Proto::Icmpv6);
1415 let mut buf_ref = buf.as_ref();
1416 let packet = buf_ref.parse::<Ipv6Packet<_>>().expect("parse should succeed");
1417 {
1418 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1419 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1420 }
1421
1422 assert_counters(api.core_ctx(), 1);
1424 let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1425 assert_matches!(&received_packets.lock().packets[..], []);
1426 }
1427
1428 #[ip_test(I)]
1429 fn receive_ip_packet_queue_full<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1430 let mut api = new_raw_ip_socket_api::<I>();
1431
1432 let proto: I::Proto = IpProto::Udp.into();
1433 let sock1 = api.create(
1435 RawIpSocketProtocol::new(proto),
1436 FakeExternalSocketState {
1437 received_packets: Mutex::new(RxQueue { packets: vec![], max_size: 0 }),
1438 },
1439 );
1440 let sock2 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1441
1442 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1444 let buf = new_ip_packet_buf::<I>(&IP_BODY, proto);
1445 let mut buf_ref = buf.as_ref();
1446 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1447 {
1448 let (core_ctx, bindings_ctx) = api.ctx.contexts();
1449 core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &DEVICE);
1450 }
1451
1452 assert_eq!(api.core_ctx().state.counters.rx_packets.get(), 2);
1454 assert_eq!(api.core_ctx().state.counters.rx_queue_full.get(), 1);
1455 assert_eq!(sock1.state().counters().rx_packets.get(), 1);
1456 assert_eq!(sock1.state().counters().rx_queue_full.get(), 1);
1457 assert_eq!(sock2.state().counters().rx_packets.get(), 1);
1458 assert_eq!(sock2.state().counters().rx_queue_full.get(), 0);
1459 }
1460
1461 #[ip_test(I)]
1462 #[test_case(None, None; "default_send")]
1463 #[test_case(Some(MultipleDevicesId::A), None; "with_bound_dev")]
1464 #[test_case(None, Some(123); "with_hop_limit")]
1465 fn send_to<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>(
1466 bound_dev: Option<MultipleDevicesId>,
1467 hop_limit: Option<u8>,
1468 ) {
1469 const PROTO: IpProto = IpProto::Udp;
1470 let mut api = new_raw_ip_socket_api::<I>();
1471 let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1472
1473 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1474 assert_eq!(core_ctx.state.counters.tx_packets.get(), count);
1475 assert_eq!(sock.state().counters().tx_packets.get(), count);
1476 };
1477 assert_counters(api.core_ctx(), 0);
1478
1479 assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1480 let hop_limit = hop_limit.and_then(NonZeroU8::new);
1481 assert_eq!(api.set_unicast_hop_limit(&sock, hop_limit), None);
1482
1483 let remote_ip = ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip);
1484 assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1485 api.send_to(&sock, Some(remote_ip), Buf::new(IP_BODY.to_vec(), ..))
1486 .expect("send should succeed");
1487 let frames = api.core_ctx().take_frames();
1488 let (SendIpPacketMeta { device, src_ip, dst_ip, proto, mtu, ttl, .. }, data) =
1489 assert_matches!( &frames[..], [packet] => packet);
1490 assert_eq!(&data[..], &IP_BODY[..]);
1491 assert_eq!(*dst_ip, remote_ip.addr());
1492 assert_eq!(*src_ip, I::TEST_ADDRS.local_ip);
1493 if let Some(bound_dev) = bound_dev {
1494 assert_eq!(*device, bound_dev);
1495 }
1496 assert_eq!(*proto, <I as IpProtoExt>::Proto::from(PROTO));
1497 assert_eq!(*mtu, Mtu::max());
1498 assert_eq!(*ttl, hop_limit);
1499
1500 assert_counters(api.core_ctx(), 1);
1501 }
1502
1503 #[ip_test(I)]
1504 fn send_to_disallows_raw_protocol<I: IpExt + DualStackIpExt + TestIpExt + FilterIpExt>() {
1505 let mut api = new_raw_ip_socket_api::<I>();
1506 let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1507 assert_matches!(
1508 api.send_to(&sock, None, Buf::new(IP_BODY.to_vec(), ..)),
1509 Err(RawIpSocketSendToError::ProtocolRaw)
1510 );
1511 }
1512
1513 #[test]
1514 fn send_to_disallows_dualstack() {
1515 let mut api = new_raw_ip_socket_api::<Ipv6>();
1516 let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1517 let mapped_remote_ip = ZonedAddr::Unzoned(Ipv4::TEST_ADDRS.local_ip.to_ipv6_mapped());
1518 assert_matches!(
1519 api.send_to(&sock, Some(mapped_remote_ip), Buf::new(IP_BODY.to_vec(), ..)),
1520 Err(RawIpSocketSendToError::MappedRemoteIp)
1521 );
1522 }
1523
1524 #[test]
1527 fn icmpv6_send_to_generates_checksum() {
1528 let mut api = new_raw_ip_socket_api::<Ipv6>();
1529 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1530
1531 let icmp_body_with_checksum =
1534 new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1535 .as_ref()
1536 .to_vec();
1537 const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1538 let mut icmp_body_without_checksum = icmp_body_with_checksum.clone();
1539 assert_ne!(
1540 packet_formats::testutil::overwrite_icmpv6_checksum(
1541 icmp_body_without_checksum.as_mut(),
1542 CORRUPT_CHECKSUM,
1543 )
1544 .expect("parse should succeed"),
1545 CORRUPT_CHECKSUM
1546 );
1547
1548 let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1550 assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1551 api.send_to(&sock, Some(remote_ip), Buf::new(icmp_body_without_checksum.to_vec(), ..))
1552 .expect("send should succeed");
1553
1554 let frames = api.core_ctx().take_frames();
1556 let (_send_ip_packet_meta, data) = assert_matches!( &frames[..], [packet] => packet);
1557 assert_eq!(&data[..], icmp_body_with_checksum);
1558 }
1559
1560 #[test_case(Icmpv6MessageType::DestUnreachable.into(), 4; "header-too-short")]
1562 #[test_case(0, 8; "message-type-zero-not-supported")]
1563 fn icmpv6_send_to_invalid_body(message_type: u8, header_len: usize) {
1564 let mut api = new_raw_ip_socket_api::<Ipv6>();
1565 let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1566
1567 let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1568 assert_eq!(core_ctx.state.counters.tx_checksum_errors.get(), count);
1569 assert_eq!(sock.state().counters().tx_checksum_errors.get(), count);
1570 };
1571
1572 let mut body = vec![0; header_len];
1573 body[0] = message_type;
1574 assert_counters(api.core_ctx(), 0);
1575
1576 let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1577 assert_matches!(
1578 api.send_to(&sock, Some(remote_ip), Buf::new(body, ..)),
1579 Err(RawIpSocketSendToError::InvalidBody)
1580 );
1581
1582 assert_counters(api.core_ctx(), 1);
1583 }
1584
1585 #[ip_test(I)]
1586 #[test_case::test_matrix(
1587 [MarkDomain::Mark1, MarkDomain::Mark2],
1588 [None, Some(0), Some(1)]
1589 )]
1590 fn raw_ip_socket_marks<I: TestIpExt + DualStackIpExt + IpExt + FilterIpExt>(
1591 domain: MarkDomain,
1592 mark: Option<u32>,
1593 ) {
1594 let mut api = new_raw_ip_socket_api::<I>();
1595 let socket = api.create(RawIpSocketProtocol::Raw, Default::default());
1596
1597 assert_eq!(api.get_mark(&socket, domain), Mark(None));
1599
1600 let mark = Mark(mark);
1601 api.set_mark(&socket, domain, mark);
1603 assert_eq!(api.get_mark(&socket, domain), mark);
1604 }
1605}