use alloc::collections::HashMap;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::num::NonZeroU64;
use derivative::Derivative;
use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
use net_types::ethernet::Mac;
use net_types::ip::{Ip, IpVersion, Ipv4, Ipv6};
use netstack3_base::sync::RwLock;
use netstack3_base::{
Counter, Device, DeviceIdContext, HandleableTimer, Inspectable, Inspector, InstantContext,
ReferenceNotifiers, TimerBindingsTypes, TimerHandler,
};
use netstack3_filter::FilterBindingsTypes;
use netstack3_ip::nud::{LinkResolutionContext, NudCounters};
use packet::Buf;
use crate::internal::arp::ArpCounters;
use crate::internal::ethernet::{EthernetLinkDevice, EthernetTimerId};
use crate::internal::id::{
BaseDeviceId, BasePrimaryDeviceId, DeviceId, EthernetDeviceId, EthernetPrimaryDeviceId,
EthernetWeakDeviceId,
};
use crate::internal::loopback::{LoopbackDeviceId, LoopbackPrimaryDeviceId};
use crate::internal::pure_ip::{PureIpDeviceId, PureIpPrimaryDeviceId};
use crate::internal::queue::rx::ReceiveQueueBindingsContext;
use crate::internal::queue::tx::TransmitQueueBindingsContext;
use crate::internal::socket::{self, HeldSockets};
use crate::internal::state::DeviceStateSpec;
pub struct DevicesIter<'s, BT: DeviceLayerTypes> {
pub(super) ethernet:
alloc::collections::hash_map::Values<'s, EthernetDeviceId<BT>, EthernetPrimaryDeviceId<BT>>,
pub(super) pure_ip:
alloc::collections::hash_map::Values<'s, PureIpDeviceId<BT>, PureIpPrimaryDeviceId<BT>>,
pub(super) loopback: core::option::Iter<'s, LoopbackPrimaryDeviceId<BT>>,
}
impl<'s, BT: DeviceLayerTypes> Iterator for DevicesIter<'s, BT> {
type Item = DeviceId<BT>;
fn next(&mut self) -> Option<Self::Item> {
let Self { ethernet, pure_ip, loopback } = self;
ethernet
.map(|primary| primary.clone_strong().into())
.chain(pure_ip.map(|primary| primary.clone_strong().into()))
.chain(loopback.map(|primary| primary.clone_strong().into()))
.next()
}
}
#[allow(missing_docs)]
pub enum Ipv6DeviceLinkLayerAddr {
Mac(Mac),
}
impl AsRef<[u8]> for Ipv6DeviceLinkLayerAddr {
fn as_ref(&self) -> &[u8] {
match self {
Ipv6DeviceLinkLayerAddr::Mac(a) => a.as_ref(),
}
}
}
#[derive(Derivative)]
#[derivative(
Clone(bound = ""),
Eq(bound = ""),
PartialEq(bound = ""),
Hash(bound = ""),
Debug(bound = "")
)]
pub struct DeviceLayerTimerId<BT: DeviceLayerTypes>(DeviceLayerTimerIdInner<BT>);
#[derive(Derivative)]
#[derivative(
Clone(bound = ""),
Eq(bound = ""),
PartialEq(bound = ""),
Hash(bound = ""),
Debug(bound = "")
)]
#[allow(missing_docs)]
enum DeviceLayerTimerIdInner<BT: DeviceLayerTypes> {
Ethernet(EthernetTimerId<EthernetWeakDeviceId<BT>>),
}
impl<BT: DeviceLayerTypes> From<EthernetTimerId<EthernetWeakDeviceId<BT>>>
for DeviceLayerTimerId<BT>
{
fn from(id: EthernetTimerId<EthernetWeakDeviceId<BT>>) -> DeviceLayerTimerId<BT> {
DeviceLayerTimerId(DeviceLayerTimerIdInner::Ethernet(id))
}
}
impl<CC, BT> HandleableTimer<CC, BT> for DeviceLayerTimerId<BT>
where
BT: DeviceLayerTypes,
CC: TimerHandler<BT, EthernetTimerId<EthernetWeakDeviceId<BT>>>,
{
fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BT, timer: BT::UniqueTimerId) {
let Self(id) = self;
match id {
DeviceLayerTimerIdInner::Ethernet(id) => core_ctx.handle_timer(bindings_ctx, id, timer),
}
}
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct Devices<BT: DeviceLayerTypes> {
pub ethernet: HashMap<EthernetDeviceId<BT>, EthernetPrimaryDeviceId<BT>>,
pub pure_ip: HashMap<PureIpDeviceId<BT>, PureIpPrimaryDeviceId<BT>>,
pub loopback: Option<LoopbackPrimaryDeviceId<BT>>,
}
impl<BT: DeviceLayerTypes> Devices<BT> {
pub fn iter(&self) -> DevicesIter<'_, BT> {
let Self { ethernet, pure_ip, loopback } = self;
DevicesIter {
ethernet: ethernet.values(),
pure_ip: pure_ip.values(),
loopback: loopback.iter(),
}
}
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct DeviceLayerState<BT: DeviceLayerTypes> {
devices: RwLock<Devices<BT>>,
pub origin: OriginTracker,
pub shared_sockets: HeldSockets<BT>,
pub counters: DeviceCounters,
pub ethernet_counters: EthernetDeviceCounters,
pub pure_ip_counters: PureIpDeviceCounters,
pub nud_v4_counters: NudCounters<Ipv4>,
pub nud_v6_counters: NudCounters<Ipv6>,
pub arp_counters: ArpCounters,
}
impl<BT: DeviceLayerTypes> DeviceLayerState<BT> {
pub fn nud_counters<I: Ip>(&self) -> &NudCounters<I> {
I::map_ip((), |()| &self.nud_v4_counters, |()| &self.nud_v6_counters)
}
}
impl<BT: DeviceLayerTypes> OrderedLockAccess<Devices<BT>> for DeviceLayerState<BT> {
type Lock = RwLock<Devices<BT>>;
fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
OrderedLockRef::new(&self.devices)
}
}
#[derive(Default)]
pub struct EthernetDeviceCounters {
pub recv_unsupported_ethertype: Counter,
pub recv_no_ethertype: Counter,
}
impl Inspectable for EthernetDeviceCounters {
fn record<I: Inspector>(&self, inspector: &mut I) {
inspector.record_child("Ethernet", |inspector| {
let Self { recv_no_ethertype, recv_unsupported_ethertype } = self;
inspector.record_child("Rx", |inspector| {
inspector.record_counter("NoEthertype", recv_no_ethertype);
inspector.record_counter("UnsupportedEthertype", recv_unsupported_ethertype);
});
})
}
}
#[derive(Default)]
pub struct PureIpDeviceCounters {}
impl Inspectable for PureIpDeviceCounters {
fn record<I: Inspector>(&self, _inspector: &mut I) {}
}
#[derive(Default)]
pub struct DeviceCounters {
pub send_total_frames: Counter,
pub send_frame: Counter,
pub send_queue_full: Counter,
pub send_serialize_error: Counter,
pub recv_frame: Counter,
pub recv_parse_error: Counter,
pub recv_ipv4_delivered: Counter,
pub recv_ipv6_delivered: Counter,
pub send_ipv4_frame: Counter,
pub send_ipv6_frame: Counter,
pub send_dropped_no_queue: Counter,
pub send_dropped_dequeue: Counter,
}
impl DeviceCounters {
pub fn send_frame<I: Ip>(&self) -> &Counter {
match I::VERSION {
IpVersion::V4 => &self.send_ipv4_frame,
IpVersion::V6 => &self.send_ipv6_frame,
}
}
}
impl Inspectable for DeviceCounters {
fn record<I: Inspector>(&self, inspector: &mut I) {
let Self {
recv_frame,
recv_ipv4_delivered,
recv_ipv6_delivered,
recv_parse_error,
send_dropped_no_queue,
send_frame,
send_ipv4_frame,
send_ipv6_frame,
send_queue_full,
send_serialize_error,
send_total_frames,
send_dropped_dequeue,
} = self;
inspector.record_child("Rx", |inspector| {
inspector.record_counter("TotalFrames", recv_frame);
inspector.record_counter("Malformed", recv_parse_error);
inspector.record_counter("Ipv4Delivered", recv_ipv4_delivered);
inspector.record_counter("Ipv6Delivered", recv_ipv6_delivered);
});
inspector.record_child("Tx", |inspector| {
inspector.record_counter("TotalFrames", send_total_frames);
inspector.record_counter("Sent", send_frame);
inspector.record_counter("SendIpv4Frame", send_ipv4_frame);
inspector.record_counter("SendIpv6Frame", send_ipv6_frame);
inspector.record_counter("NoQueue", send_dropped_no_queue);
inspector.record_counter("QueueFull", send_queue_full);
inspector.record_counter("SerializeError", send_serialize_error);
inspector.record_counter("DequeueDrop", send_dropped_dequeue);
});
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct OriginTracker(#[cfg(debug_assertions)] u64);
impl Default for OriginTracker {
fn default() -> Self {
Self::new()
}
}
impl OriginTracker {
#[cfg_attr(not(debug_assertions), inline)]
fn new() -> Self {
Self(
#[cfg(debug_assertions)]
{
static COUNTER: core::sync::atomic::AtomicU64 =
core::sync::atomic::AtomicU64::new(0);
COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
},
)
}
}
pub trait OriginTrackerContext {
fn origin_tracker(&mut self) -> OriginTracker;
}
pub trait DeviceCollectionContext<D: Device + DeviceStateSpec, BT: DeviceLayerTypes>:
DeviceIdContext<D>
{
fn insert(&mut self, device: BasePrimaryDeviceId<D, BT>);
fn remove(&mut self, device: &BaseDeviceId<D, BT>) -> Option<BasePrimaryDeviceId<D, BT>>;
}
pub trait DeviceReceiveFrameSpec {
type FrameMetadata<D>;
}
pub trait DeviceLayerStateTypes: InstantContext + FilterBindingsTypes {
type LoopbackDeviceState: Send + Sync + DeviceClassMatcher<Self::DeviceClass>;
type EthernetDeviceState: Send + Sync + DeviceClassMatcher<Self::DeviceClass>;
type PureIpDeviceState: Send + Sync + DeviceClassMatcher<Self::DeviceClass>;
type DeviceIdentifier: Send + Sync + Debug + Display + DeviceIdAndNameMatcher;
}
pub trait DeviceClassMatcher<DeviceClass> {
fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
}
pub trait DeviceIdAndNameMatcher {
fn id_matches(&self, id: &NonZeroU64) -> bool;
fn name_matches(&self, name: &str) -> bool;
}
pub trait DeviceLayerTypes:
DeviceLayerStateTypes
+ socket::DeviceSocketTypes
+ LinkResolutionContext<EthernetLinkDevice>
+ TimerBindingsTypes
+ ReferenceNotifiers
+ 'static
{
}
impl<
BC: DeviceLayerStateTypes
+ socket::DeviceSocketTypes
+ LinkResolutionContext<EthernetLinkDevice>
+ TimerBindingsTypes
+ ReferenceNotifiers
+ 'static,
> DeviceLayerTypes for BC
{
}
pub trait DeviceLayerEventDispatcher:
DeviceLayerTypes
+ ReceiveQueueBindingsContext<LoopbackDeviceId<Self>>
+ TransmitQueueBindingsContext<EthernetDeviceId<Self>>
+ TransmitQueueBindingsContext<LoopbackDeviceId<Self>>
+ TransmitQueueBindingsContext<PureIpDeviceId<Self>>
+ Sized
{
type DequeueContext;
fn send_ethernet_frame(
&mut self,
device: &EthernetDeviceId<Self>,
frame: Buf<Vec<u8>>,
dequeue_context: Option<&mut Self::DequeueContext>,
) -> Result<(), DeviceSendFrameError>;
fn send_ip_packet(
&mut self,
device: &PureIpDeviceId<Self>,
packet: Buf<Vec<u8>>,
ip_version: IpVersion,
dequeue_context: Option<&mut Self::DequeueContext>,
) -> Result<(), DeviceSendFrameError>;
}
#[derive(Debug, PartialEq, Eq)]
pub enum DeviceSendFrameError {
NoBuffers,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn origin_tracker() {
let tracker = OriginTracker::new();
if cfg!(debug_assertions) {
assert_ne!(tracker, OriginTracker::new());
} else {
assert_eq!(tracker, OriginTracker::new());
}
assert_eq!(tracker.clone(), tracker);
}
}