use alloc::sync::Arc;
use core::fmt::{self, Debug};
use core::hash::Hash;
use core::num::NonZeroU64;
use derivative::Derivative;
use netstack3_base::sync::{DynDebugReferences, PrimaryRc, StrongRc};
use netstack3_base::{
Device, DeviceIdentifier, DeviceWithName, StrongDeviceIdentifier, WeakDeviceIdentifier,
};
use netstack3_filter as filter;
use crate::internal::base::{
DeviceClassMatcher as _, DeviceIdAndNameMatcher as _, DeviceLayerTypes, OriginTracker,
};
use crate::internal::ethernet::EthernetLinkDevice;
use crate::internal::loopback::{LoopbackDevice, LoopbackDeviceId, LoopbackWeakDeviceId};
use crate::internal::pure_ip::{PureIpDevice, PureIpDeviceId, PureIpWeakDeviceId};
use crate::internal::state::{BaseDeviceState, DeviceStateSpec, IpLinkDeviceState, WeakCookie};
#[derive(Derivative)]
#[derivative(Clone(bound = ""), Eq(bound = ""), PartialEq(bound = ""), Hash(bound = ""))]
#[allow(missing_docs)]
pub enum WeakDeviceId<BT: DeviceLayerTypes> {
Ethernet(EthernetWeakDeviceId<BT>),
Loopback(LoopbackWeakDeviceId<BT>),
PureIp(PureIpWeakDeviceId<BT>),
}
impl<BT: DeviceLayerTypes> PartialEq<DeviceId<BT>> for WeakDeviceId<BT> {
fn eq(&self, other: &DeviceId<BT>) -> bool {
<DeviceId<BT> as PartialEq<WeakDeviceId<BT>>>::eq(other, self)
}
}
impl<BT: DeviceLayerTypes> From<EthernetWeakDeviceId<BT>> for WeakDeviceId<BT> {
fn from(id: EthernetWeakDeviceId<BT>) -> WeakDeviceId<BT> {
WeakDeviceId::Ethernet(id)
}
}
impl<BT: DeviceLayerTypes> From<LoopbackWeakDeviceId<BT>> for WeakDeviceId<BT> {
fn from(id: LoopbackWeakDeviceId<BT>) -> WeakDeviceId<BT> {
WeakDeviceId::Loopback(id)
}
}
impl<BT: DeviceLayerTypes> From<PureIpWeakDeviceId<BT>> for WeakDeviceId<BT> {
fn from(id: PureIpWeakDeviceId<BT>) -> WeakDeviceId<BT> {
WeakDeviceId::PureIp(id)
}
}
impl<BT: DeviceLayerTypes> WeakDeviceId<BT> {
pub fn upgrade(&self) -> Option<DeviceId<BT>> {
for_any_device_id!(WeakDeviceId, self, id => id.upgrade().map(Into::into))
}
pub fn debug_references(&self) -> DynDebugReferences {
for_any_device_id!(
WeakDeviceId,
self,
BaseWeakDeviceId { cookie } => cookie.weak_ref.debug_references().into_dyn()
)
}
pub fn bindings_id(&self) -> &BT::DeviceIdentifier {
for_any_device_id!(WeakDeviceId, self, id => id.bindings_id())
}
}
impl<BT: DeviceLayerTypes> DeviceIdentifier for WeakDeviceId<BT> {
fn is_loopback(&self) -> bool {
match self {
WeakDeviceId::Loopback(_) => true,
WeakDeviceId::Ethernet(_) | WeakDeviceId::PureIp(_) => false,
}
}
}
impl<BT: DeviceLayerTypes> WeakDeviceIdentifier for WeakDeviceId<BT> {
type Strong = DeviceId<BT>;
fn upgrade(&self) -> Option<Self::Strong> {
self.upgrade()
}
}
impl<BT: DeviceLayerTypes> Debug for WeakDeviceId<BT> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for_any_device_id!(WeakDeviceId, self, id => Debug::fmt(id, f))
}
}
#[derive(Derivative)]
#[derivative(Eq(bound = ""), PartialEq(bound = ""), Hash(bound = ""))]
#[allow(missing_docs)]
pub enum DeviceId<BT: DeviceLayerTypes> {
Ethernet(EthernetDeviceId<BT>),
Loopback(LoopbackDeviceId<BT>),
PureIp(PureIpDeviceId<BT>),
}
#[macro_export]
macro_rules! for_any_device_id {
($device_id_enum_type:ident, $device_id:expr, $variable:pat => $expression:expr) => {
match $device_id {
$device_id_enum_type::Loopback($variable) => $expression,
$device_id_enum_type::Ethernet($variable) => $expression,
$device_id_enum_type::PureIp($variable) => $expression,
}
};
(
$device_id_enum_type:ident,
$provider_trait:ident,
$type_param:ident,
$device_id:expr, $variable:pat => $expression:expr) => {
match $device_id {
$device_id_enum_type::Loopback($variable) => {
type $type_param = <() as $provider_trait>::Loopback;
$expression
}
$device_id_enum_type::Ethernet($variable) => {
type $type_param = <() as $provider_trait>::Ethernet;
$expression
}
$device_id_enum_type::PureIp($variable) => {
type $type_param = <() as $provider_trait>::PureIp;
$expression
}
}
};
}
pub(crate) use crate::for_any_device_id;
pub trait DeviceProvider {
type Ethernet: Device;
type Loopback: Device;
type PureIp: Device;
}
impl DeviceProvider for () {
type Ethernet = EthernetLinkDevice;
type Loopback = LoopbackDevice;
type PureIp = PureIpDevice;
}
impl<BT: DeviceLayerTypes> Clone for DeviceId<BT> {
#[cfg_attr(feature = "instrumented", track_caller)]
fn clone(&self) -> Self {
for_any_device_id!(DeviceId, self, id => id.clone().into())
}
}
impl<BT: DeviceLayerTypes> PartialEq<WeakDeviceId<BT>> for DeviceId<BT> {
fn eq(&self, other: &WeakDeviceId<BT>) -> bool {
match (self, other) {
(DeviceId::Ethernet(strong), WeakDeviceId::Ethernet(weak)) => strong == weak,
(DeviceId::Loopback(strong), WeakDeviceId::Loopback(weak)) => strong == weak,
(DeviceId::PureIp(strong), WeakDeviceId::PureIp(weak)) => strong == weak,
(DeviceId::Ethernet(_), _) | (DeviceId::Loopback(_), _) | (DeviceId::PureIp(_), _) => {
false
}
}
}
}
impl<BT: DeviceLayerTypes> PartialOrd for DeviceId<BT> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<BT: DeviceLayerTypes> Ord for DeviceId<BT> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
fn discriminant<BT: DeviceLayerTypes>(d: &DeviceId<BT>) -> u8 {
match d {
DeviceId::Ethernet(_) => 0,
DeviceId::Loopback(_) => 1,
DeviceId::PureIp(_) => 2,
}
}
match (self, other) {
(DeviceId::Ethernet(me), DeviceId::Ethernet(other)) => me.cmp(other),
(DeviceId::Loopback(me), DeviceId::Loopback(other)) => me.cmp(other),
(DeviceId::PureIp(me), DeviceId::PureIp(other)) => me.cmp(other),
(me @ DeviceId::Ethernet(_), other)
| (me @ DeviceId::Loopback(_), other)
| (me @ DeviceId::PureIp(_), other) => discriminant(me).cmp(&discriminant(other)),
}
}
}
impl<BT: DeviceLayerTypes> From<EthernetDeviceId<BT>> for DeviceId<BT> {
fn from(id: EthernetDeviceId<BT>) -> DeviceId<BT> {
DeviceId::Ethernet(id)
}
}
impl<BT: DeviceLayerTypes> From<LoopbackDeviceId<BT>> for DeviceId<BT> {
fn from(id: LoopbackDeviceId<BT>) -> DeviceId<BT> {
DeviceId::Loopback(id)
}
}
impl<BT: DeviceLayerTypes> From<PureIpDeviceId<BT>> for DeviceId<BT> {
fn from(id: PureIpDeviceId<BT>) -> DeviceId<BT> {
DeviceId::PureIp(id)
}
}
impl<BT: DeviceLayerTypes> DeviceId<BT> {
pub fn downgrade(&self) -> WeakDeviceId<BT> {
for_any_device_id!(DeviceId, self, id => id.downgrade().into())
}
pub fn bindings_id(&self) -> &BT::DeviceIdentifier {
for_any_device_id!(DeviceId, self, id => id.bindings_id())
}
}
impl<BT: DeviceLayerTypes> DeviceIdentifier for DeviceId<BT> {
fn is_loopback(&self) -> bool {
match self {
DeviceId::Loopback(_) => true,
DeviceId::Ethernet(_) | DeviceId::PureIp(_) => false,
}
}
}
impl<BT: DeviceLayerTypes> StrongDeviceIdentifier for DeviceId<BT> {
type Weak = WeakDeviceId<BT>;
fn downgrade(&self) -> Self::Weak {
self.downgrade()
}
}
impl<BT: DeviceLayerTypes> Debug for DeviceId<BT> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for_any_device_id!(DeviceId, self, id => Debug::fmt(id, f))
}
}
impl<BT: DeviceLayerTypes> DeviceWithName for DeviceId<BT> {
fn name_matches(&self, name: &str) -> bool {
self.bindings_id().name_matches(name)
}
}
impl<BT: DeviceLayerTypes> filter::InterfaceProperties<BT::DeviceClass> for DeviceId<BT> {
fn id_matches(&self, id: &NonZeroU64) -> bool {
self.bindings_id().id_matches(id)
}
fn device_class_matches(&self, device_class: &BT::DeviceClass) -> bool {
for_any_device_id!(
DeviceId,
self,
id => id.external_state().device_class_matches(device_class)
)
}
}
#[derive(Derivative)]
#[derivative(Clone(bound = ""))]
pub struct BaseWeakDeviceId<T: DeviceStateSpec, BT: DeviceLayerTypes> {
cookie: Arc<WeakCookie<T, BT>>,
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> PartialEq for BaseWeakDeviceId<T, BT> {
fn eq(&self, other: &Self) -> bool {
self.cookie.weak_ref.ptr_eq(&other.cookie.weak_ref)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Eq for BaseWeakDeviceId<T, BT> {}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Hash for BaseWeakDeviceId<T, BT> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.cookie.weak_ref.hash(state)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> PartialEq<BaseDeviceId<T, BT>>
for BaseWeakDeviceId<T, BT>
{
fn eq(&self, other: &BaseDeviceId<T, BT>) -> bool {
<BaseDeviceId<T, BT> as PartialEq<BaseWeakDeviceId<T, BT>>>::eq(other, self)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Debug for BaseWeakDeviceId<T, BT> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { cookie } = self;
write!(f, "Weak{}({:?})", T::DEBUG_TYPE, &cookie.bindings_id)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> DeviceIdentifier for BaseWeakDeviceId<T, BT> {
fn is_loopback(&self) -> bool {
T::IS_LOOPBACK
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> WeakDeviceIdentifier for BaseWeakDeviceId<T, BT> {
type Strong = BaseDeviceId<T, BT>;
fn upgrade(&self) -> Option<Self::Strong> {
self.upgrade()
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> BaseWeakDeviceId<T, BT> {
pub fn upgrade(&self) -> Option<BaseDeviceId<T, BT>> {
let Self { cookie } = self;
cookie.weak_ref.upgrade().map(|rc| BaseDeviceId { rc })
}
pub fn bindings_id(&self) -> &BT::DeviceIdentifier {
&self.cookie.bindings_id
}
}
#[derive(Derivative)]
#[derivative(Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
pub struct BaseDeviceId<T: DeviceStateSpec, BT: DeviceLayerTypes> {
rc: StrongRc<BaseDeviceState<T, BT>>,
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Clone for BaseDeviceId<T, BT> {
#[cfg_attr(feature = "instrumented", track_caller)]
fn clone(&self) -> Self {
let Self { rc } = self;
Self { rc: StrongRc::clone(rc) }
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> PartialEq<BaseWeakDeviceId<T, BT>>
for BaseDeviceId<T, BT>
{
fn eq(&self, BaseWeakDeviceId { cookie }: &BaseWeakDeviceId<T, BT>) -> bool {
let Self { rc: me_rc } = self;
StrongRc::weak_ptr_eq(me_rc, &cookie.weak_ref)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> PartialEq<BasePrimaryDeviceId<T, BT>>
for BaseDeviceId<T, BT>
{
fn eq(&self, BasePrimaryDeviceId { rc: other_rc }: &BasePrimaryDeviceId<T, BT>) -> bool {
let Self { rc: me_rc } = self;
PrimaryRc::ptr_eq(other_rc, me_rc)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> PartialOrd for BaseDeviceId<T, BT> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Ord for BaseDeviceId<T, BT> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let Self { rc: me } = self;
let Self { rc: other } = other;
StrongRc::ptr_cmp(me, other)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Debug for BaseDeviceId<T, BT> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { rc } = self;
write!(f, "{}({:?})", T::DEBUG_TYPE, &rc.weak_cookie.bindings_id)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> DeviceIdentifier for BaseDeviceId<T, BT> {
fn is_loopback(&self) -> bool {
T::IS_LOOPBACK
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> StrongDeviceIdentifier for BaseDeviceId<T, BT> {
type Weak = BaseWeakDeviceId<T, BT>;
fn downgrade(&self) -> Self::Weak {
self.downgrade()
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> BaseDeviceId<T, BT> {
pub fn device_state(&self, tracker: &OriginTracker) -> &IpLinkDeviceState<T, BT> {
debug_assert_eq!(tracker, &self.rc.ip.origin);
&self.rc.ip
}
pub fn external_state(&self) -> &T::External<BT> {
&self.rc.external_state
}
pub fn bindings_id(&self) -> &BT::DeviceIdentifier {
&self.rc.weak_cookie.bindings_id
}
pub fn downgrade(&self) -> BaseWeakDeviceId<T, BT> {
let Self { rc } = self;
BaseWeakDeviceId { cookie: Arc::clone(&rc.weak_cookie) }
}
}
pub struct BasePrimaryDeviceId<T: DeviceStateSpec, BT: DeviceLayerTypes> {
rc: PrimaryRc<BaseDeviceState<T, BT>>,
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> Debug for BasePrimaryDeviceId<T, BT> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { rc } = self;
write!(f, "Primary{}({:?})", T::DEBUG_TYPE, &rc.weak_cookie.bindings_id)
}
}
impl<T: DeviceStateSpec, BT: DeviceLayerTypes> BasePrimaryDeviceId<T, BT> {
#[cfg_attr(feature = "instrumented", track_caller)]
pub fn clone_strong(&self) -> BaseDeviceId<T, BT> {
let Self { rc } = self;
BaseDeviceId { rc: PrimaryRc::clone_strong(rc) }
}
pub(crate) fn new<F: FnOnce(BaseWeakDeviceId<T, BT>) -> IpLinkDeviceState<T, BT>>(
ip: F,
external_state: T::External<BT>,
bindings_id: BT::DeviceIdentifier,
) -> Self {
Self {
rc: PrimaryRc::new_cyclic(move |weak_ref| {
let weak_cookie = Arc::new(WeakCookie { bindings_id, weak_ref });
let ip = ip(BaseWeakDeviceId { cookie: Arc::clone(&weak_cookie) });
BaseDeviceState { ip, external_state, weak_cookie }
}),
}
}
pub(crate) fn into_inner(self) -> PrimaryRc<BaseDeviceState<T, BT>> {
self.rc
}
}
pub type EthernetDeviceId<BT> = BaseDeviceId<EthernetLinkDevice, BT>;
pub type EthernetWeakDeviceId<BT> = BaseWeakDeviceId<EthernetLinkDevice, BT>;
pub type EthernetPrimaryDeviceId<BT> = BasePrimaryDeviceId<EthernetLinkDevice, BT>;
#[cfg(any(test, feature = "testutils"))]
mod testutil {
use super::*;
impl<BT: DeviceLayerTypes> TryFrom<DeviceId<BT>> for EthernetDeviceId<BT> {
type Error = DeviceId<BT>;
fn try_from(id: DeviceId<BT>) -> Result<EthernetDeviceId<BT>, DeviceId<BT>> {
match id {
DeviceId::Ethernet(id) => Ok(id),
DeviceId::Loopback(_) | DeviceId::PureIp(_) => Err(id),
}
}
}
impl<BT: DeviceLayerTypes> DeviceId<BT> {
pub fn unwrap_ethernet(self) -> EthernetDeviceId<BT> {
assert_matches::assert_matches!(self, DeviceId::Ethernet(e) => e)
}
}
}