1use core::iter::FromIterator;
8use core::ops::Range;
9
10use alloc::vec::Vec;
11use core::mem::MaybeUninit;
12use net_types::ip::{Ip, IpVersion};
13use packet::InnerPacketBuilder;
14use static_assertions::const_assert;
15
16use crate::ip::Mms;
17use crate::tcp::segment::{Payload, PayloadLen};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Control {
22 SYN,
24 FIN,
26 RST,
28}
29
30impl Control {
31 pub fn has_sequence_no(self) -> bool {
34 match self {
35 Control::SYN | Control::FIN => true,
36 Control::RST => false,
37 }
38 }
39}
40
41const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
42
43#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
47pub struct Mss(u16);
48
49const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV4.get());
50const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV6.get());
51const_assert!(Mss::MIN.get() as usize >= packet_formats::tcp::MAX_OPTIONS_LEN);
52
53impl Mss {
54 pub const MIN: Mss = Mss(216);
75
76 pub const DEFAULT_IPV4: Mss = Mss(536);
81
82 pub const DEFAULT_IPV6: Mss = Mss(1220);
87
88 pub const fn new(mss: u16) -> Option<Self> {
90 if mss < Self::MIN.get() { None } else { Some(Mss(mss)) }
91 }
92
93 pub fn from_mms(mms: Mms) -> Option<Self> {
95 let mss = u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX);
96 Self::new(mss)
97 }
98
99 pub const fn default<I: Ip>() -> Self {
101 match I::VERSION {
102 IpVersion::V4 => Self::DEFAULT_IPV4,
103 IpVersion::V6 => Self::DEFAULT_IPV6,
104 }
105 }
106
107 pub const fn get(&self) -> u16 {
109 let Self(mss) = *self;
110 mss
111 }
112}
113
114impl From<Mss> for u32 {
115 fn from(Mss(mss): Mss) -> Self {
116 u32::from(mss)
117 }
118}
119
120impl From<Mss> for usize {
121 fn from(Mss(mss): Mss) -> Self {
122 usize::from(mss)
123 }
124}
125
126#[derive(Copy, Clone, Debug, PartialEq)]
128pub struct FragmentedPayload<'a, const N: usize> {
129 storage: [&'a [u8]; N],
130 start: usize,
135 end: usize,
136}
137
138impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
145 fn from_iter<T>(iter: T) -> Self
146 where
147 T: IntoIterator<Item = &'a [u8]>,
148 {
149 let Self { storage, start, end } = Self::new_empty();
150 let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
151 storage[end] = sl;
152 (storage, end + 1)
153 });
154 Self { storage, start, end }
155 }
156}
157
158impl<'a, const N: usize> FragmentedPayload<'a, N> {
159 pub fn new(values: [&'a [u8]; N]) -> Self {
161 Self { storage: values, start: 0, end: N }
162 }
163
164 pub fn new_contiguous(value: &'a [u8]) -> Self {
166 core::iter::once(value).collect()
167 }
168
169 pub fn to_vec(self) -> Vec<u8> {
171 self.slices().concat()
172 }
173
174 fn slices(&self) -> &[&'a [u8]] {
175 let Self { storage, start, end } = self;
176 &storage[*start..*end]
177 }
178
179 fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
182 &self,
183 mut offset: usize,
184 mut dst: &mut [T],
185 apply: F,
186 ) {
187 let mut slices = self.slices().into_iter();
188 while let Some(sl) = slices.next() {
189 let l = sl.len();
190 if offset >= l {
191 offset -= l;
192 continue;
193 }
194 let sl = &sl[offset..];
195 let cp = sl.len().min(dst.len());
196 let (target, new_dst) = dst.split_at_mut(cp);
197 apply(&sl[..cp], target);
198
199 if new_dst.len() == 0 {
201 return;
202 }
203
204 dst = new_dst;
205 offset = 0;
206 }
207 assert_eq!(dst.len(), 0, "failed to fill dst");
208 }
209}
210
211impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
212 fn len(&self) -> usize {
213 self.slices().into_iter().map(|s| s.len()).sum()
214 }
215}
216
217impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
218 fn slice(self, byte_range: Range<u32>) -> Self {
219 let Self { mut storage, start: mut self_start, end: mut self_end } = self;
220 let Range { start: byte_start, end: byte_end } = byte_range;
221 let byte_start =
222 usize::try_from(byte_start).expect("range start index out of range for usize");
223 let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
224 assert!(byte_end >= byte_start);
225 let mut storage_iter =
226 (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
227 let slice_len = slice.len();
228 let item = Some((*total_len, slice));
229 *total_len += slice_len;
230 item
231 });
232
233 let mut start_offset = None;
236 let mut final_len = 0;
237 while let Some((sl_offset, sl)) = storage_iter.next() {
238 let orig_len = sl.len();
239
240 if sl_offset + orig_len < byte_start {
243 *sl = &[];
244 self_start += 1;
245 continue;
246 }
247 if sl_offset >= byte_end {
249 *sl = &[];
250 self_end -= 1;
251 continue;
252 }
253
254 let sl_start = byte_start.saturating_sub(sl_offset);
255 let sl_end = sl.len().min(byte_end - sl_offset);
256 *sl = &sl[sl_start..sl_end];
257
258 match start_offset {
259 Some(_) => (),
260 None => {
261 start_offset = Some(sl_offset + sl_start);
263 if sl.len() == 0 {
266 self_start += 1;
267 }
268 }
269 }
270 final_len += sl.len();
271 }
272 assert_eq!(
274 start_offset.unwrap_or(0),
277 byte_start,
278 "range start index out of range {byte_range:?}"
279 );
280 assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
281
282 if self_start == self_end {
284 self_start = 0;
285 self_end = 0;
286 }
287 Self { storage, start: self_start, end: self_end }
288 }
289
290 fn new_empty() -> Self {
291 Self { storage: [&[]; N], start: 0, end: 0 }
292 }
293
294 fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
295 self.apply_copy(offset, dst, |src, dst| {
296 dst.copy_from_slice(src);
297 });
298 }
299
300 fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
301 self.apply_copy(offset, dst, |src, dst| {
302 let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
306 dst.copy_from_slice(&uninit_src);
307 });
308 }
309}
310
311impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
312 fn bytes_len(&self) -> usize {
313 self.len()
314 }
315
316 fn serialize(&self, buffer: &mut [u8]) {
317 self.partial_copy(0, buffer);
318 }
319}
320
321#[cfg(test)]
322mod test {
323 use super::*;
324 use alloc::format;
325
326 use packet::Serializer as _;
327 use proptest::test_runner::Config;
328 use proptest::{prop_assert_eq, proptest};
329 use proptest_support::failed_seeds_no_std;
330 use test_case::test_case;
331
332 const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
333 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
334 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
335 #[test_case(FragmentedPayload::new([
336 &EXAMPLE_DATA[0..2],
337 &EXAMPLE_DATA[2..5],
338 &EXAMPLE_DATA[5..],
339 ]); "split twice")]
340 #[test_case(FragmentedPayload::<4>::from_iter([
341 &EXAMPLE_DATA[0..2],
342 &EXAMPLE_DATA[2..5],
343 &EXAMPLE_DATA[5..],
344 ]); "partial twice")]
345 fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
346 let serialized = payload
347 .into_serializer()
348 .serialize_vec_outer()
349 .expect("should serialize")
350 .unwrap_b()
351 .into_inner();
352 assert_eq!(&serialized[..], EXAMPLE_DATA);
353 }
354
355 #[test]
356 #[should_panic(expected = "range start index out of range")]
357 fn slice_start_out_of_bounds() {
358 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
359 let bad_len = len + 1;
360 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
363 }
364
365 #[test]
366 #[should_panic(expected = "range end index out of range")]
367 fn slice_end_out_of_bounds() {
368 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
369 let bad_len = len + 1;
370 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
371 }
372
373 #[test]
374 fn canon_empty_payload() {
375 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
376 assert_eq!(
377 FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
378 FragmentedPayload::new_empty()
379 );
380 assert_eq!(
381 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
382 FragmentedPayload::new_empty()
383 );
384 assert_eq!(
385 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
386 FragmentedPayload::new_empty()
387 );
388 }
389
390 const TEST_BYTES: &'static [u8] = b"Hello World!";
391 proptest! {
392 #![proptest_config(Config {
393 failure_persistence: failed_seeds_no_std!(),
395 ..Config::default()
396 })]
397
398 #[test]
399 fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
400 prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
401 }
402
403 #[test]
404 fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
405 prop_assert_eq!(payload.len(), TEST_BYTES.len())
406 }
407
408 #[test]
409 fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
410 let want = &TEST_BYTES[start..end];
411 let start = u32::try_from(start).unwrap();
412 let end = u32::try_from(end).unwrap();
413 prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
414 }
415
416 #[test]
417 fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
418 let mut buffer = [0; TEST_BYTES.len()];
419 let buffer = &mut buffer[0..(end-start)];
420 payload.partial_copy(start, buffer);
421 prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
422 }
423 }
424
425 mod fragmented_payload {
426 use super::*;
427
428 use proptest::strategy::{Just, Strategy};
429 use rand::Rng as _;
430
431 const TEST_STORAGE: usize = 5;
432 type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
433 pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
434 (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
435 (0..slices)
436 .scan(0, |st, slice| {
437 let len = if slice == slices - 1 {
438 TEST_BYTES.len() - *st
439 } else {
440 rng.random_range(0..=(TEST_BYTES.len() - *st))
441 };
442 let start = *st;
443 *st += len;
444 Some(&TEST_BYTES[start..*st])
445 })
446 .collect()
447 })
448 }
449
450 pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
451 {
452 (
453 with_payload(),
454 (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
455 )
456 }
457 }
458}