1use core::iter::FromIterator;
8use core::ops::Range;
9
10use alloc::vec::Vec;
11use core::mem::MaybeUninit;
12use core::num::NonZeroU16;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15use static_assertions::const_assert;
16
17use crate::ip::Mms;
18use crate::tcp::segment::{Payload, PayloadLen, SegmentOptions};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Control {
23 SYN,
25 FIN,
27 RST,
29}
30
31impl Control {
32 pub fn has_sequence_no(self) -> bool {
35 match self {
36 Control::SYN | Control::FIN => true,
37 Control::RST => false,
38 }
39 }
40}
41
42const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
43
44#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
48pub struct Mss(u16);
49
50const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV4.get());
51const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV6.get());
52const_assert!(Mss::MIN.get() as usize >= packet_formats::tcp::MAX_OPTIONS_LEN);
53
54impl Mss {
55 pub const MIN: Mss = Mss(216);
76
77 pub const DEFAULT_IPV4: Mss = Mss(536);
82
83 pub const DEFAULT_IPV6: Mss = Mss(1220);
88
89 pub const fn new(mss: u16) -> Option<Self> {
91 if mss < Self::MIN.get() { None } else { Some(Mss(mss)) }
92 }
93
94 pub fn from_mms(mms: Mms) -> Option<Self> {
96 let mss = u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX);
97 Self::new(mss)
98 }
99
100 pub const fn default<I: Ip>() -> Self {
102 match I::VERSION {
103 IpVersion::V4 => Self::DEFAULT_IPV4,
104 IpVersion::V6 => Self::DEFAULT_IPV6,
105 }
106 }
107
108 pub const fn get(&self) -> u16 {
110 let Self(mss) = *self;
111 mss
112 }
113}
114
115#[derive(Clone, Copy, PartialEq, Eq, Debug)]
141pub struct EffectiveMss {
142 mss: Mss,
143 fixed_tcp_options_size: u16,
144}
145
146impl EffectiveMss {
147 pub const fn from_mss(mss: Mss, size_limits: MssSizeLimiters) -> Self {
149 let MssSizeLimiters { timestamp_enabled } = size_limits;
150 let fixed_tcp_options_size = if timestamp_enabled {
153 packet_formats::tcp::options::ALIGNED_TIMESTAMP_OPTION_LENGTH as u16
154 } else {
155 0
156 };
157 EffectiveMss { mss, fixed_tcp_options_size }
158 }
159
160 pub fn payload_size(&self, options: &SegmentOptions) -> NonZeroU16 {
165 let Self { mss, fixed_tcp_options_size } = self;
166 let SegmentOptions { timestamp: _, sack_blocks } = options;
170 let tcp_options_len = if sack_blocks.is_empty() {
171 *fixed_tcp_options_size
172 } else {
173 u16::try_from(options.builder().bytes_len()).unwrap()
176 };
177
178 NonZeroU16::new(mss.get() - tcp_options_len).unwrap()
181 }
182
183 pub fn mss(&self) -> &Mss {
185 &self.mss
186 }
187
188 pub fn update_mss(&mut self, new: Mss) {
190 self.mss = new
191 }
192
193 pub const fn get(&self) -> u16 {
195 let Self { mss, fixed_tcp_options_size } = *self;
196 mss.get() - fixed_tcp_options_size
197 }
198}
199
200pub struct MssSizeLimiters {
202 pub timestamp_enabled: bool,
204}
205
206impl From<EffectiveMss> for u32 {
207 fn from(mss: EffectiveMss) -> Self {
208 u32::from(mss.get())
209 }
210}
211
212impl From<EffectiveMss> for usize {
213 fn from(mss: EffectiveMss) -> Self {
214 usize::from(mss.get())
215 }
216}
217
218#[derive(Copy, Clone, Debug, PartialEq)]
220pub struct FragmentedPayload<'a, const N: usize> {
221 storage: [&'a [u8]; N],
222 start: usize,
227 end: usize,
228}
229
230impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
237 fn from_iter<T>(iter: T) -> Self
238 where
239 T: IntoIterator<Item = &'a [u8]>,
240 {
241 let Self { storage, start, end } = Self::new_empty();
242 let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
243 storage[end] = sl;
244 (storage, end + 1)
245 });
246 Self { storage, start, end }
247 }
248}
249
250impl<'a, const N: usize> FragmentedPayload<'a, N> {
251 pub fn new(values: [&'a [u8]; N]) -> Self {
253 Self { storage: values, start: 0, end: N }
254 }
255
256 pub fn new_contiguous(value: &'a [u8]) -> Self {
258 core::iter::once(value).collect()
259 }
260
261 pub fn to_vec(self) -> Vec<u8> {
263 self.slices().concat()
264 }
265
266 fn slices(&self) -> &[&'a [u8]] {
267 let Self { storage, start, end } = self;
268 &storage[*start..*end]
269 }
270
271 fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
274 &self,
275 mut offset: usize,
276 mut dst: &mut [T],
277 apply: F,
278 ) {
279 let mut slices = self.slices().iter();
280 while let Some(sl) = slices.next() {
281 let l = sl.len();
282 if offset >= l {
283 offset -= l;
284 continue;
285 }
286 let sl = &sl[offset..];
287 let cp = sl.len().min(dst.len());
288 let (target, new_dst) = dst.split_at_mut(cp);
289 apply(&sl[..cp], target);
290
291 if new_dst.len() == 0 {
293 return;
294 }
295
296 dst = new_dst;
297 offset = 0;
298 }
299 assert_eq!(dst.len(), 0, "failed to fill dst");
300 }
301}
302
303impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
304 fn len(&self) -> usize {
305 self.slices().iter().map(|s| s.len()).sum()
306 }
307}
308
309impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
310 fn slice(self, byte_range: Range<u32>) -> Self {
311 let Self { mut storage, start: mut self_start, end: mut self_end } = self;
312 let Range { start: byte_start, end: byte_end } = byte_range;
313 let byte_start =
314 usize::try_from(byte_start).expect("range start index out of range for usize");
315 let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
316 assert!(byte_end >= byte_start);
317 let mut storage_iter =
318 (&mut storage[self_start..self_end]).iter_mut().scan(0, |total_len, slice| {
319 let slice_len = slice.len();
320 let item = Some((*total_len, slice));
321 *total_len += slice_len;
322 item
323 });
324
325 let mut start_offset = None;
328 let mut final_len = 0;
329 while let Some((sl_offset, sl)) = storage_iter.next() {
330 let orig_len = sl.len();
331
332 if sl_offset + orig_len < byte_start {
335 *sl = &[];
336 self_start += 1;
337 continue;
338 }
339 if sl_offset >= byte_end {
341 *sl = &[];
342 self_end -= 1;
343 continue;
344 }
345
346 let sl_start = byte_start.saturating_sub(sl_offset);
347 let sl_end = sl.len().min(byte_end - sl_offset);
348 *sl = &sl[sl_start..sl_end];
349
350 match start_offset {
351 Some(_) => (),
352 None => {
353 start_offset = Some(sl_offset + sl_start);
355 if sl.len() == 0 {
358 self_start += 1;
359 }
360 }
361 }
362 final_len += sl.len();
363 }
364 assert_eq!(
366 start_offset.unwrap_or(0),
369 byte_start,
370 "range start index out of range {byte_range:?}"
371 );
372 assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
373
374 if self_start == self_end {
376 self_start = 0;
377 self_end = 0;
378 }
379 Self { storage, start: self_start, end: self_end }
380 }
381
382 fn new_empty() -> Self {
383 Self { storage: [&[]; N], start: 0, end: 0 }
384 }
385
386 fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
387 self.apply_copy(offset, dst, |src, dst| {
388 dst.copy_from_slice(src);
389 });
390 }
391
392 fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
393 self.apply_copy(offset, dst, |src, dst| {
394 let _ = dst.write_copy_of_slice(src);
395 });
396 }
397}
398
399impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
400 fn bytes_len(&self) -> usize {
401 self.len()
402 }
403
404 fn serialize(&self, buffer: &mut [u8]) {
405 self.partial_copy(0, buffer);
406 }
407}
408
409#[cfg(any(test, feature = "testutils"))]
410mod testutil {
411 use super::*;
412
413 impl From<Mss> for u32 {
414 fn from(Mss(mss): Mss) -> Self {
415 u32::from(mss)
416 }
417 }
418
419 impl From<Mss> for usize {
420 fn from(Mss(mss): Mss) -> Self {
421 usize::from(mss)
422 }
423 }
424}
425
426#[cfg(test)]
427mod test {
428 use super::*;
429
430 use packet::Serializer as _;
431 use proptest::test_runner::Config;
432 use proptest::{prop_assert_eq, proptest};
433 use proptest_support::failed_seeds_no_std;
434 use test_case::test_case;
435
436 use crate::{SackBlock, SackBlocks, SeqNum, Timestamp, TimestampOption};
437
438 const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
439 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
440 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
441 #[test_case(FragmentedPayload::new([
442 &EXAMPLE_DATA[0..2],
443 &EXAMPLE_DATA[2..5],
444 &EXAMPLE_DATA[5..],
445 ]); "split twice")]
446 #[test_case(FragmentedPayload::<4>::from_iter([
447 &EXAMPLE_DATA[0..2],
448 &EXAMPLE_DATA[2..5],
449 &EXAMPLE_DATA[5..],
450 ]); "partial twice")]
451 fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
452 let serialized = payload
453 .into_serializer()
454 .serialize_vec_outer()
455 .expect("should serialize")
456 .unwrap_b()
457 .into_inner();
458 assert_eq!(&serialized[..], EXAMPLE_DATA);
459 }
460
461 #[test]
462 #[should_panic(expected = "range start index out of range")]
463 fn slice_start_out_of_bounds() {
464 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
465 let bad_len = len + 1;
466 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
469 }
470
471 #[test]
472 #[should_panic(expected = "range end index out of range")]
473 fn slice_end_out_of_bounds() {
474 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
475 let bad_len = len + 1;
476 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
477 }
478
479 #[test]
480 fn canon_empty_payload() {
481 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
482 assert_eq!(
483 FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
484 FragmentedPayload::new_empty()
485 );
486 assert_eq!(
487 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
488 FragmentedPayload::new_empty()
489 );
490 assert_eq!(
491 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
492 FragmentedPayload::new_empty()
493 );
494 }
495
496 const TEST_BYTES: &'static [u8] = b"Hello World!";
497 proptest! {
498 #![proptest_config(Config {
499 failure_persistence: failed_seeds_no_std!(),
501 ..Config::default()
502 })]
503
504 #[test]
505 fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
506 prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
507 }
508
509 #[test]
510 fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
511 prop_assert_eq!(payload.len(), TEST_BYTES.len())
512 }
513
514 #[test]
515 fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
516 let want = &TEST_BYTES[start..end];
517 let start = u32::try_from(start).unwrap();
518 let end = u32::try_from(end).unwrap();
519 prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
520 }
521
522 #[test]
523 fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
524 let mut buffer = [0; TEST_BYTES.len()];
525 let buffer = &mut buffer[0..(end-start)];
526 payload.partial_copy(start, buffer);
527 prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
528 }
529 }
530
531 mod fragmented_payload {
532 use super::*;
533
534 use proptest::strategy::{Just, Strategy};
535 use rand::Rng as _;
536
537 const TEST_STORAGE: usize = 5;
538 type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
539 pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
540 (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
541 (0..slices)
542 .scan(0, |st, slice| {
543 let len = if slice == slices - 1 {
544 TEST_BYTES.len() - *st
545 } else {
546 rng.random_range(0..=(TEST_BYTES.len() - *st))
547 };
548 let start = *st;
549 *st += len;
550 Some(&TEST_BYTES[start..*st])
551 })
552 .collect()
553 })
554 }
555
556 pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
557 {
558 (
559 with_payload(),
560 (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
561 )
562 }
563 }
564
565 #[test_case(true; "timestamp_enabled")]
566 #[test_case(false; "timestamp_disabled")]
567 fn effective_mss_accounts_for_fixed_size_tcp_options(timestamp_enabled: bool) {
568 const SIZE: u16 = 1000;
569 let mss =
570 EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
571 if timestamp_enabled {
572 assert_eq!(
573 mss.get(),
574 SIZE - packet_formats::tcp::options::ALIGNED_TIMESTAMP_OPTION_LENGTH as u16
575 )
576 } else {
577 assert_eq!(mss.get(), SIZE);
578 }
579 }
580
581 #[test_case(SegmentOptions {sack_blocks: SackBlocks::EMPTY, timestamp: None}; "empty")]
582 #[test_case(SegmentOptions {
583 sack_blocks: SackBlocks::from_iter([
584 SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
585 SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
586 ]),
587 timestamp: None
588 }; "sack_blocks")]
589 #[test_case(SegmentOptions {
590 sack_blocks: SackBlocks::EMPTY,
591 timestamp: Some(TimestampOption {
592 ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
593 }),
594 }; "timestamp")]
595 #[test_case(SegmentOptions {
596 sack_blocks: SackBlocks::from_iter([
597 SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
598 SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
599 ]),
600 timestamp: Some(TimestampOption {
601 ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
602 }),
603 }; "sack_blocks_and_timestamp")]
604
605 fn effective_mss_accounts_for_variable_size_tcp_options(options: SegmentOptions) {
606 const SIZE: u16 = 1000;
607 let timestamp_enabled = options.timestamp.is_some();
608 let mss =
609 EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
610 let options_len = u16::try_from(options.builder().bytes_len()).unwrap();
611 assert_eq!(mss.payload_size(&options).get(), SIZE - options_len);
612 }
613}