netstack3_base/tcp/
base.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The Transmission Control Protocol (TCP).
6
7use core::iter::FromIterator;
8use core::ops::Range;
9
10use alloc::vec::Vec;
11use core::mem::MaybeUninit;
12use net_types::ip::{Ip, IpVersion};
13use packet::InnerPacketBuilder;
14use static_assertions::const_assert;
15
16use crate::ip::Mms;
17use crate::tcp::segment::{Payload, PayloadLen};
18
19/// Control flags that can alter the state of a TCP control block.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Control {
22    /// Corresponds to the SYN bit in a TCP segment.
23    SYN,
24    /// Corresponds to the FIN bit in a TCP segment.
25    FIN,
26    /// Corresponds to the RST bit in a TCP segment.
27    RST,
28}
29
30impl Control {
31    /// Returns whether the control flag consumes one byte from the sequence
32    /// number space.
33    pub fn has_sequence_no(self) -> bool {
34        match self {
35            Control::SYN | Control::FIN => true,
36            Control::RST => false,
37        }
38    }
39}
40
41const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
42
43/// Maximum segment size, that is the maximum TCP payload one segment can carry.
44///
45/// `Mss` also acts as a witness that the contained value is >= `Mss::MIN`.
46#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
47pub struct Mss(u16);
48
49const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV4.get());
50const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV6.get());
51const_assert!(Mss::MIN.get() as usize >= packet_formats::tcp::MAX_OPTIONS_LEN);
52
53impl Mss {
54    /// The minimum MSS allowed by TCP.
55    ///
56    /// Although enforcing a minimum MSS is outside the recommendations of any
57    /// RFC, it is a common practice on other platforms and has multiple
58    /// benefits:
59    ///   1) Ensures there is enough space to transmit TCP Options & IP Options.
60    ///      See RFC 6691 section 2, which clarifies that
61    ///          The TCP MSS OPTION [...] SHOULD NOT be decreased to account for
62    ///          any possible IP or TCP options; conversely, the sender MUST
63    ///          reduce the TCP data length to account for any IP or TCP options
64    ///          that it is including in the packets that it sends.
65    ///   2) Protects against DOS attacks in which the attacker initiates TCP
66    ///      connections with an intentionally small MSS to incur additional
67    ///      packet processing overhead on the victim. See
68    ///      * FreeBSD: https://www.cve.org/CVERecord?id=CVE-2004-0002
69    ///      * Linux: https://www.cve.org/CVERecord?id=CVE-2019-11479
70    ///
71    /// Here, the value 216 is inspired by FreeBSD. It's large enough to satisfy
72    /// points 1 & 2 from above, while remaining small enough to support all
73    /// link-layer technologies on the open Internet.
74    pub const MIN: Mss = Mss(216);
75
76    /// Per RFC 9293 Section 3.7.1:
77    ///  If an MSS Option is not received at connection setup, TCP
78    ///  implementations MUST assume a default send MSS of 536 (576 - 40) for
79    ///  IPv4.
80    pub const DEFAULT_IPV4: Mss = Mss(536);
81
82    /// Per RFC 9293 Section 3.7.1:
83    ///  If an MSS Option is not received at connection setup, TCP
84    ///  implementations MUST assume a default send MSS of [...] 1220
85    /// (1280 - 60) for IPv6 (MUST-15).
86    pub const DEFAULT_IPV6: Mss = Mss(1220);
87
88    /// Creates `Mss`, provided the given value satisfies the requirements.
89    pub const fn new(mss: u16) -> Option<Self> {
90        if mss < Self::MIN.get() { None } else { Some(Mss(mss)) }
91    }
92
93    /// Creates MSS from the maximum message size of the IP layer.
94    pub fn from_mms(mms: Mms) -> Option<Self> {
95        let mss = u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX);
96        Self::new(mss)
97    }
98
99    /// Create a new [`Mss`] with the IP-version default value, as defined by RFC 9293.
100    pub const fn default<I: Ip>() -> Self {
101        match I::VERSION {
102            IpVersion::V4 => Self::DEFAULT_IPV4,
103            IpVersion::V6 => Self::DEFAULT_IPV6,
104        }
105    }
106
107    /// Gets the numeric value of the MSS.
108    pub const fn get(&self) -> u16 {
109        let Self(mss) = *self;
110        mss
111    }
112}
113
114impl From<Mss> for u32 {
115    fn from(Mss(mss): Mss) -> Self {
116        u32::from(mss)
117    }
118}
119
120impl From<Mss> for usize {
121    fn from(Mss(mss): Mss) -> Self {
122        usize::from(mss)
123    }
124}
125
126/// An implementation of [`Payload`] backed by up to `N` byte slices.
127#[derive(Copy, Clone, Debug, PartialEq)]
128pub struct FragmentedPayload<'a, const N: usize> {
129    storage: [&'a [u8]; N],
130    // NB: Not using `Range` because it is not `Copy`.
131    //
132    // Start is inclusive, end is exclusive; so this is equivalent to
133    // `start..end` ranges.
134    start: usize,
135    end: usize,
136}
137
138/// Creates a new `FragmentedPayload` possibly without using the entire
139/// storage capacity `N`.
140///
141/// # Panics
142///
143/// Panics if the iterator contains more than `N` items.
144impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
145    fn from_iter<T>(iter: T) -> Self
146    where
147        T: IntoIterator<Item = &'a [u8]>,
148    {
149        let Self { storage, start, end } = Self::new_empty();
150        let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
151            storage[end] = sl;
152            (storage, end + 1)
153        });
154        Self { storage, start, end }
155    }
156}
157
158impl<'a, const N: usize> FragmentedPayload<'a, N> {
159    /// Creates a new `FragmentedPayload` with the slices in `values`.
160    pub fn new(values: [&'a [u8]; N]) -> Self {
161        Self { storage: values, start: 0, end: N }
162    }
163
164    /// Creates a new `FragmentedPayload` with a single contiguous slice.
165    pub fn new_contiguous(value: &'a [u8]) -> Self {
166        core::iter::once(value).collect()
167    }
168
169    /// Converts this [`FragmentedPayload`] into an owned `Vec`.
170    pub fn to_vec(self) -> Vec<u8> {
171        self.slices().concat()
172    }
173
174    fn slices(&self) -> &[&'a [u8]] {
175        let Self { storage, start, end } = self;
176        &storage[*start..*end]
177    }
178
179    /// Extracted function to implement [`Payload::partial_copy`] and
180    /// [`Payload::partial_copy_uninit`].
181    fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
182        &self,
183        mut offset: usize,
184        mut dst: &mut [T],
185        apply: F,
186    ) {
187        let mut slices = self.slices().into_iter();
188        while let Some(sl) = slices.next() {
189            let l = sl.len();
190            if offset >= l {
191                offset -= l;
192                continue;
193            }
194            let sl = &sl[offset..];
195            let cp = sl.len().min(dst.len());
196            let (target, new_dst) = dst.split_at_mut(cp);
197            apply(&sl[..cp], target);
198
199            // We're done.
200            if new_dst.len() == 0 {
201                return;
202            }
203
204            dst = new_dst;
205            offset = 0;
206        }
207        assert_eq!(dst.len(), 0, "failed to fill dst");
208    }
209}
210
211impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
212    fn len(&self) -> usize {
213        self.slices().into_iter().map(|s| s.len()).sum()
214    }
215}
216
217impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
218    fn slice(self, byte_range: Range<u32>) -> Self {
219        let Self { mut storage, start: mut self_start, end: mut self_end } = self;
220        let Range { start: byte_start, end: byte_end } = byte_range;
221        let byte_start =
222            usize::try_from(byte_start).expect("range start index out of range for usize");
223        let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
224        assert!(byte_end >= byte_start);
225        let mut storage_iter =
226            (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
227                let slice_len = slice.len();
228                let item = Some((*total_len, slice));
229                *total_len += slice_len;
230                item
231            });
232
233        // Keep track of whether the start was inside the range, we should panic
234        // even on an empty range out of start bounds.
235        let mut start_offset = None;
236        let mut final_len = 0;
237        while let Some((sl_offset, sl)) = storage_iter.next() {
238            let orig_len = sl.len();
239
240            // Advance until the start of the specified range, discarding unused
241            // slices.
242            if sl_offset + orig_len < byte_start {
243                *sl = &[];
244                self_start += 1;
245                continue;
246            }
247            // Discard any empty slices at the end.
248            if sl_offset >= byte_end {
249                *sl = &[];
250                self_end -= 1;
251                continue;
252            }
253
254            let sl_start = byte_start.saturating_sub(sl_offset);
255            let sl_end = sl.len().min(byte_end - sl_offset);
256            *sl = &sl[sl_start..sl_end];
257
258            match start_offset {
259                Some(_) => (),
260                None => {
261                    // Keep track of the start offset of the first slice.
262                    start_offset = Some(sl_offset + sl_start);
263                    // Avoid producing an empty slice if we haven't added
264                    // anything yet.
265                    if sl.len() == 0 {
266                        self_start += 1;
267                    }
268                }
269            }
270            final_len += sl.len();
271        }
272        // Verify that the entire range was consumed.
273        assert_eq!(
274            // If we didn't use start_offset the only valid value for
275            // `byte_start` is zero.
276            start_offset.unwrap_or(0),
277            byte_start,
278            "range start index out of range {byte_range:?}"
279        );
280        assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
281
282        // Canonicalize an empty payload.
283        if self_start == self_end {
284            self_start = 0;
285            self_end = 0;
286        }
287        Self { storage, start: self_start, end: self_end }
288    }
289
290    fn new_empty() -> Self {
291        Self { storage: [&[]; N], start: 0, end: 0 }
292    }
293
294    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
295        self.apply_copy(offset, dst, |src, dst| {
296            dst.copy_from_slice(src);
297        });
298    }
299
300    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
301        self.apply_copy(offset, dst, |src, dst| {
302            // TODO(https://github.com/rust-lang/rust/issues/79995): Replace unsafe
303            // with copy_from_slice when stabiliized.
304            // SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
305            let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
306            dst.copy_from_slice(&uninit_src);
307        });
308    }
309}
310
311impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
312    fn bytes_len(&self) -> usize {
313        self.len()
314    }
315
316    fn serialize(&self, buffer: &mut [u8]) {
317        self.partial_copy(0, buffer);
318    }
319}
320
321#[cfg(test)]
322mod test {
323    use super::*;
324    use alloc::format;
325
326    use packet::Serializer as _;
327    use proptest::test_runner::Config;
328    use proptest::{prop_assert_eq, proptest};
329    use proptest_support::failed_seeds_no_std;
330    use test_case::test_case;
331
332    const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
333    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
334    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
335    #[test_case(FragmentedPayload::new([
336        &EXAMPLE_DATA[0..2],
337        &EXAMPLE_DATA[2..5],
338        &EXAMPLE_DATA[5..],
339    ]); "split twice")]
340    #[test_case(FragmentedPayload::<4>::from_iter([
341        &EXAMPLE_DATA[0..2],
342        &EXAMPLE_DATA[2..5],
343        &EXAMPLE_DATA[5..],
344    ]); "partial twice")]
345    fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
346        let serialized = payload
347            .into_serializer()
348            .serialize_vec_outer()
349            .expect("should serialize")
350            .unwrap_b()
351            .into_inner();
352        assert_eq!(&serialized[..], EXAMPLE_DATA);
353    }
354
355    #[test]
356    #[should_panic(expected = "range start index out of range")]
357    fn slice_start_out_of_bounds() {
358        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
359        let bad_len = len + 1;
360        // Like for standard slices, this shouldn't succeed if the start length
361        // is out of bounds, even if the total range is empty.
362        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
363    }
364
365    #[test]
366    #[should_panic(expected = "range end index out of range")]
367    fn slice_end_out_of_bounds() {
368        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
369        let bad_len = len + 1;
370        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
371    }
372
373    #[test]
374    fn canon_empty_payload() {
375        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
376        assert_eq!(
377            FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
378            FragmentedPayload::new_empty()
379        );
380        assert_eq!(
381            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
382            FragmentedPayload::new_empty()
383        );
384        assert_eq!(
385            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
386            FragmentedPayload::new_empty()
387        );
388    }
389
390    const TEST_BYTES: &'static [u8] = b"Hello World!";
391    proptest! {
392        #![proptest_config(Config {
393            // Add all failed seeds here.
394            failure_persistence: failed_seeds_no_std!(),
395            ..Config::default()
396        })]
397
398        #[test]
399        fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
400            prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
401        }
402
403        #[test]
404        fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
405            prop_assert_eq!(payload.len(), TEST_BYTES.len())
406        }
407
408        #[test]
409        fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
410            let want = &TEST_BYTES[start..end];
411            let start = u32::try_from(start).unwrap();
412            let end = u32::try_from(end).unwrap();
413            prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
414        }
415
416        #[test]
417        fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
418            let mut buffer = [0; TEST_BYTES.len()];
419            let buffer = &mut buffer[0..(end-start)];
420            payload.partial_copy(start, buffer);
421            prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
422        }
423    }
424
425    mod fragmented_payload {
426        use super::*;
427
428        use proptest::strategy::{Just, Strategy};
429        use rand::Rng as _;
430
431        const TEST_STORAGE: usize = 5;
432        type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
433        pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
434            (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
435                (0..slices)
436                    .scan(0, |st, slice| {
437                        let len = if slice == slices - 1 {
438                            TEST_BYTES.len() - *st
439                        } else {
440                            rng.random_range(0..=(TEST_BYTES.len() - *st))
441                        };
442                        let start = *st;
443                        *st += len;
444                        Some(&TEST_BYTES[start..*st])
445                    })
446                    .collect()
447            })
448        }
449
450        pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
451        {
452            (
453                with_payload(),
454                (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
455            )
456        }
457    }
458}