1use core::iter::FromIterator;
8use core::num::NonZeroU16;
9use core::ops::Range;
10
11use alloc::vec::Vec;
12use core::mem::MaybeUninit;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15use packet_formats::ip::IpExt;
16
17use crate::ip::Mms;
18use crate::tcp::segment::{Payload, PayloadLen};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Control {
23 SYN,
25 FIN,
27 RST,
29}
30
31impl Control {
32 pub fn has_sequence_no(self) -> bool {
35 match self {
36 Control::SYN | Control::FIN => true,
37 Control::RST => false,
38 }
39 }
40}
41
42const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
43
44#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
46pub struct Mss(pub NonZeroU16);
47
48impl Mss {
49 pub fn from_mms<I: IpExt>(mms: Mms) -> Option<Self> {
51 NonZeroU16::new(
52 u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX),
53 )
54 .map(Self)
55 }
56
57 pub const fn default<I: Ip>() -> Self {
59 match I::VERSION {
64 IpVersion::V4 => Mss(NonZeroU16::new(536).unwrap()),
65 IpVersion::V6 => Mss(NonZeroU16::new(1220).unwrap()),
66 }
67 }
68
69 pub const fn get(&self) -> NonZeroU16 {
71 let Self(mss) = *self;
72 mss
73 }
74}
75
76impl From<Mss> for u32 {
77 fn from(Mss(mss): Mss) -> Self {
78 u32::from(mss.get())
79 }
80}
81
82impl From<Mss> for usize {
83 fn from(Mss(mss): Mss) -> Self {
84 usize::from(mss.get())
85 }
86}
87
88#[derive(Copy, Clone, Debug, PartialEq)]
90pub struct FragmentedPayload<'a, const N: usize> {
91 storage: [&'a [u8]; N],
92 start: usize,
97 end: usize,
98}
99
100impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
107 fn from_iter<T>(iter: T) -> Self
108 where
109 T: IntoIterator<Item = &'a [u8]>,
110 {
111 let Self { storage, start, end } = Self::new_empty();
112 let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
113 storage[end] = sl;
114 (storage, end + 1)
115 });
116 Self { storage, start, end }
117 }
118}
119
120impl<'a, const N: usize> FragmentedPayload<'a, N> {
121 pub fn new(values: [&'a [u8]; N]) -> Self {
123 Self { storage: values, start: 0, end: N }
124 }
125
126 pub fn new_contiguous(value: &'a [u8]) -> Self {
128 core::iter::once(value).collect()
129 }
130
131 pub fn to_vec(self) -> Vec<u8> {
133 self.slices().concat()
134 }
135
136 fn slices(&self) -> &[&'a [u8]] {
137 let Self { storage, start, end } = self;
138 &storage[*start..*end]
139 }
140
141 fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
144 &self,
145 mut offset: usize,
146 mut dst: &mut [T],
147 apply: F,
148 ) {
149 let mut slices = self.slices().into_iter();
150 while let Some(sl) = slices.next() {
151 let l = sl.len();
152 if offset >= l {
153 offset -= l;
154 continue;
155 }
156 let sl = &sl[offset..];
157 let cp = sl.len().min(dst.len());
158 let (target, new_dst) = dst.split_at_mut(cp);
159 apply(&sl[..cp], target);
160
161 if new_dst.len() == 0 {
163 return;
164 }
165
166 dst = new_dst;
167 offset = 0;
168 }
169 assert_eq!(dst.len(), 0, "failed to fill dst");
170 }
171}
172
173impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
174 fn len(&self) -> usize {
175 self.slices().into_iter().map(|s| s.len()).sum()
176 }
177}
178
179impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
180 fn slice(self, byte_range: Range<u32>) -> Self {
181 let Self { mut storage, start: mut self_start, end: mut self_end } = self;
182 let Range { start: byte_start, end: byte_end } = byte_range;
183 let byte_start =
184 usize::try_from(byte_start).expect("range start index out of range for usize");
185 let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
186 assert!(byte_end >= byte_start);
187 let mut storage_iter =
188 (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
189 let slice_len = slice.len();
190 let item = Some((*total_len, slice));
191 *total_len += slice_len;
192 item
193 });
194
195 let mut start_offset = None;
198 let mut final_len = 0;
199 while let Some((sl_offset, sl)) = storage_iter.next() {
200 let orig_len = sl.len();
201
202 if sl_offset + orig_len < byte_start {
205 *sl = &[];
206 self_start += 1;
207 continue;
208 }
209 if sl_offset >= byte_end {
211 *sl = &[];
212 self_end -= 1;
213 continue;
214 }
215
216 let sl_start = byte_start.saturating_sub(sl_offset);
217 let sl_end = sl.len().min(byte_end - sl_offset);
218 *sl = &sl[sl_start..sl_end];
219
220 match start_offset {
221 Some(_) => (),
222 None => {
223 start_offset = Some(sl_offset + sl_start);
225 if sl.len() == 0 {
228 self_start += 1;
229 }
230 }
231 }
232 final_len += sl.len();
233 }
234 assert_eq!(
236 start_offset.unwrap_or(0),
239 byte_start,
240 "range start index out of range {byte_range:?}"
241 );
242 assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
243
244 if self_start == self_end {
246 self_start = 0;
247 self_end = 0;
248 }
249 Self { storage, start: self_start, end: self_end }
250 }
251
252 fn new_empty() -> Self {
253 Self { storage: [&[]; N], start: 0, end: 0 }
254 }
255
256 fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
257 self.apply_copy(offset, dst, |src, dst| {
258 dst.copy_from_slice(src);
259 });
260 }
261
262 fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
263 self.apply_copy(offset, dst, |src, dst| {
264 let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
268 dst.copy_from_slice(&uninit_src);
269 });
270 }
271}
272
273impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
274 fn bytes_len(&self) -> usize {
275 self.len()
276 }
277
278 fn serialize(&self, buffer: &mut [u8]) {
279 self.partial_copy(0, buffer);
280 }
281}
282
283#[cfg(test)]
284mod test {
285 use super::*;
286 use alloc::format;
287
288 use packet::Serializer as _;
289 use proptest::test_runner::Config;
290 use proptest::{prop_assert_eq, proptest};
291 use proptest_support::failed_seeds_no_std;
292 use test_case::test_case;
293
294 const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
295 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
296 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
297 #[test_case(FragmentedPayload::new([
298 &EXAMPLE_DATA[0..2],
299 &EXAMPLE_DATA[2..5],
300 &EXAMPLE_DATA[5..],
301 ]); "split twice")]
302 #[test_case(FragmentedPayload::<4>::from_iter([
303 &EXAMPLE_DATA[0..2],
304 &EXAMPLE_DATA[2..5],
305 &EXAMPLE_DATA[5..],
306 ]); "partial twice")]
307 fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
308 let serialized = payload
309 .into_serializer()
310 .serialize_vec_outer()
311 .expect("should serialize")
312 .unwrap_b()
313 .into_inner();
314 assert_eq!(&serialized[..], EXAMPLE_DATA);
315 }
316
317 #[test]
318 #[should_panic(expected = "range start index out of range")]
319 fn slice_start_out_of_bounds() {
320 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
321 let bad_len = len + 1;
322 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
325 }
326
327 #[test]
328 #[should_panic(expected = "range end index out of range")]
329 fn slice_end_out_of_bounds() {
330 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
331 let bad_len = len + 1;
332 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
333 }
334
335 #[test]
336 fn canon_empty_payload() {
337 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
338 assert_eq!(
339 FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
340 FragmentedPayload::new_empty()
341 );
342 assert_eq!(
343 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
344 FragmentedPayload::new_empty()
345 );
346 assert_eq!(
347 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
348 FragmentedPayload::new_empty()
349 );
350 }
351
352 const TEST_BYTES: &'static [u8] = b"Hello World!";
353 proptest! {
354 #![proptest_config(Config {
355 failure_persistence: failed_seeds_no_std!(),
357 ..Config::default()
358 })]
359
360 #[test]
361 fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
362 prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
363 }
364
365 #[test]
366 fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
367 prop_assert_eq!(payload.len(), TEST_BYTES.len())
368 }
369
370 #[test]
371 fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
372 let want = &TEST_BYTES[start..end];
373 let start = u32::try_from(start).unwrap();
374 let end = u32::try_from(end).unwrap();
375 prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
376 }
377
378 #[test]
379 fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
380 let mut buffer = [0; TEST_BYTES.len()];
381 let buffer = &mut buffer[0..(end-start)];
382 payload.partial_copy(start, buffer);
383 prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
384 }
385 }
386
387 mod fragmented_payload {
388 use super::*;
389
390 use proptest::strategy::{Just, Strategy};
391 use rand::Rng as _;
392
393 const TEST_STORAGE: usize = 5;
394 type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
395 pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
396 (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
397 (0..slices)
398 .scan(0, |st, slice| {
399 let len = if slice == slices - 1 {
400 TEST_BYTES.len() - *st
401 } else {
402 rng.gen_range(0..=(TEST_BYTES.len() - *st))
403 };
404 let start = *st;
405 *st += len;
406 Some(&TEST_BYTES[start..*st])
407 })
408 .collect()
409 })
410 }
411
412 pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
413 {
414 (
415 with_payload(),
416 (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
417 )
418 }
419 }
420}