1use core::convert::Infallible as Never;
6use core::fmt::Debug;
7use core::num::NonZeroU32;
8
9use diagnostics_traits::{Inspectable, Inspector};
10use net_types::Witness;
11use net_types::ip::{GenericOverIp, Ip, Ipv4, Ipv4SourceAddr, Ipv6, Ipv6SourceAddr, Mtu};
12use packet_formats::icmp::{
13 IcmpDestUnreachable, Icmpv4DestUnreachableCode, Icmpv4ParameterProblemCode, Icmpv4RedirectCode,
14 Icmpv4TimeExceededCode, Icmpv6DestUnreachableCode, Icmpv6ParameterProblemCode,
15 Icmpv6TimeExceededCode,
16};
17use packet_formats::ip::IpProtoExt;
18use strum::{EnumCount as _, IntoEnumIterator as _};
19use strum_macros::{EnumCount, EnumIter};
20
21pub trait BroadcastIpExt: Ip {
23 type BroadcastMarker: Debug + Copy + Clone + PartialEq + Eq + Send + Sync + 'static;
26}
27
28impl BroadcastIpExt for Ipv4 {
29 type BroadcastMarker = ();
30}
31
32impl BroadcastIpExt for Ipv6 {
33 type BroadcastMarker = Never;
34}
35
36#[derive(GenericOverIp)]
39#[generic_over_ip(I, Ip)]
40pub struct WrapBroadcastMarker<I: BroadcastIpExt>(pub I::BroadcastMarker);
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46pub struct Mms(NonZeroU32);
47
48impl Mms {
49 pub fn from_mtu<I: IpExt>(mtu: Mtu, options_size: u32) -> Option<Self> {
51 NonZeroU32::new(mtu.get().saturating_sub(I::IP_HEADER_LENGTH.get() + options_size))
52 .map(|mms| Self(mms.min(I::IP_MAX_PAYLOAD_LENGTH)))
53 }
54
55 pub fn get(&self) -> NonZeroU32 {
57 let Self(mms) = *self;
58 mms
59 }
60}
61
62#[derive(Copy, Clone, Debug, PartialEq)]
67#[allow(missing_docs)]
68pub enum Icmpv4ErrorCode {
69 DestUnreachable(Icmpv4DestUnreachableCode, IcmpDestUnreachable),
70 Redirect(Icmpv4RedirectCode),
71 TimeExceeded(Icmpv4TimeExceededCode),
72 ParameterProblem(Icmpv4ParameterProblemCode),
73}
74
75impl<I: IcmpIpExt> GenericOverIp<I> for Icmpv4ErrorCode {
76 type Type = I::ErrorCode;
77}
78
79#[derive(Copy, Clone, Debug, PartialEq)]
84#[allow(missing_docs)]
85pub enum Icmpv6ErrorCode {
86 DestUnreachable(Icmpv6DestUnreachableCode),
87 PacketTooBig(Mtu),
88 TimeExceeded(Icmpv6TimeExceededCode),
89 ParameterProblem(Icmpv6ParameterProblemCode),
90}
91
92impl<I: IcmpIpExt> GenericOverIp<I> for Icmpv6ErrorCode {
93 type Type = I::ErrorCode;
94}
95
96#[derive(Debug, Clone, Copy)]
98pub enum IcmpErrorCode {
99 V4(Icmpv4ErrorCode),
101 V6(Icmpv6ErrorCode),
103}
104
105impl From<Icmpv4ErrorCode> for IcmpErrorCode {
106 fn from(v4_err: Icmpv4ErrorCode) -> Self {
107 IcmpErrorCode::V4(v4_err)
108 }
109}
110
111impl From<Icmpv6ErrorCode> for IcmpErrorCode {
112 fn from(v6_err: Icmpv6ErrorCode) -> Self {
113 IcmpErrorCode::V6(v6_err)
114 }
115}
116
117pub trait IcmpIpExt: packet_formats::ip::IpExt + packet_formats::icmp::IcmpIpExt {
119 type ErrorCode: Debug
122 + Copy
123 + PartialEq
124 + GenericOverIp<Self, Type = Self::ErrorCode>
125 + GenericOverIp<Ipv4, Type = Icmpv4ErrorCode>
126 + GenericOverIp<Ipv6, Type = Icmpv6ErrorCode>
127 + Into<IcmpErrorCode>;
128}
129
130impl IcmpIpExt for Ipv4 {
131 type ErrorCode = Icmpv4ErrorCode;
132}
133
134impl IcmpIpExt for Ipv6 {
135 type ErrorCode = Icmpv6ErrorCode;
136}
137
138pub trait IpTypesIpExt: packet_formats::ip::IpExt {
140 type BroadcastMarker: Debug + Copy + Clone + PartialEq + Eq;
143}
144
145impl IpTypesIpExt for Ipv4 {
146 type BroadcastMarker = ();
147}
148
149impl IpTypesIpExt for Ipv6 {
150 type BroadcastMarker = Never;
151}
152
153pub trait IpExt: packet_formats::ip::IpExt + IcmpIpExt + BroadcastIpExt + IpProtoExt {
155 type RecvSrcAddr: Witness<Self::Addr> + Copy + Clone;
160 const IP_HEADER_LENGTH: NonZeroU32;
162 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32;
164}
165
166impl IpExt for Ipv4 {
167 type RecvSrcAddr = Ipv4SourceAddr;
168 const IP_HEADER_LENGTH: NonZeroU32 =
169 NonZeroU32::new(packet_formats::ipv4::HDR_PREFIX_LEN as u32).unwrap();
170 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32 =
171 NonZeroU32::new(u16::MAX as u32 - Self::IP_HEADER_LENGTH.get()).unwrap();
172}
173
174impl IpExt for Ipv6 {
175 type RecvSrcAddr = Ipv6SourceAddr;
176 const IP_HEADER_LENGTH: NonZeroU32 =
177 NonZeroU32::new(packet_formats::ipv6::IPV6_FIXED_HDR_LEN as u32).unwrap();
178 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32 = NonZeroU32::new(u16::MAX as u32).unwrap();
179}
180
181#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
186pub struct Mark(pub Option<u32>);
187
188impl From<Option<u32>> for Mark {
189 fn from(m: Option<u32>) -> Self {
190 Self(m)
191 }
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumCount, EnumIter)]
196pub enum MarkDomain {
197 Mark1,
199 Mark2,
201}
202
203const MARK_DOMAINS: usize = MarkDomain::COUNT;
204
205#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
207pub struct MarkStorage<T>([T; MARK_DOMAINS]);
208
209impl<T> MarkStorage<T> {
210 pub fn new<U, IntoIter>(iter: IntoIter) -> Self
217 where
218 IntoIter: IntoIterator<Item = (MarkDomain, U)>,
219 T: From<Option<U>> + Copy,
220 {
221 let mut storage = MarkStorage([None.into(); MARK_DOMAINS]);
222 for (domain, value) in iter.into_iter() {
223 *storage.get_mut(domain) = Some(value).into();
224 }
225 storage
226 }
227
228 fn domain_as_index(domain: MarkDomain) -> usize {
229 match domain {
230 MarkDomain::Mark1 => 0,
231 MarkDomain::Mark2 => 1,
232 }
233 }
234
235 pub fn get(&self, domain: MarkDomain) -> &T {
237 let Self(inner) = self;
238 &inner[Self::domain_as_index(domain)]
239 }
240
241 pub fn get_mut(&mut self, domain: MarkDomain) -> &mut T {
243 let Self(inner) = self;
244 &mut inner[Self::domain_as_index(domain)]
245 }
246
247 pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &T)> {
249 let Self(inner) = self;
250 MarkDomain::iter().map(move |domain| (domain, &inner[Self::domain_as_index(domain)]))
251 }
252
253 pub fn zip_with<'a, U>(
255 &'a self,
256 MarkStorage(other): &'a MarkStorage<U>,
257 ) -> impl Iterator<Item = (MarkDomain, &'a T, &'a U)> + 'a {
258 let Self(this) = self;
259 MarkDomain::iter().zip(this.iter().zip(other.iter())).map(|(d, (t, u))| (d, t, u))
260 }
261}
262
263pub type Marks = MarkStorage<Mark>;
265
266impl Marks {
267 pub const UNMARKED: Self = MarkStorage([Mark(None), Mark(None)]);
269}
270
271impl Inspectable for Marks {
272 fn record<I: Inspector>(&self, inspector: &mut I) {
273 for (domain, Mark(mark)) in self.iter() {
274 if let Some(mark) = mark {
275 let domain_name = match domain {
276 MarkDomain::Mark1 => "Mark1",
277 MarkDomain::Mark2 => "Mark2",
278 };
279 inspector.record_uint(domain_name, *mark);
280 }
281 }
282 }
283}