1use core::iter::FromIterator;
8use core::num::NonZeroU16;
9use core::ops::Range;
10
11use alloc::vec::Vec;
12use core::mem::MaybeUninit;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15
16use crate::ip::Mms;
17use crate::tcp::segment::{Payload, PayloadLen};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Control {
22 SYN,
24 FIN,
26 RST,
28}
29
30impl Control {
31 pub fn has_sequence_no(self) -> bool {
34 match self {
35 Control::SYN | Control::FIN => true,
36 Control::RST => false,
37 }
38 }
39}
40
41const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
42
43#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
45pub struct Mss(pub NonZeroU16);
46
47impl Mss {
48 pub fn from_mms(mms: Mms) -> Option<Self> {
50 NonZeroU16::new(
51 u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX),
52 )
53 .map(Self)
54 }
55
56 pub const fn default<I: Ip>() -> Self {
58 match I::VERSION {
63 IpVersion::V4 => Mss(NonZeroU16::new(536).unwrap()),
64 IpVersion::V6 => Mss(NonZeroU16::new(1220).unwrap()),
65 }
66 }
67
68 pub const fn get(&self) -> NonZeroU16 {
70 let Self(mss) = *self;
71 mss
72 }
73}
74
75impl From<Mss> for u32 {
76 fn from(Mss(mss): Mss) -> Self {
77 u32::from(mss.get())
78 }
79}
80
81impl From<Mss> for usize {
82 fn from(Mss(mss): Mss) -> Self {
83 usize::from(mss.get())
84 }
85}
86
87#[derive(Copy, Clone, Debug, PartialEq)]
89pub struct FragmentedPayload<'a, const N: usize> {
90 storage: [&'a [u8]; N],
91 start: usize,
96 end: usize,
97}
98
99impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
106 fn from_iter<T>(iter: T) -> Self
107 where
108 T: IntoIterator<Item = &'a [u8]>,
109 {
110 let Self { storage, start, end } = Self::new_empty();
111 let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
112 storage[end] = sl;
113 (storage, end + 1)
114 });
115 Self { storage, start, end }
116 }
117}
118
119impl<'a, const N: usize> FragmentedPayload<'a, N> {
120 pub fn new(values: [&'a [u8]; N]) -> Self {
122 Self { storage: values, start: 0, end: N }
123 }
124
125 pub fn new_contiguous(value: &'a [u8]) -> Self {
127 core::iter::once(value).collect()
128 }
129
130 pub fn to_vec(self) -> Vec<u8> {
132 self.slices().concat()
133 }
134
135 fn slices(&self) -> &[&'a [u8]] {
136 let Self { storage, start, end } = self;
137 &storage[*start..*end]
138 }
139
140 fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
143 &self,
144 mut offset: usize,
145 mut dst: &mut [T],
146 apply: F,
147 ) {
148 let mut slices = self.slices().into_iter();
149 while let Some(sl) = slices.next() {
150 let l = sl.len();
151 if offset >= l {
152 offset -= l;
153 continue;
154 }
155 let sl = &sl[offset..];
156 let cp = sl.len().min(dst.len());
157 let (target, new_dst) = dst.split_at_mut(cp);
158 apply(&sl[..cp], target);
159
160 if new_dst.len() == 0 {
162 return;
163 }
164
165 dst = new_dst;
166 offset = 0;
167 }
168 assert_eq!(dst.len(), 0, "failed to fill dst");
169 }
170}
171
172impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
173 fn len(&self) -> usize {
174 self.slices().into_iter().map(|s| s.len()).sum()
175 }
176}
177
178impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
179 fn slice(self, byte_range: Range<u32>) -> Self {
180 let Self { mut storage, start: mut self_start, end: mut self_end } = self;
181 let Range { start: byte_start, end: byte_end } = byte_range;
182 let byte_start =
183 usize::try_from(byte_start).expect("range start index out of range for usize");
184 let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
185 assert!(byte_end >= byte_start);
186 let mut storage_iter =
187 (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
188 let slice_len = slice.len();
189 let item = Some((*total_len, slice));
190 *total_len += slice_len;
191 item
192 });
193
194 let mut start_offset = None;
197 let mut final_len = 0;
198 while let Some((sl_offset, sl)) = storage_iter.next() {
199 let orig_len = sl.len();
200
201 if sl_offset + orig_len < byte_start {
204 *sl = &[];
205 self_start += 1;
206 continue;
207 }
208 if sl_offset >= byte_end {
210 *sl = &[];
211 self_end -= 1;
212 continue;
213 }
214
215 let sl_start = byte_start.saturating_sub(sl_offset);
216 let sl_end = sl.len().min(byte_end - sl_offset);
217 *sl = &sl[sl_start..sl_end];
218
219 match start_offset {
220 Some(_) => (),
221 None => {
222 start_offset = Some(sl_offset + sl_start);
224 if sl.len() == 0 {
227 self_start += 1;
228 }
229 }
230 }
231 final_len += sl.len();
232 }
233 assert_eq!(
235 start_offset.unwrap_or(0),
238 byte_start,
239 "range start index out of range {byte_range:?}"
240 );
241 assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
242
243 if self_start == self_end {
245 self_start = 0;
246 self_end = 0;
247 }
248 Self { storage, start: self_start, end: self_end }
249 }
250
251 fn new_empty() -> Self {
252 Self { storage: [&[]; N], start: 0, end: 0 }
253 }
254
255 fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
256 self.apply_copy(offset, dst, |src, dst| {
257 dst.copy_from_slice(src);
258 });
259 }
260
261 fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
262 self.apply_copy(offset, dst, |src, dst| {
263 let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
267 dst.copy_from_slice(&uninit_src);
268 });
269 }
270}
271
272impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
273 fn bytes_len(&self) -> usize {
274 self.len()
275 }
276
277 fn serialize(&self, buffer: &mut [u8]) {
278 self.partial_copy(0, buffer);
279 }
280}
281
282#[cfg(test)]
283mod test {
284 use super::*;
285 use alloc::format;
286
287 use packet::Serializer as _;
288 use proptest::test_runner::Config;
289 use proptest::{prop_assert_eq, proptest};
290 use proptest_support::failed_seeds_no_std;
291 use test_case::test_case;
292
293 const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
294 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
295 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
296 #[test_case(FragmentedPayload::new([
297 &EXAMPLE_DATA[0..2],
298 &EXAMPLE_DATA[2..5],
299 &EXAMPLE_DATA[5..],
300 ]); "split twice")]
301 #[test_case(FragmentedPayload::<4>::from_iter([
302 &EXAMPLE_DATA[0..2],
303 &EXAMPLE_DATA[2..5],
304 &EXAMPLE_DATA[5..],
305 ]); "partial twice")]
306 fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
307 let serialized = payload
308 .into_serializer()
309 .serialize_vec_outer()
310 .expect("should serialize")
311 .unwrap_b()
312 .into_inner();
313 assert_eq!(&serialized[..], EXAMPLE_DATA);
314 }
315
316 #[test]
317 #[should_panic(expected = "range start index out of range")]
318 fn slice_start_out_of_bounds() {
319 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
320 let bad_len = len + 1;
321 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
324 }
325
326 #[test]
327 #[should_panic(expected = "range end index out of range")]
328 fn slice_end_out_of_bounds() {
329 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
330 let bad_len = len + 1;
331 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
332 }
333
334 #[test]
335 fn canon_empty_payload() {
336 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
337 assert_eq!(
338 FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
339 FragmentedPayload::new_empty()
340 );
341 assert_eq!(
342 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
343 FragmentedPayload::new_empty()
344 );
345 assert_eq!(
346 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
347 FragmentedPayload::new_empty()
348 );
349 }
350
351 const TEST_BYTES: &'static [u8] = b"Hello World!";
352 proptest! {
353 #![proptest_config(Config {
354 failure_persistence: failed_seeds_no_std!(),
356 ..Config::default()
357 })]
358
359 #[test]
360 fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
361 prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
362 }
363
364 #[test]
365 fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
366 prop_assert_eq!(payload.len(), TEST_BYTES.len())
367 }
368
369 #[test]
370 fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
371 let want = &TEST_BYTES[start..end];
372 let start = u32::try_from(start).unwrap();
373 let end = u32::try_from(end).unwrap();
374 prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
375 }
376
377 #[test]
378 fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
379 let mut buffer = [0; TEST_BYTES.len()];
380 let buffer = &mut buffer[0..(end-start)];
381 payload.partial_copy(start, buffer);
382 prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
383 }
384 }
385
386 mod fragmented_payload {
387 use super::*;
388
389 use proptest::strategy::{Just, Strategy};
390 use rand::Rng as _;
391
392 const TEST_STORAGE: usize = 5;
393 type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
394 pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
395 (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
396 (0..slices)
397 .scan(0, |st, slice| {
398 let len = if slice == slices - 1 {
399 TEST_BYTES.len() - *st
400 } else {
401 rng.gen_range(0..=(TEST_BYTES.len() - *st))
402 };
403 let start = *st;
404 *st += len;
405 Some(&TEST_BYTES[start..*st])
406 })
407 .collect()
408 })
409 }
410
411 pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
412 {
413 (
414 with_payload(),
415 (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
416 )
417 }
418 }
419}