netstack3_base/tcp/
base.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The Transmission Control Protocol (TCP).
6
7use core::iter::FromIterator;
8use core::num::NonZeroU16;
9use core::ops::Range;
10
11use alloc::vec::Vec;
12use core::mem::MaybeUninit;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15
16use crate::ip::Mms;
17use crate::tcp::segment::{Payload, PayloadLen};
18
19/// Control flags that can alter the state of a TCP control block.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Control {
22    /// Corresponds to the SYN bit in a TCP segment.
23    SYN,
24    /// Corresponds to the FIN bit in a TCP segment.
25    FIN,
26    /// Corresponds to the RST bit in a TCP segment.
27    RST,
28}
29
30impl Control {
31    /// Returns whether the control flag consumes one byte from the sequence
32    /// number space.
33    pub fn has_sequence_no(self) -> bool {
34        match self {
35            Control::SYN | Control::FIN => true,
36            Control::RST => false,
37        }
38    }
39}
40
41const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
42
43/// Maximum segment size, that is the maximum TCP payload one segment can carry.
44#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
45pub struct Mss(pub NonZeroU16);
46
47impl Mss {
48    /// Creates MSS from the maximum message size of the IP layer.
49    pub fn from_mms(mms: Mms) -> Option<Self> {
50        NonZeroU16::new(
51            u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX),
52        )
53        .map(Self)
54    }
55
56    /// Create a new [`Mss`] with the IP-version default value, as defined by RFC 9293.
57    pub const fn default<I: Ip>() -> Self {
58        // Per RFC 9293 Section 3.7.1:
59        //  If an MSS Option is not received at connection setup, TCP
60        //  implementations MUST assume a default send MSS of 536 (576 - 40) for
61        //  IPv4 or 1220 (1280 - 60) for IPv6 (MUST-15).
62        match I::VERSION {
63            IpVersion::V4 => Mss(NonZeroU16::new(536).unwrap()),
64            IpVersion::V6 => Mss(NonZeroU16::new(1220).unwrap()),
65        }
66    }
67
68    /// Gets the numeric value of the MSS.
69    pub const fn get(&self) -> NonZeroU16 {
70        let Self(mss) = *self;
71        mss
72    }
73}
74
75impl From<Mss> for u32 {
76    fn from(Mss(mss): Mss) -> Self {
77        u32::from(mss.get())
78    }
79}
80
81impl From<Mss> for usize {
82    fn from(Mss(mss): Mss) -> Self {
83        usize::from(mss.get())
84    }
85}
86
87/// An implementation of [`Payload`] backed by up to `N` byte slices.
88#[derive(Copy, Clone, Debug, PartialEq)]
89pub struct FragmentedPayload<'a, const N: usize> {
90    storage: [&'a [u8]; N],
91    // NB: Not using `Range` because it is not `Copy`.
92    //
93    // Start is inclusive, end is exclusive; so this is equivalent to
94    // `start..end` ranges.
95    start: usize,
96    end: usize,
97}
98
99/// Creates a new `FragmentedPayload` possibly without using the entire
100/// storage capacity `N`.
101///
102/// # Panics
103///
104/// Panics if the iterator contains more than `N` items.
105impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
106    fn from_iter<T>(iter: T) -> Self
107    where
108        T: IntoIterator<Item = &'a [u8]>,
109    {
110        let Self { storage, start, end } = Self::new_empty();
111        let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
112            storage[end] = sl;
113            (storage, end + 1)
114        });
115        Self { storage, start, end }
116    }
117}
118
119impl<'a, const N: usize> FragmentedPayload<'a, N> {
120    /// Creates a new `FragmentedPayload` with the slices in `values`.
121    pub fn new(values: [&'a [u8]; N]) -> Self {
122        Self { storage: values, start: 0, end: N }
123    }
124
125    /// Creates a new `FragmentedPayload` with a single contiguous slice.
126    pub fn new_contiguous(value: &'a [u8]) -> Self {
127        core::iter::once(value).collect()
128    }
129
130    /// Converts this [`FragmentedPayload`] into an owned `Vec`.
131    pub fn to_vec(self) -> Vec<u8> {
132        self.slices().concat()
133    }
134
135    fn slices(&self) -> &[&'a [u8]] {
136        let Self { storage, start, end } = self;
137        &storage[*start..*end]
138    }
139
140    /// Extracted function to implement [`Payload::partial_copy`] and
141    /// [`Payload::partial_copy_uninit`].
142    fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
143        &self,
144        mut offset: usize,
145        mut dst: &mut [T],
146        apply: F,
147    ) {
148        let mut slices = self.slices().into_iter();
149        while let Some(sl) = slices.next() {
150            let l = sl.len();
151            if offset >= l {
152                offset -= l;
153                continue;
154            }
155            let sl = &sl[offset..];
156            let cp = sl.len().min(dst.len());
157            let (target, new_dst) = dst.split_at_mut(cp);
158            apply(&sl[..cp], target);
159
160            // We're done.
161            if new_dst.len() == 0 {
162                return;
163            }
164
165            dst = new_dst;
166            offset = 0;
167        }
168        assert_eq!(dst.len(), 0, "failed to fill dst");
169    }
170}
171
172impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
173    fn len(&self) -> usize {
174        self.slices().into_iter().map(|s| s.len()).sum()
175    }
176}
177
178impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
179    fn slice(self, byte_range: Range<u32>) -> Self {
180        let Self { mut storage, start: mut self_start, end: mut self_end } = self;
181        let Range { start: byte_start, end: byte_end } = byte_range;
182        let byte_start =
183            usize::try_from(byte_start).expect("range start index out of range for usize");
184        let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
185        assert!(byte_end >= byte_start);
186        let mut storage_iter =
187            (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
188                let slice_len = slice.len();
189                let item = Some((*total_len, slice));
190                *total_len += slice_len;
191                item
192            });
193
194        // Keep track of whether the start was inside the range, we should panic
195        // even on an empty range out of start bounds.
196        let mut start_offset = None;
197        let mut final_len = 0;
198        while let Some((sl_offset, sl)) = storage_iter.next() {
199            let orig_len = sl.len();
200
201            // Advance until the start of the specified range, discarding unused
202            // slices.
203            if sl_offset + orig_len < byte_start {
204                *sl = &[];
205                self_start += 1;
206                continue;
207            }
208            // Discard any empty slices at the end.
209            if sl_offset >= byte_end {
210                *sl = &[];
211                self_end -= 1;
212                continue;
213            }
214
215            let sl_start = byte_start.saturating_sub(sl_offset);
216            let sl_end = sl.len().min(byte_end - sl_offset);
217            *sl = &sl[sl_start..sl_end];
218
219            match start_offset {
220                Some(_) => (),
221                None => {
222                    // Keep track of the start offset of the first slice.
223                    start_offset = Some(sl_offset + sl_start);
224                    // Avoid producing an empty slice if we haven't added
225                    // anything yet.
226                    if sl.len() == 0 {
227                        self_start += 1;
228                    }
229                }
230            }
231            final_len += sl.len();
232        }
233        // Verify that the entire range was consumed.
234        assert_eq!(
235            // If we didn't use start_offset the only valid value for
236            // `byte_start` is zero.
237            start_offset.unwrap_or(0),
238            byte_start,
239            "range start index out of range {byte_range:?}"
240        );
241        assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
242
243        // Canonicalize an empty payload.
244        if self_start == self_end {
245            self_start = 0;
246            self_end = 0;
247        }
248        Self { storage, start: self_start, end: self_end }
249    }
250
251    fn new_empty() -> Self {
252        Self { storage: [&[]; N], start: 0, end: 0 }
253    }
254
255    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
256        self.apply_copy(offset, dst, |src, dst| {
257            dst.copy_from_slice(src);
258        });
259    }
260
261    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
262        self.apply_copy(offset, dst, |src, dst| {
263            // TODO(https://github.com/rust-lang/rust/issues/79995): Replace unsafe
264            // with copy_from_slice when stabiliized.
265            // SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
266            let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
267            dst.copy_from_slice(&uninit_src);
268        });
269    }
270}
271
272impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
273    fn bytes_len(&self) -> usize {
274        self.len()
275    }
276
277    fn serialize(&self, buffer: &mut [u8]) {
278        self.partial_copy(0, buffer);
279    }
280}
281
282#[cfg(test)]
283mod test {
284    use super::*;
285    use alloc::format;
286
287    use packet::Serializer as _;
288    use proptest::test_runner::Config;
289    use proptest::{prop_assert_eq, proptest};
290    use proptest_support::failed_seeds_no_std;
291    use test_case::test_case;
292
293    const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
294    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
295    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
296    #[test_case(FragmentedPayload::new([
297        &EXAMPLE_DATA[0..2],
298        &EXAMPLE_DATA[2..5],
299        &EXAMPLE_DATA[5..],
300    ]); "split twice")]
301    #[test_case(FragmentedPayload::<4>::from_iter([
302        &EXAMPLE_DATA[0..2],
303        &EXAMPLE_DATA[2..5],
304        &EXAMPLE_DATA[5..],
305    ]); "partial twice")]
306    fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
307        let serialized = payload
308            .into_serializer()
309            .serialize_vec_outer()
310            .expect("should serialize")
311            .unwrap_b()
312            .into_inner();
313        assert_eq!(&serialized[..], EXAMPLE_DATA);
314    }
315
316    #[test]
317    #[should_panic(expected = "range start index out of range")]
318    fn slice_start_out_of_bounds() {
319        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
320        let bad_len = len + 1;
321        // Like for standard slices, this shouldn't succeed if the start length
322        // is out of bounds, even if the total range is empty.
323        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
324    }
325
326    #[test]
327    #[should_panic(expected = "range end index out of range")]
328    fn slice_end_out_of_bounds() {
329        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
330        let bad_len = len + 1;
331        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
332    }
333
334    #[test]
335    fn canon_empty_payload() {
336        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
337        assert_eq!(
338            FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
339            FragmentedPayload::new_empty()
340        );
341        assert_eq!(
342            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
343            FragmentedPayload::new_empty()
344        );
345        assert_eq!(
346            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
347            FragmentedPayload::new_empty()
348        );
349    }
350
351    const TEST_BYTES: &'static [u8] = b"Hello World!";
352    proptest! {
353        #![proptest_config(Config {
354            // Add all failed seeds here.
355            failure_persistence: failed_seeds_no_std!(),
356            ..Config::default()
357        })]
358
359        #[test]
360        fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
361            prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
362        }
363
364        #[test]
365        fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
366            prop_assert_eq!(payload.len(), TEST_BYTES.len())
367        }
368
369        #[test]
370        fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
371            let want = &TEST_BYTES[start..end];
372            let start = u32::try_from(start).unwrap();
373            let end = u32::try_from(end).unwrap();
374            prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
375        }
376
377        #[test]
378        fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
379            let mut buffer = [0; TEST_BYTES.len()];
380            let buffer = &mut buffer[0..(end-start)];
381            payload.partial_copy(start, buffer);
382            prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
383        }
384    }
385
386    mod fragmented_payload {
387        use super::*;
388
389        use proptest::strategy::{Just, Strategy};
390        use rand::Rng as _;
391
392        const TEST_STORAGE: usize = 5;
393        type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
394        pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
395            (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
396                (0..slices)
397                    .scan(0, |st, slice| {
398                        let len = if slice == slices - 1 {
399                            TEST_BYTES.len() - *st
400                        } else {
401                            rng.gen_range(0..=(TEST_BYTES.len() - *st))
402                        };
403                        let start = *st;
404                        *st += len;
405                        Some(&TEST_BYTES[start..*st])
406                    })
407                    .collect()
408            })
409        }
410
411        pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
412        {
413            (
414                with_payload(),
415                (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
416            )
417        }
418    }
419}