Skip to main content

netstack3_ip/
fragmentation.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! IP fragmentation support.
6
7use core::borrow::Borrow;
8use core::fmt::Debug;
9
10use alloc::vec::Vec;
11
12use explicit::UnreachableExt;
13use net_types::ip::{GenericOverIp, Ip, IpInvariant, Ipv4, Ipv6, Mtu};
14use netstack3_base::{Counter, RngContext, Uninstantiable};
15use netstack3_filter::ForwardedPacket;
16use packet::{
17    Buf, BufferMut, EmptyBuf, FragmentedBuffer as _, InnerPacketBuilder as _, Nested,
18    PacketBuilder, PacketConstraints, ParsablePacket, SerializeError, Serializer,
19};
20use packet_formats::ip::FragmentOffset;
21use packet_formats::ipv4::options::Ipv4Option;
22use packet_formats::ipv4::{
23    Ipv4Header as _, Ipv4PacketBuilder, Ipv4PacketBuilderWithOptions, Ipv4PacketRaw,
24};
25use packet_formats::ipv6::{
26    Ipv6PacketBuilder, Ipv6PacketBuilderBeforeFragment, Ipv6PacketBuilderWithFragmentHeader,
27};
28use rand::Rng;
29
30/// The maximum fragment offset that can be expressed in both IPv4 and IPv6
31/// headers. The maximum transmissible body is this value plus the maximum bytes
32/// transmitted in the last fragment.
33// We have 13 bits to express an 8-byte multiple offset.
34const MAX_FRAGMENT_OFFSET: usize = ((1 << 13) - 1) * 8;
35
36pub trait FragmentationIpExt:
37    packet_formats::ip::IpExt<PacketBuilder: AsFragmentableIpPacketBuilder<Self>>
38{
39    /// The IP packet builder for a forwarded packet.
40    type ForwardedFragmentBuilder: FragmentableIpPacketBuilder<Self>;
41    /// An identifier generated at fragmentation time.
42    type FragmentationId: Copy + Debug;
43}
44
45impl FragmentationIpExt for Ipv4 {
46    type ForwardedFragmentBuilder = ForwardedIpv4PacketBuilder;
47    type FragmentationId = ();
48}
49
50impl FragmentationIpExt for Ipv6 {
51    // IPv6 never fragments forwarded packets, only the source node may
52    // fragment.
53    type ForwardedFragmentBuilder = Uninstantiable;
54    type FragmentationId = u32;
55}
56
57/// Fragmentation errors
58#[derive(Debug, Eq, PartialEq, GenericOverIp)]
59#[generic_over_ip()]
60pub enum FragmentationError {
61    /// Fragmentation not allowed.
62    NotAllowed,
63    /// MTU is too small, headers don't fit.
64    MtuTooSmall,
65    /// Body is too long to be fragmented.
66    BodyTooLong,
67    /// Inner serializer reported a size limited exceeded.
68    SizeLimitExceeded,
69}
70
71/// A [`Serializer`] capable of splitting itself into a packet builder and a
72/// pre-serialized body for fragmentation.
73// TODO(https://fxbug.dev/42148826): Ideally we'd be able to generate fragments
74// without requiring the IP body to be dumped into a Vec first. Update this when
75// that support is available in packet and packet_formats.
76pub trait FragmentableIpSerializer<I: FragmentationIpExt>: Serializer {
77    /// The builder for each fragment.
78    type Builder<'a>: FragmentableIpPacketBuilder<I>
79    where
80        Self: 'a;
81    /// The body to be fragmented.
82    ///
83    /// Note that this API is not attempting to reuse buffers in any way. There
84    /// are improvements that can be made here to perhaps avoid allocations and
85    /// yield out reusable bodies, but we're constrained to taking references to
86    /// the serializers here to avoid changing the body which could interfere
87    /// with the higher layers on errors.
88    type Body<'a>: AsRef<[u8]>
89    where
90        Self: 'a;
91
92    /// Returns the inner packet builder for this IP version and a serialized
93    /// body.
94    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError>;
95}
96
97impl<I, S, B> FragmentableIpSerializer<I> for Nested<S, B>
98where
99    I: FragmentationIpExt,
100    S: Serializer,
101    B: AsFragmentableIpPacketBuilder<I> + PacketBuilder,
102{
103    type Builder<'a>
104        = B::Builder<'a>
105    where
106        Self: 'a;
107
108    type Body<'a>
109        = Buf<Vec<u8>>
110    where
111        Self: 'a;
112
113    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
114        let builder = self.outer().try_as_fragmentable()?;
115        let body = self
116            .inner()
117            .serialize_new_buf(PacketConstraints::UNCONSTRAINED, packet::new_buf_vec)
118            .map_err(|e| match e {
119                SerializeError::SizeLimitExceeded => FragmentationError::SizeLimitExceeded,
120            })?;
121        Ok((builder, body))
122    }
123}
124
125#[derive(Debug, Eq, PartialEq, Copy, Clone)]
126pub enum FragmentPosition {
127    First,
128    Middle,
129    Last,
130}
131
132/// The header size constraints for `FragmentableIpPacketBuilder`
133/// implementations.
134pub struct HeaderSizes {
135    first: usize,
136    remaining: usize,
137}
138
139/// A type that may be transformed into a fragmentable ip packet builder.
140pub trait AsFragmentableIpPacketBuilder<I: FragmentationIpExt> {
141    /// The fragmentable packet builder that can be constructed from this type.
142    type Builder<'a>: FragmentableIpPacketBuilder<I>
143    where
144        Self: 'a;
145
146    /// Attempts to extract a `FragmentableIpPacketBuilder` implementation from
147    /// this type, returning an error if it can't be fragmented.
148    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError>;
149}
150
151/// An IP packet builder that can create IP fragments.
152pub trait FragmentableIpPacketBuilder<I: FragmentationIpExt> {
153    /// Returns the portion of the MTU occupied by IP headers.
154    fn header_sizes(&self) -> HeaderSizes;
155
156    /// Returns a builder for fragment at offset `offset`.
157    ///
158    /// `position` carries information if this is the first or last segment, which
159    /// require special logic.
160    fn builder_at(
161        &self,
162        offset: FragmentOffset,
163        position: FragmentPosition,
164        identifier: I::FragmentationId,
165    ) -> impl PacketBuilder + '_;
166}
167
168/// Blanket impl for everything that has a shape to fit in `Ipv4FragmentBuilder`
169/// as a provider for fragmentation.
170impl<B> AsFragmentableIpPacketBuilder<Ipv4> for B
171where
172    B: InnerIpv4FragmentBuilder,
173{
174    type Builder<'a>
175        = Ipv4FragmentBuilder<'a, Self>
176    where
177        Self: 'a;
178
179    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
180        can_fragment_ipv4(self.prefix())?;
181        Ok(Ipv4FragmentBuilder { builder: self })
182    }
183}
184
185/// A trait marking all the IPv4 builder types that can be fragmented with
186/// [`Ipv4FragmentBuilder`].
187trait InnerIpv4FragmentBuilder: PacketBuilder {
188    fn prefix(&self) -> &Ipv4PacketBuilder;
189    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder;
190    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder;
191    fn header_sizes(&self) -> HeaderSizes;
192}
193
194impl InnerIpv4FragmentBuilder for Ipv4PacketBuilder {
195    fn prefix(&self) -> &Ipv4PacketBuilder {
196        self
197    }
198
199    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
200        self
201    }
202
203    fn clone_for_fragment(&self, _position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
204        self.clone()
205    }
206
207    fn header_sizes(&self) -> HeaderSizes {
208        let size = self.constraints().header_len();
209        HeaderSizes { first: size, remaining: size }
210    }
211}
212
213impl<'a, I> InnerIpv4FragmentBuilder for Ipv4PacketBuilderWithOptions<'a, I>
214where
215    I: Iterator<Item: Borrow<Ipv4Option<'a>>> + Clone,
216{
217    fn prefix(&self) -> &Ipv4PacketBuilder {
218        self.prefix_builder()
219    }
220
221    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
222        self.prefix_builder_mut()
223    }
224
225    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
226        self.clone().with_fragment_options(position == FragmentPosition::First)
227    }
228
229    fn header_sizes(&self) -> HeaderSizes {
230        let first = self.constraints().header_len();
231        let remaining = self.clone().with_fragment_options(false).constraints().header_len();
232        HeaderSizes { first, remaining }
233    }
234}
235
236pub struct Ipv4FragmentBuilder<'a, B> {
237    builder: &'a B,
238}
239
240impl<'a, B> FragmentableIpPacketBuilder<Ipv4> for Ipv4FragmentBuilder<'a, B>
241where
242    B: InnerIpv4FragmentBuilder,
243{
244    fn header_sizes(&self) -> HeaderSizes {
245        self.builder.header_sizes()
246    }
247
248    fn builder_at(
249        &self,
250        offset: FragmentOffset,
251        position: FragmentPosition,
252        (): (),
253    ) -> impl PacketBuilder + '_ {
254        let mut builder = self.builder.clone_for_fragment(position);
255        set_ipv4_fragment(builder.prefix_mut(), offset, position);
256        builder
257    }
258}
259
260impl<B> AsFragmentableIpPacketBuilder<Ipv6> for B
261where
262    for<'a> &'a B: Ipv6PacketBuilderBeforeFragment,
263{
264    type Builder<'a>
265        = Ipv6FragmentBuilder<'a, Self>
266    where
267        Self: 'a;
268
269    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
270        Ok(Ipv6FragmentBuilder { builder: self })
271    }
272}
273
274pub struct Ipv6FragmentBuilder<'a, B> {
275    builder: &'a B,
276}
277
278impl<'a, B> FragmentableIpPacketBuilder<Ipv6> for Ipv6FragmentBuilder<'a, B>
279where
280    &'a B: Ipv6PacketBuilderBeforeFragment,
281{
282    fn header_sizes(&self) -> HeaderSizes {
283        // NB: We currently only support headers that need to be in all
284        // fragments, so we only need to calculate once. We might need to change
285        // the trait shape if that changes.
286        let header_len =
287            Ipv6PacketBuilderWithFragmentHeader::new(self.builder, FragmentOffset::ZERO, false, 0)
288                .constraints()
289                .header_len();
290        HeaderSizes { first: header_len, remaining: header_len }
291    }
292
293    fn builder_at(
294        &self,
295        offset: FragmentOffset,
296        position: FragmentPosition,
297        identifier: u32,
298    ) -> impl PacketBuilder + '_ {
299        Ipv6PacketBuilderWithFragmentHeader::new(
300            self.builder,
301            offset,
302            position != FragmentPosition::Last,
303            identifier,
304        )
305    }
306}
307
308impl<I, B> FragmentableIpSerializer<I> for ForwardedPacket<I, B>
309where
310    I: FragmentationIpExt,
311    B: BufferMut,
312{
313    type Builder<'a>
314        = I::ForwardedFragmentBuilder
315    where
316        Self: 'a;
317    type Body<'a>
318        = Buf<&'a [u8]>
319    where
320        Self: 'a;
321
322    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
323        #[derive(GenericOverIp)]
324        #[generic_over_ip(I, Ip)]
325        struct Out<I: FragmentationIpExt>(I::ForwardedFragmentBuilder);
326        I::map_ip::<_, Result<(Out<I>, IpInvariant<Buf<&[u8]>>), FragmentationError>>(
327            self,
328            |forwarded| {
329                // Parse an IPv4 packet from the forwarded packet. We can assert
330                // strongly on all of the parsing here because ForwardedPacket
331                // is guaranteed to have been parsed by the IP stack already.
332                let mut buffer = forwarded.buffer().as_ref();
333                let packet = Ipv4PacketRaw::parse(&mut buffer, ())
334                    .expect("ForwardedPacket must be parseable");
335                let builder = packet.builder();
336                can_fragment_ipv4(&builder)?;
337                let raw_options_bytes = packet
338                    .options()
339                    .as_ref()
340                    .complete()
341                    .expect("unexpected incomplete IP header")
342                    .bytes();
343
344                let mut raw_options = Buf::new(
345                    [0u8; packet_formats::ipv4::MAX_OPTIONS_LEN],
346                    ..raw_options_bytes.len(),
347                );
348                raw_options.as_mut().copy_from_slice(raw_options_bytes);
349                let body = Buf::new(
350                    packet.into_body().complete().expect("unexpected incomplete IP body"),
351                    ..,
352                );
353                Ok((Out(ForwardedIpv4PacketBuilder { builder, raw_options }), IpInvariant(body)))
354            },
355            |_forwarded| Err(FragmentationError::NotAllowed),
356        )
357        .map(|(Out(builder), IpInvariant(body))| (builder, body))
358    }
359}
360
361pub struct ForwardedIpv4PacketBuilder {
362    builder: Ipv4PacketBuilder,
363    raw_options: Buf<[u8; packet_formats::ipv4::MAX_OPTIONS_LEN]>,
364}
365
366impl FragmentableIpPacketBuilder<Ipv4> for ForwardedIpv4PacketBuilder {
367    fn header_sizes(&self) -> HeaderSizes {
368        let Self { builder, raw_options } = self;
369        if raw_options.is_empty() {
370            builder.header_sizes()
371        } else {
372            let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
373                .expect("must hold valid options");
374            Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.iter())
375                .header_sizes()
376        }
377    }
378
379    fn builder_at(
380        &self,
381        offset: FragmentOffset,
382        position: FragmentPosition,
383        (): (),
384    ) -> impl PacketBuilder + '_ {
385        let Self { builder, raw_options } = self;
386        let mut builder = builder.clone();
387        set_ipv4_fragment(&mut builder, offset, position);
388        let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
389            .expect("must hold valid options");
390        Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.into_iter())
391            .with_fragment_options(position == FragmentPosition::First)
392    }
393}
394
395impl<I: FragmentationIpExt> FragmentableIpPacketBuilder<I> for Uninstantiable {
396    fn header_sizes(&self) -> HeaderSizes {
397        self.uninstantiable_unreachable()
398    }
399
400    fn builder_at(
401        &self,
402        _offset: FragmentOffset,
403        _position: FragmentPosition,
404        _identifier: I::FragmentationId,
405    ) -> impl PacketBuilder + '_ {
406        self.uninstantiable_unreachable::<Ipv6PacketBuilder>()
407    }
408}
409
410/// Abstracts fragment ID generation for [`IpFragmenter`].
411///
412/// A blanket impl is provided for [`RngContext`] implementers, so the bindings
413/// context can be used to generate random IDs for IPv6.
414pub(crate) trait FragmentationIdGenContext {
415    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId;
416}
417
418#[derive(GenericOverIp)]
419#[generic_over_ip(I, Ip)]
420struct WrapFragmentationId<I: FragmentationIpExt>(I::FragmentationId);
421
422impl<BC> FragmentationIdGenContext for BC
423where
424    BC: RngContext,
425{
426    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
427        let WrapFragmentationId(identifier) = I::map_ip_out(
428            self,
429            |_| WrapFragmentationId(()),
430            |rng| {
431                // TODO(https://fxbug.dev/373428005): Perhaps we can do better
432                // than a simple RNG. This is currently copying what netstack2
433                // does. RFC 7739 calls out different strategies for fragment
434                // IDs in IPv6. We currently pick an option that is not doing a
435                // best effort to avoid collisions, but it guarantees that
436                // fragment IDs can't be tracked as an attack vector.
437                // We avoid a zero fragment ID like netstack2 does.
438                WrapFragmentationId(rng.rng().random_range(1..=u32::MAX))
439            },
440        );
441        identifier
442    }
443}
444
445pub(crate) struct IpFragmenter<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I> + 'a> {
446    builder: S::Builder<'a>,
447    body: S::Body<'a>,
448    consumed: usize,
449    max_fragment_body_first: usize,
450    max_fragment_body_remaining: usize,
451    identifier: I::FragmentationId,
452}
453
454/// Returns the biggest fragment body that can fit in `mtu` with a given IP
455/// `header` size.
456///
457/// The returned body size is rounded down to the nearest multiple of 8 to fit
458/// the IP header representation of fragment offsets.
459fn maximum_fragment_body_with_header_and_mtu(
460    mtu: Mtu,
461    header: usize,
462) -> Result<usize, FragmentationError> {
463    let v = usize::from(mtu).checked_sub(header).ok_or(FragmentationError::MtuTooSmall)?;
464    // Mask the final 8 bits since fragment offset is expressed in units
465    // of 8 octets for both IP versions.
466    let v = v & !0x07usize;
467
468    if v == 0 {
469        // Can't fragment if we don't have at least a single 8 octet
470        // of space.
471        return Err(FragmentationError::MtuTooSmall);
472    }
473    Ok(v)
474}
475
476impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
477    /// Creates a new `IpFragmenter` with some `serializer` respecting a maximum
478    /// IP layer `mtu`.
479    pub(crate) fn new<C: FragmentationIdGenContext>(
480        id_ctx: &mut C,
481        serializer: &'a S,
482        mtu: Mtu,
483    ) -> Result<Self, FragmentationError> {
484        let (builder, body) = serializer.builder_and_body()?;
485        let HeaderSizes { first, remaining } = builder.header_sizes();
486        let max_fragment_body_first = maximum_fragment_body_with_header_and_mtu(mtu, first)?;
487        let max_fragment_body_remaining =
488            maximum_fragment_body_with_header_and_mtu(mtu, remaining)?;
489
490        if body.as_ref().len() > MAX_FRAGMENT_OFFSET + max_fragment_body_remaining {
491            return Err(FragmentationError::BodyTooLong);
492        }
493
494        let identifier = id_ctx.generate_id::<I>();
495
496        Ok(Self {
497            builder,
498            body,
499            consumed: 0,
500            max_fragment_body_first,
501            max_fragment_body_remaining,
502            identifier,
503        })
504    }
505
506    /// Returns the serializer for the next segment and a boolean indicating
507    /// whether more fragments are pending, or `None` if all segments have been
508    /// produced.
509    ///
510    /// # Panics
511    ///
512    /// Panics if fragmentation is not necessary for the `serializer` that
513    /// created this `IpFragmenter`.
514    pub(crate) fn next(&mut self) -> Option<(impl Serializer<Buffer = EmptyBuf>, bool)> {
515        let Self {
516            builder,
517            body,
518            consumed,
519            max_fragment_body_first,
520            max_fragment_body_remaining,
521            identifier,
522        } = self;
523        let body = &AsRef::as_ref(body)[*consumed..];
524        if body.is_empty() {
525            return None;
526        }
527        let first = *consumed == 0;
528        let max_fragment_body =
529            if first { max_fragment_body_first } else { max_fragment_body_remaining };
530        let take = body.len().min(*max_fragment_body);
531        let last = take == body.len();
532        let position = match (first, last) {
533            (true, true) => {
534                panic!("unnecessary fragmentation");
535            }
536            (true, false) => FragmentPosition::First,
537            (false, false) => FragmentPosition::Middle,
538            (false, true) => FragmentPosition::Last,
539        };
540        // Upon construction IpFragmenter verifies that we won't go over the
541        // maximum offset since the body length is known.
542        let fragment_offset = u16::try_from(*consumed).expect("fragment offset too large");
543        // Care is taken above to always take 8-byte multiples to be added to
544        // consumed, so we should always have a good representation for
545        // FragmentOffset.
546        let fragment_offset =
547            FragmentOffset::new_with_bytes(fragment_offset).expect("invalid offset");
548        let fragment_builder = builder.builder_at(fragment_offset, position, *identifier);
549        let end = *consumed + take;
550        let has_more = body.len() > take;
551        let fragment_body = &body[..take];
552        *consumed = end;
553        Some((fragment_builder.wrap_body(fragment_body.into_serializer()), has_more))
554    }
555}
556
557fn can_fragment_ipv4(builder: &Ipv4PacketBuilder) -> Result<(), FragmentationError> {
558    if builder.read_df_flag() {
559        return Err(FragmentationError::NotAllowed);
560    }
561    Ok(())
562}
563
564fn set_ipv4_fragment(
565    builder: &mut Ipv4PacketBuilder,
566    offset: FragmentOffset,
567    position: FragmentPosition,
568) {
569    builder.mf_flag(position != FragmentPosition::Last);
570    builder.fragment_offset(offset);
571}
572
573/// Counters kept by the IP stack pertaining to fragmentation.
574#[derive(Default, Debug)]
575#[cfg_attr(
576    any(test, feature = "testutils"),
577    derive(PartialEq, netstack3_macros::CounterCollection)
578)]
579pub struct FragmentationCounters<C = Counter> {
580    /// The number of IP frames requiring fragmentation on egress.
581    pub fragmentation_required: C,
582    /// The total number of fragments sent.
583    pub fragments: C,
584    /// The number of `NotAllowed` errors encountered.
585    pub error_not_allowed: C,
586    /// The number of `MtuTooSmall` errors encountered.
587    pub error_mtu_too_small: C,
588    /// The number of `BodyTooLong` errors encountered.
589    pub error_body_too_long: C,
590    /// The number of `SizeLimitExceeded` errors encountered.
591    pub error_inner_size_limit_exceeded: C,
592    /// Counts the number of times fragmentation was short-circuited due to a
593    /// fragment serialization error.
594    pub error_fragmented_serializer: C,
595}
596
597impl FragmentationCounters {
598    pub(crate) fn error_counter(&self, error: &FragmentationError) -> &Counter {
599        match error {
600            FragmentationError::NotAllowed => &self.error_not_allowed,
601            FragmentationError::MtuTooSmall => &self.error_mtu_too_small,
602            FragmentationError::BodyTooLong => &self.error_body_too_long,
603            FragmentationError::SizeLimitExceeded => &self.error_inner_size_limit_exceeded,
604        }
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    use assert_matches::assert_matches;
613    use net_types::Witness as _;
614    use netstack3_base::testutil::{TEST_ADDRS_V4, TEST_ADDRS_V6};
615    use netstack3_filter::FilterIpExt;
616    use packet::{Buffer, BufferView, GrowBuffer};
617    use packet_formats::ip::IpProto;
618    use packet_formats::ipv4::Ipv4Packet;
619    use packet_formats::ipv6::ext_hdrs::Ipv6ExtensionHeaderData;
620    use packet_formats::ipv6::{Ipv6Header, Ipv6Packet};
621    use test_case::test_case;
622
623    const TEST_MTU: Mtu = Ipv6::MINIMUM_LINK_MTU;
624
625    fn gen_body(len: usize) -> Vec<u8> {
626        // Cycle bytes until 251 which is the largest prime that can fit in a
627        // u8. Unlikely this aligns poorly and hides fragmentation bugs.
628        (0u8..=251).cycle().take(len).collect::<Vec<u8>>()
629    }
630
631    impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
632        fn next_serialized(&mut self) -> Buf<Vec<u8>> {
633            self.next()
634                .expect("no more fragments")
635                .0
636                .serialize_vec_outer()
637                .map_err(|(err, _serializer)| err)
638                .unwrap()
639                .unwrap_b()
640        }
641    }
642
643    trait FragmentationTestEnv<I: FragmentationIpExt> {
644        fn new_serializer<'a>(
645            &self,
646            body: &'a [u8],
647        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a;
648        fn check_fragment(
649            &self,
650            fragment: &mut Buf<Vec<u8>>,
651            position: FragmentPosition,
652            offset: usize,
653        );
654    }
655
656    #[derive(Default)]
657    struct Ipv4TestEnv {
658        dont_frag: bool,
659    }
660
661    impl Ipv4TestEnv {
662        const fn dont_frag() -> Self {
663            Self { dont_frag: true }
664        }
665    }
666
667    const IPV4_ID: u16 = 0x1234;
668    fn new_ipv4_packet_builder(dont_frag: bool) -> Ipv4PacketBuilder {
669        let mut builder = Ipv4PacketBuilder::new(
670            TEST_ADDRS_V4.local_ip,
671            TEST_ADDRS_V4.remote_ip,
672            1,
673            IpProto::Udp.into(),
674        );
675        builder.id(IPV4_ID);
676        builder.df_flag(dont_frag);
677        builder
678    }
679
680    fn parse_and_check_ipv4_packet(
681        fragment: &mut Buf<Vec<u8>>,
682        position: FragmentPosition,
683        offset: usize,
684    ) -> Ipv4Packet<&[u8]> {
685        let packet = Ipv4Packet::parse(fragment.buffer_view(), ()).expect("parse fragment");
686        assert_eq!(packet.src_ip(), TEST_ADDRS_V4.local_ip.get());
687        assert_eq!(packet.dst_ip(), TEST_ADDRS_V4.remote_ip.get());
688        assert_eq!(packet.ttl(), 1);
689        assert_eq!(packet.id(), IPV4_ID);
690        assert_eq!(packet.proto(), IpProto::Udp.into());
691        assert_eq!(packet.mf_flag(), position != FragmentPosition::Last);
692        assert_eq!(usize::from(packet.fragment_offset().into_bytes()), offset);
693        packet
694    }
695
696    impl FragmentationTestEnv<Ipv4> for Ipv4TestEnv {
697        fn new_serializer<'a>(
698            &self,
699            body: &'a [u8],
700        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
701            let Self { dont_frag } = self;
702            new_ipv4_packet_builder(*dont_frag).wrap_body(body.into_serializer())
703        }
704
705        fn check_fragment(
706            &self,
707            fragment: &mut Buf<Vec<u8>>,
708            position: FragmentPosition,
709            offset: usize,
710        ) {
711            let _ = parse_and_check_ipv4_packet(fragment, position, offset);
712        }
713    }
714
715    #[derive(Default)]
716    struct Ipv4WithOptionsTestEnv(Ipv4TestEnv);
717
718    // The MSB of an option kind determines if it should be copied.
719    const FAKE_OPTION_COPIED_KIND: u8 = 255;
720    const FAKE_OPTION_COPIED: [u8; 1] = [255];
721    const FAKE_OPTION_NOT_COPIED_KIND: u8 = 127;
722    const FAKE_OPTION_NOT_COPIED: [u8; 1] = [127];
723
724    impl FragmentationTestEnv<Ipv4> for Ipv4WithOptionsTestEnv {
725        fn new_serializer<'a>(
726            &self,
727            body: &'a [u8],
728        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
729            let Self(Ipv4TestEnv { dont_frag }) = self;
730
731            Ipv4PacketBuilderWithOptions::new(
732                new_ipv4_packet_builder(*dont_frag),
733                [
734                    Ipv4Option::Unrecognized {
735                        kind: FAKE_OPTION_COPIED_KIND,
736                        data: &FAKE_OPTION_COPIED[..],
737                    },
738                    Ipv4Option::Unrecognized {
739                        kind: FAKE_OPTION_NOT_COPIED_KIND,
740                        data: &FAKE_OPTION_NOT_COPIED[..],
741                    },
742                ],
743            )
744            .unwrap()
745            .wrap_body(body.into_serializer())
746        }
747
748        fn check_fragment(
749            &self,
750            fragment: &mut Buf<Vec<u8>>,
751            position: FragmentPosition,
752            offset: usize,
753        ) {
754            let packet = parse_and_check_ipv4_packet(fragment, position, offset);
755            let (copied, not_copied) = packet.iter_options().fold(
756                (false, false),
757                |(mut copied, mut not_copied), option| {
758                    let (kind, data) = assert_matches!(option,
759                        Ipv4Option::Unrecognized{ kind, data } => (kind, data)
760                    );
761                    assert_eq!(data.len(), 1);
762                    assert_eq!(data[0], kind);
763                    let seen = match kind {
764                        FAKE_OPTION_COPIED_KIND => &mut copied,
765                        FAKE_OPTION_NOT_COPIED_KIND => &mut not_copied,
766                        k => panic!("unexpected option {k}"),
767                    };
768                    assert_eq!(core::mem::replace(seen, true), false);
769                    (copied, not_copied)
770                },
771            );
772            assert_eq!(copied, true, "must be copied on all fragments {position:?}");
773            assert_eq!(
774                not_copied,
775                position == FragmentPosition::First,
776                "must only be in first fragment {position:?}"
777            );
778        }
779    }
780
781    struct ForwardingTestEnv<E>(E);
782    impl<I: FragmentationIpExt + FilterIpExt, E: FragmentationTestEnv<I>> FragmentationTestEnv<I>
783        for ForwardingTestEnv<E>
784    {
785        fn new_serializer<'a>(
786            &self,
787            body: &'a [u8],
788        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a {
789            use packet_formats::ip::IpPacket as _;
790            let Self(inner) = self;
791            let mut buffer = inner
792                .new_serializer(body)
793                .serialize_outer(packet::NoReuseBufferProvider(packet::new_buf_vec))
794                .map_err(|(err, _)| err)
795                .unwrap();
796            let packet =
797                <I::Packet<_> as ParsablePacket<_, _>>::parse(buffer.buffer_view(), ()).unwrap();
798            let src_addr = packet.src_ip();
799            let dst_addr = packet.dst_ip();
800            let proto = packet.proto();
801            let meta = packet.parse_metadata();
802            drop(packet);
803            ForwardedPacket::new(src_addr, dst_addr, proto, meta, buffer)
804        }
805        fn check_fragment(
806            &self,
807            fragment: &mut Buf<Vec<u8>>,
808            position: FragmentPosition,
809            offset: usize,
810        ) {
811            let Self(inner) = self;
812            inner.check_fragment(fragment, position, offset)
813        }
814    }
815
816    struct Ipv6TestEnv;
817
818    const IPV6_ID: u32 = 0x1234ABCD;
819
820    impl FragmentationTestEnv<Ipv6> for Ipv6TestEnv {
821        fn new_serializer<'a>(
822            &self,
823            body: &'a [u8],
824        ) -> impl FragmentableIpSerializer<Ipv6, Buffer: Buffer> + 'a {
825            Ipv6PacketBuilder::new(
826                TEST_ADDRS_V6.local_ip,
827                TEST_ADDRS_V6.remote_ip,
828                1,
829                IpProto::Udp.into(),
830            )
831            .wrap_body(body.into_serializer())
832        }
833
834        fn check_fragment(
835            &self,
836            fragment: &mut Buf<Vec<u8>>,
837            position: FragmentPosition,
838            offset: usize,
839        ) {
840            let packet = Ipv6Packet::parse(fragment.buffer_view(), ()).unwrap();
841            assert_eq!(packet.src_ip(), TEST_ADDRS_V6.local_ip.get());
842            assert_eq!(packet.dst_ip(), TEST_ADDRS_V6.remote_ip.get());
843            assert_eq!(packet.hop_limit(), 1);
844            assert_eq!(packet.proto(), IpProto::Udp.into());
845            let fragment = packet
846                .iter_extension_hdrs()
847                .find_map(|h| match h.into_data() {
848                    Ipv6ExtensionHeaderData::Fragment { fragment_data } => Some(fragment_data),
849                    _ => None,
850                })
851                .expect("no fragment header");
852            assert_eq!(fragment.identification(), IPV6_ID);
853            assert_eq!(usize::from(fragment.fragment_offset().into_bytes()), offset);
854            assert_eq!(fragment.m_flag(), position != FragmentPosition::Last);
855        }
856    }
857
858    struct FixedIdContext;
859    impl FragmentationIdGenContext for FixedIdContext {
860        fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
861            let WrapFragmentationId(id) =
862                I::map_ip_out((), |()| WrapFragmentationId(()), |()| WrapFragmentationId(IPV6_ID));
863            id
864        }
865    }
866
867    #[test_case::test_matrix(
868        [
869            Ipv4TestEnv::default(),
870            Ipv4WithOptionsTestEnv::default(),
871            ForwardingTestEnv(Ipv4TestEnv::default()),
872            ForwardingTestEnv(Ipv4WithOptionsTestEnv::default()),
873            Ipv6TestEnv,
874        ],
875        0..=2
876    )]
877    fn fragment<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(
878        env: E,
879        middle_fragments: usize,
880    ) {
881        // NB: We're using the fact that MTU is larger than the header sizes
882        // here to end up obtaining the right number of middle fragments as
883        // expected. This makes this test sensitive to the relation between the
884        // picked MTU and the header sizes for the multiple serializers.
885        let full_body = gen_body(usize::from(TEST_MTU) * (1 + middle_fragments));
886        let mut body_view = Buf::new(&full_body[..], ..);
887        let serializer = env.new_serializer(&full_body[..]);
888        let mut fragmenter = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU)
889            .expect("create fragmenter");
890
891        let mut frag = fragmenter.next_serialized();
892        env.check_fragment(&mut frag, FragmentPosition::First, body_view.prefix_len());
893        assert_eq!(
894            frag.as_ref(),
895            body_view.buffer_view().take_front(fragmenter.max_fragment_body_first).unwrap()
896        );
897
898        for _ in 0..middle_fragments {
899            let mut frag = fragmenter.next_serialized();
900            env.check_fragment(&mut frag, FragmentPosition::Middle, body_view.prefix_len());
901            assert_eq!(
902                frag.as_ref(),
903                body_view.buffer_view().take_front(fragmenter.max_fragment_body_remaining).unwrap()
904            );
905        }
906
907        let mut frag = fragmenter.next_serialized();
908        env.check_fragment(&mut frag, FragmentPosition::Last, body_view.prefix_len());
909        assert_eq!(frag.as_ref(), body_view.buffer_view().into_rest());
910
911        // No more fragments.
912        assert!(fragmenter.next().is_none());
913    }
914
915    #[test_case(Ipv4TestEnv::dont_frag())]
916    #[test_case(Ipv4WithOptionsTestEnv(Ipv4TestEnv::dont_frag()))]
917    #[test_case(ForwardingTestEnv(Ipv4TestEnv::dont_frag()))]
918    #[test_case(ForwardingTestEnv(Ipv6TestEnv))]
919    fn not_allowed<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
920        let body = gen_body(usize::from(TEST_MTU));
921        let serializer = env.new_serializer(&body[..]);
922        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
923        assert_eq!(result, Err(FragmentationError::NotAllowed))
924    }
925
926    #[test_case(Ipv4TestEnv::default())]
927    #[test_case(Ipv4WithOptionsTestEnv::default())]
928    #[test_case(ForwardingTestEnv(Ipv4TestEnv::default()))]
929    #[test_case(Ipv6TestEnv)]
930    fn mtu_too_small<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
931        let body = gen_body(usize::from(TEST_MTU));
932        let serializer = env.new_serializer(&body[..]);
933        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, Mtu::new(10)).map(|_| ());
934        assert_eq!(result, Err(FragmentationError::MtuTooSmall));
935    }
936
937    #[test_case(Ipv4TestEnv::default())]
938    #[test_case(Ipv4WithOptionsTestEnv::default())]
939    #[test_case(Ipv6TestEnv)]
940    fn body_too_long<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
941        let body = gen_body(MAX_FRAGMENT_OFFSET + usize::from(TEST_MTU));
942        let serializer = env.new_serializer(&body[..]);
943        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
944        assert_eq!(result, Err(FragmentationError::BodyTooLong));
945    }
946}