netstack3_base/tcp/
base.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The Transmission Control Protocol (TCP).
6
7use core::iter::FromIterator;
8use core::num::NonZeroU16;
9use core::ops::Range;
10
11use alloc::vec::Vec;
12use core::mem::MaybeUninit;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15use packet_formats::ip::IpExt;
16
17use crate::ip::Mms;
18use crate::tcp::segment::{Payload, PayloadLen};
19
20/// Control flags that can alter the state of a TCP control block.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Control {
23    /// Corresponds to the SYN bit in a TCP segment.
24    SYN,
25    /// Corresponds to the FIN bit in a TCP segment.
26    FIN,
27    /// Corresponds to the RST bit in a TCP segment.
28    RST,
29}
30
31impl Control {
32    /// Returns whether the control flag consumes one byte from the sequence
33    /// number space.
34    pub fn has_sequence_no(self) -> bool {
35        match self {
36            Control::SYN | Control::FIN => true,
37            Control::RST => false,
38        }
39    }
40}
41
42const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
43
44/// Maximum segment size, that is the maximum TCP payload one segment can carry.
45#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
46pub struct Mss(pub NonZeroU16);
47
48impl Mss {
49    /// Creates MSS from the maximum message size of the IP layer.
50    pub fn from_mms<I: IpExt>(mms: Mms) -> Option<Self> {
51        NonZeroU16::new(
52            u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX),
53        )
54        .map(Self)
55    }
56
57    /// Create a new [`Mss`] with the IP-version default value, as defined by RFC 9293.
58    pub const fn default<I: Ip>() -> Self {
59        // Per RFC 9293 Section 3.7.1:
60        //  If an MSS Option is not received at connection setup, TCP
61        //  implementations MUST assume a default send MSS of 536 (576 - 40) for
62        //  IPv4 or 1220 (1280 - 60) for IPv6 (MUST-15).
63        match I::VERSION {
64            IpVersion::V4 => Mss(NonZeroU16::new(536).unwrap()),
65            IpVersion::V6 => Mss(NonZeroU16::new(1220).unwrap()),
66        }
67    }
68
69    /// Gets the numeric value of the MSS.
70    pub const fn get(&self) -> NonZeroU16 {
71        let Self(mss) = *self;
72        mss
73    }
74}
75
76impl From<Mss> for u32 {
77    fn from(Mss(mss): Mss) -> Self {
78        u32::from(mss.get())
79    }
80}
81
82impl From<Mss> for usize {
83    fn from(Mss(mss): Mss) -> Self {
84        usize::from(mss.get())
85    }
86}
87
88/// An implementation of [`Payload`] backed by up to `N` byte slices.
89#[derive(Copy, Clone, Debug, PartialEq)]
90pub struct FragmentedPayload<'a, const N: usize> {
91    storage: [&'a [u8]; N],
92    // NB: Not using `Range` because it is not `Copy`.
93    //
94    // Start is inclusive, end is exclusive; so this is equivalent to
95    // `start..end` ranges.
96    start: usize,
97    end: usize,
98}
99
100/// Creates a new `FragmentedPayload` possibly without using the entire
101/// storage capacity `N`.
102///
103/// # Panics
104///
105/// Panics if the iterator contains more than `N` items.
106impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
107    fn from_iter<T>(iter: T) -> Self
108    where
109        T: IntoIterator<Item = &'a [u8]>,
110    {
111        let Self { storage, start, end } = Self::new_empty();
112        let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
113            storage[end] = sl;
114            (storage, end + 1)
115        });
116        Self { storage, start, end }
117    }
118}
119
120impl<'a, const N: usize> FragmentedPayload<'a, N> {
121    /// Creates a new `FragmentedPayload` with the slices in `values`.
122    pub fn new(values: [&'a [u8]; N]) -> Self {
123        Self { storage: values, start: 0, end: N }
124    }
125
126    /// Creates a new `FragmentedPayload` with a single contiguous slice.
127    pub fn new_contiguous(value: &'a [u8]) -> Self {
128        core::iter::once(value).collect()
129    }
130
131    /// Converts this [`FragmentedPayload`] into an owned `Vec`.
132    pub fn to_vec(self) -> Vec<u8> {
133        self.slices().concat()
134    }
135
136    fn slices(&self) -> &[&'a [u8]] {
137        let Self { storage, start, end } = self;
138        &storage[*start..*end]
139    }
140
141    /// Extracted function to implement [`Payload::partial_copy`] and
142    /// [`Payload::partial_copy_uninit`].
143    fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
144        &self,
145        mut offset: usize,
146        mut dst: &mut [T],
147        apply: F,
148    ) {
149        let mut slices = self.slices().into_iter();
150        while let Some(sl) = slices.next() {
151            let l = sl.len();
152            if offset >= l {
153                offset -= l;
154                continue;
155            }
156            let sl = &sl[offset..];
157            let cp = sl.len().min(dst.len());
158            let (target, new_dst) = dst.split_at_mut(cp);
159            apply(&sl[..cp], target);
160
161            // We're done.
162            if new_dst.len() == 0 {
163                return;
164            }
165
166            dst = new_dst;
167            offset = 0;
168        }
169        assert_eq!(dst.len(), 0, "failed to fill dst");
170    }
171}
172
173impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
174    fn len(&self) -> usize {
175        self.slices().into_iter().map(|s| s.len()).sum()
176    }
177}
178
179impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
180    fn slice(self, byte_range: Range<u32>) -> Self {
181        let Self { mut storage, start: mut self_start, end: mut self_end } = self;
182        let Range { start: byte_start, end: byte_end } = byte_range;
183        let byte_start =
184            usize::try_from(byte_start).expect("range start index out of range for usize");
185        let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
186        assert!(byte_end >= byte_start);
187        let mut storage_iter =
188            (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
189                let slice_len = slice.len();
190                let item = Some((*total_len, slice));
191                *total_len += slice_len;
192                item
193            });
194
195        // Keep track of whether the start was inside the range, we should panic
196        // even on an empty range out of start bounds.
197        let mut start_offset = None;
198        let mut final_len = 0;
199        while let Some((sl_offset, sl)) = storage_iter.next() {
200            let orig_len = sl.len();
201
202            // Advance until the start of the specified range, discarding unused
203            // slices.
204            if sl_offset + orig_len < byte_start {
205                *sl = &[];
206                self_start += 1;
207                continue;
208            }
209            // Discard any empty slices at the end.
210            if sl_offset >= byte_end {
211                *sl = &[];
212                self_end -= 1;
213                continue;
214            }
215
216            let sl_start = byte_start.saturating_sub(sl_offset);
217            let sl_end = sl.len().min(byte_end - sl_offset);
218            *sl = &sl[sl_start..sl_end];
219
220            match start_offset {
221                Some(_) => (),
222                None => {
223                    // Keep track of the start offset of the first slice.
224                    start_offset = Some(sl_offset + sl_start);
225                    // Avoid producing an empty slice if we haven't added
226                    // anything yet.
227                    if sl.len() == 0 {
228                        self_start += 1;
229                    }
230                }
231            }
232            final_len += sl.len();
233        }
234        // Verify that the entire range was consumed.
235        assert_eq!(
236            // If we didn't use start_offset the only valid value for
237            // `byte_start` is zero.
238            start_offset.unwrap_or(0),
239            byte_start,
240            "range start index out of range {byte_range:?}"
241        );
242        assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
243
244        // Canonicalize an empty payload.
245        if self_start == self_end {
246            self_start = 0;
247            self_end = 0;
248        }
249        Self { storage, start: self_start, end: self_end }
250    }
251
252    fn new_empty() -> Self {
253        Self { storage: [&[]; N], start: 0, end: 0 }
254    }
255
256    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
257        self.apply_copy(offset, dst, |src, dst| {
258            dst.copy_from_slice(src);
259        });
260    }
261
262    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
263        self.apply_copy(offset, dst, |src, dst| {
264            // TODO(https://github.com/rust-lang/rust/issues/79995): Replace unsafe
265            // with copy_from_slice when stabiliized.
266            // SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
267            let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
268            dst.copy_from_slice(&uninit_src);
269        });
270    }
271}
272
273impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
274    fn bytes_len(&self) -> usize {
275        self.len()
276    }
277
278    fn serialize(&self, buffer: &mut [u8]) {
279        self.partial_copy(0, buffer);
280    }
281}
282
283#[cfg(test)]
284mod test {
285    use super::*;
286    use alloc::format;
287
288    use packet::Serializer as _;
289    use proptest::test_runner::Config;
290    use proptest::{prop_assert_eq, proptest};
291    use proptest_support::failed_seeds_no_std;
292    use test_case::test_case;
293
294    const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
295    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
296    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
297    #[test_case(FragmentedPayload::new([
298        &EXAMPLE_DATA[0..2],
299        &EXAMPLE_DATA[2..5],
300        &EXAMPLE_DATA[5..],
301    ]); "split twice")]
302    #[test_case(FragmentedPayload::<4>::from_iter([
303        &EXAMPLE_DATA[0..2],
304        &EXAMPLE_DATA[2..5],
305        &EXAMPLE_DATA[5..],
306    ]); "partial twice")]
307    fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
308        let serialized = payload
309            .into_serializer()
310            .serialize_vec_outer()
311            .expect("should serialize")
312            .unwrap_b()
313            .into_inner();
314        assert_eq!(&serialized[..], EXAMPLE_DATA);
315    }
316
317    #[test]
318    #[should_panic(expected = "range start index out of range")]
319    fn slice_start_out_of_bounds() {
320        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
321        let bad_len = len + 1;
322        // Like for standard slices, this shouldn't succeed if the start length
323        // is out of bounds, even if the total range is empty.
324        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
325    }
326
327    #[test]
328    #[should_panic(expected = "range end index out of range")]
329    fn slice_end_out_of_bounds() {
330        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
331        let bad_len = len + 1;
332        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
333    }
334
335    #[test]
336    fn canon_empty_payload() {
337        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
338        assert_eq!(
339            FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
340            FragmentedPayload::new_empty()
341        );
342        assert_eq!(
343            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
344            FragmentedPayload::new_empty()
345        );
346        assert_eq!(
347            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
348            FragmentedPayload::new_empty()
349        );
350    }
351
352    const TEST_BYTES: &'static [u8] = b"Hello World!";
353    proptest! {
354        #![proptest_config(Config {
355            // Add all failed seeds here.
356            failure_persistence: failed_seeds_no_std!(),
357            ..Config::default()
358        })]
359
360        #[test]
361        fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
362            prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
363        }
364
365        #[test]
366        fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
367            prop_assert_eq!(payload.len(), TEST_BYTES.len())
368        }
369
370        #[test]
371        fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
372            let want = &TEST_BYTES[start..end];
373            let start = u32::try_from(start).unwrap();
374            let end = u32::try_from(end).unwrap();
375            prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
376        }
377
378        #[test]
379        fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
380            let mut buffer = [0; TEST_BYTES.len()];
381            let buffer = &mut buffer[0..(end-start)];
382            payload.partial_copy(start, buffer);
383            prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
384        }
385    }
386
387    mod fragmented_payload {
388        use super::*;
389
390        use proptest::strategy::{Just, Strategy};
391        use rand::Rng as _;
392
393        const TEST_STORAGE: usize = 5;
394        type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
395        pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
396            (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
397                (0..slices)
398                    .scan(0, |st, slice| {
399                        let len = if slice == slices - 1 {
400                            TEST_BYTES.len() - *st
401                        } else {
402                            rng.gen_range(0..=(TEST_BYTES.len() - *st))
403                        };
404                        let start = *st;
405                        *st += len;
406                        Some(&TEST_BYTES[start..*st])
407                    })
408                    .collect()
409            })
410        }
411
412        pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
413        {
414            (
415                with_payload(),
416                (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
417            )
418        }
419    }
420}