1use core::fmt::Debug;
12#[cfg(test)]
13use core::fmt::{self, Formatter};
14use core::num::NonZeroU16;
15use core::ops::Range;
16
17use net_types::ip::{Ip, IpAddress, IpVersionMarker};
18use packet::{
19 BufferView, BufferViewMut, ByteSliceInnerPacketBuilder, EmptyBuf, FragmentedBytesMut, FromRaw,
20 InnerPacketBuilder, MaybeParsed, NestablePacketBuilder, NoOpParsingContext,
21 NoOpSerializationContext, PacketBuilder, PacketConstraints, ParsablePacket, ParseMetadata,
22 PartialPacketBuilder, SerializationContext, SerializeTarget, Serializer,
23};
24use zerocopy::byteorder::network_endian::U16;
25use zerocopy::{
26 FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut, Unaligned,
27};
28
29use crate::error::{ParseError, ParseResult};
30use crate::ip::IpProto;
31use crate::{
32 TransportChecksumAction, compute_transport_checksum_parts,
33 compute_transport_checksum_serialize, compute_transport_pseudo_header_partial_checksum,
34};
35
36pub const HEADER_BYTES: usize = 8;
38
39pub const CHECKSUM_OFFSET: usize = 6;
41
42const CHECKSUM_RANGE: Range<usize> = CHECKSUM_OFFSET..CHECKSUM_OFFSET + 2;
43
44#[derive(Debug, KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned)]
45#[repr(C)]
46struct Header {
47 src_port: U16,
48 dst_port: U16,
49 length: U16,
50 checksum: [u8; 2],
51}
52
53impl Header {
54 fn checksummed(&self) -> bool {
55 self.checksum != U16::ZERO
56 }
57
58 pub fn set_src_port(&mut self, new: u16) {
59 let old = self.src_port;
60 let new = U16::from(new);
61 if old == new {
62 return; }
64
65 self.src_port = new;
66 if self.checksummed() {
67 self.checksum =
68 internet_checksum::update(self.checksum, old.as_bytes(), new.as_bytes());
69 sanitize_checksum(&mut self.checksum);
70 }
71 }
72
73 pub fn set_dst_port(&mut self, new: NonZeroU16) {
74 let old = self.dst_port;
75 let new = U16::from(new.get());
76 if old == new {
77 return; }
79
80 self.dst_port = new;
81 if self.checksummed() {
82 self.checksum =
83 internet_checksum::update(self.checksum, old.as_bytes(), new.as_bytes());
84 sanitize_checksum(&mut self.checksum);
85 }
86 }
87
88 pub fn update_checksum_pseudo_header_address<A: IpAddress>(&mut self, old: A, new: A) {
89 if old == new {
90 return; }
92
93 if self.checksummed() {
94 self.checksum = internet_checksum::update(self.checksum, old.bytes(), new.bytes());
95 sanitize_checksum(&mut self.checksum);
96 }
97 }
98}
99
100pub struct UdpPacket<B> {
109 header: Ref<B, Header>,
110 body: B,
111}
112
113pub trait UdpParseContext {
115 fn skip_checksum_verification(&mut self) -> bool;
117}
118
119impl UdpParseContext for NoOpParsingContext {
120 fn skip_checksum_verification(&mut self) -> bool {
121 false
122 }
123}
124
125pub struct UdpParseArgs<A: IpAddress, C> {
127 src_ip: A,
128 dst_ip: A,
129 context: C,
130}
131
132impl<A: IpAddress> UdpParseArgs<A, NoOpParsingContext> {
133 pub fn new(src_ip: A, dst_ip: A) -> Self {
135 UdpParseArgs { src_ip, dst_ip, context: NoOpParsingContext }
136 }
137}
138
139impl<A: IpAddress, C> UdpParseArgs<A, C> {
140 pub fn with_context(src_ip: A, dst_ip: A, context: C) -> Self {
142 UdpParseArgs { src_ip, dst_ip, context }
143 }
144}
145
146impl<B: SplitByteSlice, A: IpAddress, C: UdpParseContext>
147 FromRaw<UdpPacketRaw<B>, UdpParseArgs<A, C>> for UdpPacket<B>
148{
149 type Error = ParseError;
150
151 fn try_from_raw_with(
152 raw: UdpPacketRaw<B>,
153 UdpParseArgs { src_ip, dst_ip, mut context }: UdpParseArgs<A, C>,
154 ) -> Result<Self, Self::Error> {
155 let header = raw
157 .header
158 .ok_or_else(|_| debug_err!(ParseError::Format, "too few bytes for header"))?;
159 let body = raw.body.ok_or_else(|_| debug_err!(ParseError::Format, "incomplete body"))?;
160
161 if !context.skip_checksum_verification() {
162 let checksum = header.checksum;
163 if checksum != [0, 0] {
167 let parts = [Ref::bytes(&header), body.deref().as_ref()];
168 let checksum = compute_transport_checksum_parts(
169 src_ip,
170 dst_ip,
171 IpProto::Udp.into(),
172 parts.iter(),
173 )
174 .ok_or_else(debug_err_fn!(ParseError::Format, "packet too large"))?;
175
176 if checksum != [0, 0] {
185 return debug_err!(
186 Err(ParseError::Checksum),
187 "invalid checksum {:X?}",
188 header.checksum,
189 );
190 }
191 } else if A::Version::VERSION.is_v6() {
192 return debug_err!(Err(ParseError::Format), "missing checksum");
193 }
194 }
195
196 if header.dst_port.get() == 0 {
197 return debug_err!(Err(ParseError::Format), "zero destination port");
198 }
199
200 Ok(UdpPacket { header, body })
201 }
202}
203
204impl<B: SplitByteSlice, A: IpAddress, C: UdpParseContext> ParsablePacket<B, UdpParseArgs<A, C>>
205 for UdpPacket<B>
206{
207 type Error = ParseError;
208
209 fn parse_metadata(&self) -> ParseMetadata {
210 ParseMetadata::from_packet(Ref::bytes(&self.header).len(), self.body.len(), 0)
211 }
212
213 fn parse<BV: BufferView<B>>(buffer: BV, args: UdpParseArgs<A, C>) -> ParseResult<Self> {
214 UdpPacketRaw::<B>::parse(buffer, IpVersionMarker::<A::Version>::default())
215 .and_then(|u| UdpPacket::try_from_raw_with(u, args))
216 }
217}
218
219impl<B: SplitByteSlice> UdpPacket<B> {
220 pub fn body(&self) -> &[u8] {
222 self.body.deref()
223 }
224
225 pub fn as_bytes(&self) -> [&[u8]; 2] {
227 [&Ref::bytes(&self.header), self.body.deref()]
228 }
229
230 pub fn into_body(self) -> B {
238 self.body
239 }
240
241 pub fn src_port(&self) -> Option<NonZeroU16> {
245 NonZeroU16::new(self.header.src_port.get())
246 }
247
248 pub fn dst_port(&self) -> NonZeroU16 {
250 NonZeroU16::new(self.header.dst_port.get()).unwrap()
252 }
253
254 pub fn checksummed(&self) -> bool {
264 self.header.checksummed()
265 }
266
267 pub fn builder<A: IpAddress>(&self, src_ip: A, dst_ip: A) -> UdpPacketBuilder<A> {
269 UdpPacketBuilder {
270 src_ip,
271 dst_ip,
272 src_port: self.src_port(),
273 dst_port: Some(self.dst_port()),
274 }
275 }
276
277 pub fn into_serializer<'a, A: IpAddress>(
291 self,
292 src_ip: A,
293 dst_ip: A,
294 ) -> impl Serializer<NoOpSerializationContext, Buffer = EmptyBuf> + Debug + 'a
295 where
296 B: 'a,
297 {
298 self.builder(src_ip, dst_ip)
299 .wrap_body(ByteSliceInnerPacketBuilder(self.body).into_serializer())
300 }
301}
302
303impl<B: SplitByteSliceMut> UdpPacket<B> {
304 pub fn set_src_port(&mut self, new: u16) {
306 self.header.set_src_port(new)
307 }
308
309 pub fn set_dst_port(&mut self, new: NonZeroU16) {
311 self.header.set_dst_port(new);
312 }
313
314 pub fn update_checksum_pseudo_header_address<A: IpAddress>(&mut self, old: A, new: A) {
316 self.header.update_checksum_pseudo_header_address(old, new);
317 }
318}
319
320impl<B: zerocopy::CloneableByteSlice + Clone> Clone for UdpPacket<B> {
321 fn clone(&self) -> Self {
322 UdpPacket { header: self.header.clone(), body: self.body.clone() }
323 }
324}
325
326#[derive(Debug, Default, KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned, PartialEq)]
331#[repr(C)]
332struct UdpFlowHeader {
333 src_port: U16,
334 dst_port: U16,
335}
336
337#[derive(Debug)]
339struct PartialHeader<B: SplitByteSlice> {
340 flow: Ref<B, UdpFlowHeader>,
341 rest: B,
342}
343
344pub struct UdpPacketRaw<B: SplitByteSlice> {
358 header: MaybeParsed<Ref<B, Header>, PartialHeader<B>>,
359 body: MaybeParsed<B, B>,
360}
361
362impl<B, I> ParsablePacket<B, IpVersionMarker<I>> for UdpPacketRaw<B>
363where
364 B: SplitByteSlice,
365 I: Ip,
366{
367 type Error = ParseError;
368
369 fn parse_metadata(&self) -> ParseMetadata {
370 let header_len = match &self.header {
371 MaybeParsed::Complete(h) => Ref::bytes(&h).len(),
372 MaybeParsed::Incomplete(h) => Ref::bytes(&h.flow).len() + h.rest.len(),
373 };
374 ParseMetadata::from_packet(header_len, self.body.len(), 0)
375 }
376
377 fn parse<BV: BufferView<B>>(mut buffer: BV, _args: IpVersionMarker<I>) -> ParseResult<Self> {
378 let header = if let Some(header) = buffer.take_obj_front::<Header>() {
381 header
382 } else {
383 let flow = buffer
384 .take_obj_front::<UdpFlowHeader>()
385 .ok_or_else(debug_err_fn!(ParseError::Format, "too few bytes for flow header"))?;
386 return Ok(UdpPacketRaw {
389 header: MaybeParsed::Incomplete(PartialHeader {
390 flow,
391 rest: buffer.take_rest_front(),
392 }),
393 body: MaybeParsed::Incomplete(buffer.into_rest()),
394 });
395 };
396 let buffer_len = buffer.len();
397
398 fn get_udp_body_length<I: Ip>(header: &Header, remaining_buff_len: usize) -> Option<usize> {
399 if I::VERSION.is_v6()
407 && header.length.get() == 0
408 && remaining_buff_len.saturating_add(HEADER_BYTES) >= (core::u16::MAX as usize)
409 {
410 return Some(remaining_buff_len);
411 }
412
413 usize::from(header.length.get()).checked_sub(HEADER_BYTES)
414 }
415
416 let body = if let Some(body_len) = get_udp_body_length::<I>(&header, buffer_len) {
417 if body_len <= buffer_len {
418 let _: B = buffer.take_back(buffer_len - body_len).unwrap();
422 MaybeParsed::Complete(buffer.into_rest())
423 } else {
424 MaybeParsed::Incomplete(buffer.into_rest())
426 }
427 } else {
428 let _: B = buffer.take_rest_back();
432 MaybeParsed::Incomplete(buffer.into_rest())
433 };
434
435 Ok(UdpPacketRaw { header: MaybeParsed::Complete(header), body })
436 }
437}
438
439impl<B: SplitByteSlice> UdpPacketRaw<B> {
440 pub fn src_port(&self) -> Option<NonZeroU16> {
444 NonZeroU16::new(
445 self.header
446 .as_ref()
447 .map(|header| header.src_port)
448 .map_incomplete(|partial_header| partial_header.flow.src_port)
449 .into_inner()
450 .get(),
451 )
452 }
453
454 pub fn dst_port(&self) -> Option<NonZeroU16> {
459 NonZeroU16::new(
460 self.header
461 .as_ref()
462 .map(|header| header.dst_port)
463 .map_incomplete(|partial_header| partial_header.flow.dst_port)
464 .into_inner()
465 .get(),
466 )
467 }
468
469 pub fn builder<A: IpAddress>(&self, src_ip: A, dst_ip: A) -> UdpPacketBuilder<A> {
476 UdpPacketBuilder { src_ip, dst_ip, src_port: self.src_port(), dst_port: self.dst_port() }
477 }
478
479 pub fn into_serializer<'a, A: IpAddress>(
498 self,
499 src_ip: A,
500 dst_ip: A,
501 ) -> Option<impl Serializer<NoOpSerializationContext, Buffer = EmptyBuf> + 'a>
502 where
503 B: 'a,
504 {
505 let builder = self.builder(src_ip, dst_ip);
506 self.body
507 .complete()
508 .ok()
509 .map(|body| builder.wrap_body(ByteSliceInnerPacketBuilder(body).into_serializer()))
510 }
511}
512
513impl<B: SplitByteSliceMut> UdpPacketRaw<B> {
514 pub fn set_src_port(&mut self, new: u16) {
516 match &mut self.header {
517 MaybeParsed::Complete(h) => h.set_src_port(new),
518 MaybeParsed::Incomplete(h) => {
519 h.flow.src_port = U16::from(new);
520
521 }
523 }
524 }
525
526 pub fn set_dst_port(&mut self, new: NonZeroU16) {
528 match &mut self.header {
529 MaybeParsed::Complete(h) => h.set_dst_port(new),
530 MaybeParsed::Incomplete(h) => {
531 h.flow.dst_port = U16::from(new.get());
532
533 }
535 }
536 }
537
538 pub fn update_checksum_pseudo_header_address<A: IpAddress>(&mut self, old: A, new: A) {
540 match &mut self.header {
541 MaybeParsed::Complete(h) => h.update_checksum_pseudo_header_address(old, new),
542 MaybeParsed::Incomplete(_) => {
543 }
545 }
546 }
547}
548
549pub struct UdpEnvelope;
557
558pub trait UdpSerializationContext: SerializationContext {
560 fn envelope_to_state(envelope: UdpEnvelope) -> Self::ContextState;
562
563 fn checksum_action(&mut self) -> TransportChecksumAction;
565}
566
567impl UdpSerializationContext for NoOpSerializationContext {
568 fn envelope_to_state(_envelope: UdpEnvelope) -> Self::ContextState {
569 ()
570 }
571
572 fn checksum_action(&mut self) -> TransportChecksumAction {
573 TransportChecksumAction::ComputeFull
574 }
575}
576
577#[derive(Copy, Clone, Debug, PartialEq)]
579pub struct UdpPacketBuilder<A: IpAddress> {
580 src_ip: A,
581 dst_ip: A,
582 src_port: Option<NonZeroU16>,
583 dst_port: Option<NonZeroU16>,
584}
585
586impl<A: IpAddress> UdpPacketBuilder<A> {
587 pub fn new(
589 src_ip: A,
590 dst_ip: A,
591 src_port: Option<NonZeroU16>,
592 dst_port: NonZeroU16,
593 ) -> UdpPacketBuilder<A> {
594 UdpPacketBuilder { src_ip, dst_ip, src_port, dst_port: Some(dst_port) }
595 }
596
597 pub fn src_port(&self) -> Option<NonZeroU16> {
599 self.src_port
600 }
601
602 pub fn dst_port(&self) -> Option<NonZeroU16> {
604 self.dst_port
605 }
606
607 pub fn set_src_ip(&mut self, addr: A) {
609 self.src_ip = addr;
610 }
611
612 pub fn set_dst_ip(&mut self, addr: A) {
614 self.dst_ip = addr;
615 }
616
617 pub fn set_src_port(&mut self, port: u16) {
619 self.src_port = NonZeroU16::new(port);
620 }
621
622 pub fn set_dst_port(&mut self, port: NonZeroU16) {
624 self.dst_port = Some(port);
625 }
626
627 fn serialize_header(&self, body_len: usize, mut buffer: &mut [u8]) {
628 let total_len = buffer.len() + body_len;
631
632 (&mut buffer)
638 .write_obj_front(&Header {
639 src_port: U16::new(self.src_port.map_or(0, NonZeroU16::get)),
640 dst_port: U16::new(self.dst_port.map_or(0, NonZeroU16::get)),
641 length: U16::new(total_len.try_into().unwrap_or_else(|_| {
642 if A::Version::VERSION.is_v6() {
643 0u16
645 } else {
646 panic!(
647 "total UDP packet length of {total_len} bytes \
648 overflows 16-bit length field of UDP header"
649 )
650 }
651 })),
652 checksum: [0, 0],
655 })
656 .expect("too few bytes for UDP header");
657 }
658}
659
660impl<A: IpAddress> NestablePacketBuilder for UdpPacketBuilder<A> {
661 fn constraints(&self) -> PacketConstraints {
662 PacketConstraints::new(
663 HEADER_BYTES,
664 0,
665 0,
666 if A::Version::VERSION.is_v4() {
667 (1 << 16) - 1
668 } else {
669 core::usize::MAX
675 },
676 )
677 }
678}
679
680impl<A: IpAddress, C: UdpSerializationContext> PacketBuilder<C> for UdpPacketBuilder<A> {
681 fn context_state(&self) -> C::ContextState {
682 C::envelope_to_state(UdpEnvelope)
683 }
684
685 fn serialize(
686 &self,
687 context: &mut C,
688 target: &mut SerializeTarget<'_>,
689 body: FragmentedBytesMut<'_, '_>,
690 ) {
691 self.serialize_header(body.len(), target.header);
692
693 let checksum = match context.checksum_action() {
694 TransportChecksumAction::ComputeFull => compute_transport_checksum_serialize(
695 self.src_ip,
696 self.dst_ip,
697 IpProto::Udp.into(),
698 target,
699 body,
700 )
701 .map(|mut c| {
702 sanitize_checksum(&mut c);
703 c
704 }),
705 TransportChecksumAction::ComputePartial => {
706 compute_transport_pseudo_header_partial_checksum(
707 self.src_ip,
708 self.dst_ip,
709 IpProto::Udp.into(),
710 target,
711 body,
712 )
713 }
714 }
715 .unwrap(); target.header[CHECKSUM_RANGE].copy_from_slice(&checksum[..]);
718 }
719}
720
721impl<A: IpAddress, C: UdpSerializationContext> PartialPacketBuilder<C> for UdpPacketBuilder<A> {
722 fn partial_serialize(&self, _context: &mut C, body_len: usize, buffer: &mut [u8]) {
723 self.serialize_header(body_len, buffer);
724 }
725}
726
727#[inline]
728fn sanitize_checksum(checksum_bytes: &mut [u8; 2]) {
729 if *checksum_bytes == [0, 0] {
733 *checksum_bytes = [0xFF, 0xFF];
734 }
735}
736
737#[cfg(test)]
739impl<B> Debug for UdpPacket<B> {
740 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
741 write!(fmt, "UdpPacket")
742 }
743}
744
745#[cfg(test)]
746mod tests {
747 use assert_matches::assert_matches;
748 use byteorder::{ByteOrder, NetworkEndian};
749 use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
750 use packet::{Buf, NestableSerializer as _, ParseBuffer, ParseBufferMut};
751 use test_case::test_case;
752
753 use super::*;
754 use crate::ethernet::{EthernetFrame, EthernetFrameLengthCheck};
755 use crate::ipv4::{Ipv4Header, Ipv4Packet};
756 use crate::ipv6::{Ipv6Header, Ipv6Packet};
757 use crate::testutil::*;
758 use crate::update_transport_checksum_pseudo_header;
759 use packet::NoOpSerializationContext;
760
761 const TEST_SRC_IPV4: Ipv4Addr = Ipv4Addr::new([1, 2, 3, 4]);
762 const TEST_DST_IPV4: Ipv4Addr = Ipv4Addr::new([5, 6, 7, 8]);
763 const TEST_SRC_IPV6: Ipv6Addr =
764 Ipv6Addr::from_bytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
765 const TEST_DST_IPV6: Ipv6Addr =
766 Ipv6Addr::from_bytes([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]);
767
768 #[test]
769 fn test_parse_serialize_full_ipv4() {
770 use crate::testdata::dns_request_v4::*;
771
772 let mut buf = ETHERNET_FRAME.bytes;
773 let frame = buf.parse_with::<_, EthernetFrame<_>>(EthernetFrameLengthCheck::Check).unwrap();
774 verify_ethernet_frame(&frame, ETHERNET_FRAME);
775
776 let mut body = frame.body();
777 let ip_packet = body.parse::<Ipv4Packet<_>>().unwrap();
778 verify_ipv4_packet(&ip_packet, IPV4_PACKET);
779
780 let mut body = ip_packet.body();
781 let udp_packet = body
782 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(
783 ip_packet.src_ip(),
784 ip_packet.dst_ip(),
785 ))
786 .unwrap();
787 verify_udp_packet(&udp_packet, UDP_PACKET);
788
789 let buffer = udp_packet
790 .body()
791 .into_serializer()
792 .wrap_in(udp_packet.builder(ip_packet.src_ip(), ip_packet.dst_ip()))
793 .wrap_in(ip_packet.builder())
794 .wrap_in(frame.builder())
795 .serialize_vec_outer(&mut NoOpSerializationContext)
796 .unwrap();
797 assert_eq!(buffer.as_ref(), ETHERNET_FRAME.bytes);
798 }
799
800 #[test]
801 fn test_parse_serialize_full_ipv6() {
802 use crate::testdata::dns_request_v6::*;
803
804 let mut buf = ETHERNET_FRAME.bytes;
805 let frame = buf.parse_with::<_, EthernetFrame<_>>(EthernetFrameLengthCheck::Check).unwrap();
806 verify_ethernet_frame(&frame, ETHERNET_FRAME);
807
808 let mut body = frame.body();
809 let ip_packet = body.parse::<Ipv6Packet<_>>().unwrap();
810 verify_ipv6_packet(&ip_packet, IPV6_PACKET);
811
812 let mut body = ip_packet.body();
813 let udp_packet = body
814 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(
815 ip_packet.src_ip(),
816 ip_packet.dst_ip(),
817 ))
818 .unwrap();
819 verify_udp_packet(&udp_packet, UDP_PACKET);
820
821 let buffer = udp_packet
822 .body()
823 .into_serializer()
824 .wrap_in(udp_packet.builder(ip_packet.src_ip(), ip_packet.dst_ip()))
825 .wrap_in(ip_packet.builder())
826 .wrap_in(frame.builder())
827 .serialize_vec_outer(&mut NoOpSerializationContext)
828 .unwrap();
829 assert_eq!(buffer.as_ref(), ETHERNET_FRAME.bytes);
830 }
831
832 #[test]
833 fn test_parse() {
834 let mut buf = &[0, 0, 1, 2, 0, 8, 0, 0][..];
836 let packet = buf
837 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(TEST_SRC_IPV4, TEST_DST_IPV4))
838 .unwrap();
839 assert!(packet.src_port().is_none());
840 assert_eq!(packet.dst_port().get(), NetworkEndian::read_u16(&[1, 2]));
841 assert!(!packet.checksummed());
842 assert!(packet.body().is_empty());
843
844 let mut buf = vec![0_u8, 0, 1, 2, 0, 0, 0xBF, 0x12];
846 buf.extend((0..core::u16::MAX).into_iter().map(|p| p as u8));
847 let bv = &mut &buf[..];
848 let packet = bv
849 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(TEST_SRC_IPV6, TEST_DST_IPV6))
850 .unwrap();
851 assert!(packet.src_port().is_none());
852 assert_eq!(packet.dst_port().get(), NetworkEndian::read_u16(&[1, 2]));
853 assert!(packet.checksummed());
854 assert_eq!(packet.body().len(), core::u16::MAX as usize);
855 }
856
857 fn new_test_udp_builder() -> UdpPacketBuilder<Ipv4Addr> {
858 UdpPacketBuilder::new(
859 TEST_SRC_IPV4,
860 TEST_DST_IPV4,
861 NonZeroU16::new(1),
862 NonZeroU16::new(2).unwrap(),
863 )
864 }
865
866 #[test]
867 fn test_serialize() {
868 let mut buf = new_test_udp_builder()
869 .wrap_body(EmptyBuf)
870 .serialize_vec_outer(&mut NoOpSerializationContext)
871 .unwrap();
872 assert_eq!(buf.as_ref(), [0, 1, 0, 2, 0, 8, 239, 199]);
873 let packet = buf
874 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(TEST_SRC_IPV4, TEST_DST_IPV4))
875 .unwrap();
876 assert_eq!(packet.src_port().unwrap().get(), 1);
879 assert_eq!(packet.dst_port().get(), 2);
880 assert!(packet.checksummed());
881 }
882
883 #[test]
884 fn test_serialize_zeroes() {
885 let mut buf_0 = [0; HEADER_BYTES];
888 let _: Buf<&mut [u8]> = new_test_udp_builder()
889 .wrap_body(Buf::new(&mut buf_0[..], HEADER_BYTES..))
890 .serialize_vec_outer(&mut NoOpSerializationContext)
891 .unwrap()
892 .unwrap_a();
893 let mut buf_1 = [0xFF; HEADER_BYTES];
894 let _: Buf<&mut [u8]> = new_test_udp_builder()
895 .wrap_body(Buf::new(&mut buf_1[..], HEADER_BYTES..))
896 .serialize_vec_outer(&mut NoOpSerializationContext)
897 .unwrap()
898 .unwrap_a();
899 assert_eq!(buf_0, buf_1);
900 }
901
902 #[test]
903 fn test_parse_error() {
904 fn test_zero<I: IpAddress>(
908 src: I,
909 dst: I,
910 succeeds: bool,
911 zero: &[usize],
912 err: ParseError,
913 ) {
914 let mut buf = [1, 2, 3, 4, 0, 8, 0, 0];
917 if succeeds {
918 let mut buf = &buf[..];
919 assert!(buf.parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(src, dst)).is_ok());
920 }
921 for idx in zero {
922 buf[*idx] = 0;
923 }
924 let mut buf = &buf[..];
925 assert_eq!(
926 buf.parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(src, dst)).unwrap_err(),
927 err
928 );
929 }
930
931 test_zero(TEST_SRC_IPV4, TEST_DST_IPV4, true, &[2, 3], ParseError::Format);
933 test_zero(TEST_SRC_IPV4, TEST_DST_IPV4, true, &[4, 5], ParseError::Format);
935 test_zero(TEST_SRC_IPV6, TEST_DST_IPV6, false, &[], ParseError::Format);
938
939 #[cfg(target_pointer_width = "64")]
941 {
942 let mut buf = vec![0u8; 1 << 32];
944 (&mut buf[..HEADER_BYTES]).copy_from_slice(&[0, 0, 1, 2, 0, 0, 0xFF, 0xE4]);
945 assert_eq!(
946 (&buf[..])
947 .parse_with::<_, UdpPacket<_>>(UdpParseArgs::new(TEST_SRC_IPV6, TEST_DST_IPV6))
948 .unwrap_err(),
949 ParseError::Format
950 );
951 }
952 }
953
954 #[test_case(TEST_SRC_IPV4, TEST_DST_IPV4, true; "ipv4 skip")]
955 #[test_case(TEST_SRC_IPV4, TEST_DST_IPV4, false; "ipv4 validate")]
956 #[test_case(TEST_SRC_IPV6, TEST_DST_IPV6, true; "ipv6 skip")]
957 #[test_case(TEST_SRC_IPV6, TEST_DST_IPV6, false; "ipv6 validate")]
958 fn test_parse_invalid_checksum<A: IpAddress>(src: A, dst: A, skip: bool) {
959 let mut buf =
960 UdpPacketBuilder::new(src, dst, NonZeroU16::new(1), NonZeroU16::new(2).unwrap())
961 .wrap_body(EmptyBuf)
962 .serialize_vec_outer(&mut NoOpSerializationContext)
963 .unwrap()
964 .as_ref()
965 .to_vec();
966
967 buf[CHECKSUM_OFFSET] ^= 0xFF;
969 buf[CHECKSUM_OFFSET + 1] ^= 0xFF;
970
971 let mut bv = &buf[..];
972 let res = bv.parse_with::<_, UdpPacket<_>>(UdpParseArgs::with_context(
973 src,
974 dst,
975 ForceSkipChecksumValidation(skip),
976 ));
977 if skip {
978 assert_matches!(res, Ok(_));
979 } else {
980 assert_matches!(res, Err(ParseError::Checksum));
981 }
982 }
983
984 #[test]
985 #[should_panic(expected = "too few bytes for UDP header")]
986 fn test_serialize_fail_header_too_short() {
987 let mut buf = [0u8; 7];
988 let mut buf = [&mut buf[..]];
989 let buf = FragmentedBytesMut::new(&mut buf[..]);
990 let (header, body, footer) = buf.try_split_contiguous(..).unwrap();
991 let builder =
992 UdpPacketBuilder::new(TEST_SRC_IPV4, TEST_DST_IPV4, None, NonZeroU16::new(1).unwrap());
993 builder.serialize(
994 &mut NoOpSerializationContext,
995 &mut SerializeTarget { header, footer },
996 body,
997 );
998 }
999
1000 #[test]
1001 #[should_panic(expected = "total UDP packet length of 65536 bytes overflows 16-bit length \
1002 field of UDP header")]
1003 fn test_serialize_fail_packet_too_long_ipv4() {
1004 let ser =
1005 UdpPacketBuilder::new(TEST_SRC_IPV4, TEST_DST_IPV4, None, NonZeroU16::new(1).unwrap())
1006 .wrap_body((&[0; (1 << 16) - HEADER_BYTES][..]).into_serializer());
1007 let _ = ser.serialize_vec_outer(&mut NoOpSerializationContext);
1008 }
1009
1010 #[test]
1011 fn test_partial_parse() {
1012 use core::ops::Deref as _;
1013
1014 let buf = [0, 0, 1, 2, 10, 20];
1016 let mut bv = &buf[..];
1017 let packet =
1018 bv.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv4>::default()).unwrap();
1019 let UdpPacketRaw { header, body } = &packet;
1020 let PartialHeader { flow, rest } = header.as_ref().incomplete().unwrap();
1021 assert_eq!(
1022 flow.deref(),
1023 &UdpFlowHeader { src_port: U16::new(0), dst_port: U16::new(0x0102) }
1024 );
1025 assert_eq!(*rest, &buf[4..]);
1026 assert_eq!(body.incomplete().unwrap(), []);
1027 assert!(
1028 UdpPacket::try_from_raw_with(packet, UdpParseArgs::new(TEST_SRC_IPV4, TEST_DST_IPV4))
1029 .is_err()
1030 );
1031
1032 let mut buf = &[0, 0, 1][..];
1034 assert!(buf.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv4>::default()).is_err());
1035
1036 let buf = [0, 0, 1, 2, 0, 30, 0, 0, 10, 20];
1038 let mut bv = &buf[..];
1039 let packet =
1040 bv.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv4>::default()).unwrap();
1041 let UdpPacketRaw { header, body } = &packet;
1042 assert_eq!(Ref::bytes(&header.as_ref().complete().unwrap()), &buf[..8]);
1043 assert_eq!(body.incomplete().unwrap(), &buf[8..]);
1044 assert!(
1045 UdpPacket::try_from_raw_with(packet, UdpParseArgs::new(TEST_SRC_IPV4, TEST_DST_IPV4))
1046 .is_err()
1047 );
1048
1049 let buf = [0, 0, 1, 2, 0, 6, 0, 0, 10, 20];
1051 let mut bv = &buf[..];
1052 let packet =
1053 bv.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv4>::default()).unwrap();
1054 let UdpPacketRaw { header, body } = &packet;
1055 assert_eq!(Ref::bytes(&header.as_ref().complete().unwrap()), &buf[..8]);
1056 assert_eq!(body.incomplete().unwrap(), []);
1057 assert!(
1058 UdpPacket::try_from_raw_with(packet, UdpParseArgs::new(TEST_SRC_IPV4, TEST_DST_IPV4))
1059 .is_err()
1060 );
1061
1062 let buf = [0, 0, 1, 2, 0, 0, 0, 0, 10, 20];
1066 let mut bv = &buf[..];
1067 let packet =
1068 bv.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv6>::default()).unwrap();
1069 let UdpPacketRaw { header, body } = &packet;
1070 assert_eq!(Ref::bytes(&header.as_ref().complete().unwrap()), &buf[..8]);
1071 assert_eq!(body.incomplete().unwrap(), []);
1072 let mut buf = vec![0, 0, 1, 2, 0, 0, 0, 0, 10, 20];
1075 buf.extend((0..core::u16::MAX).into_iter().map(|x| x as u8));
1076 let bv = &mut &buf[..];
1077 let packet =
1078 bv.parse_with::<_, UdpPacketRaw<_>>(IpVersionMarker::<Ipv6>::default()).unwrap();
1079 let UdpPacketRaw { header, body } = &packet;
1080 assert_eq!(Ref::bytes(header.as_ref().complete().unwrap()), &buf[..8]);
1081 assert_eq!(body.complete().unwrap(), &buf[8..]);
1082 }
1083
1084 #[test]
1085 fn test_serialization_checksum_actions() {
1086 let body = [0x12, 0x34];
1087 let serializer = new_test_udp_builder().wrap_body(body.into_serializer());
1088
1089 let mut c = internet_checksum::Checksum::new();
1091 update_transport_checksum_pseudo_header::<Ipv4>(
1092 &mut c,
1093 TEST_SRC_IPV4,
1094 TEST_DST_IPV4,
1095 IpProto::Udp.into(),
1096 HEADER_BYTES + body.len(),
1097 )
1098 .expect("failed to update checksum");
1099
1100 let buf = serializer
1102 .serialize_vec_outer(&mut ForceChecksumAction(TransportChecksumAction::ComputePartial))
1103 .unwrap();
1104 let [c0, c1] = c.checksum();
1105 assert_eq!(&buf.as_ref()[CHECKSUM_OFFSET..CHECKSUM_OFFSET + 2], [!c0, !c1]);
1106
1107 let buf = serializer
1109 .serialize_vec_outer(&mut ForceChecksumAction(TransportChecksumAction::ComputeFull))
1110 .unwrap();
1111
1112 c.add_bytes(buf.as_ref());
1113 assert_eq!(c.checksum(), [0, 0]);
1114 }
1115
1116 #[test]
1117 fn test_udp_checksum_0xffff() {
1118 let serializer = UdpPacketBuilder::new(
1120 Ipv4Addr::new([0, 0, 0, 0]),
1121 Ipv4Addr::new([0, 0, 0, 0]),
1122 None,
1123 NonZeroU16::new(1).unwrap(),
1124 )
1125 .wrap_body((&[0xFF, 0xD9]).into_serializer());
1126 let buf = serializer.serialize_vec_outer(&mut NoOpSerializationContext).unwrap();
1127 assert_eq!(&buf.as_ref()[CHECKSUM_OFFSET..CHECKSUM_OFFSET + 2], [0xFF, 0xFF]);
1132
1133 let mut c = internet_checksum::Checksum::new();
1135 c.add_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 10]);
1136 c.add_bytes(buf.as_ref());
1137 assert!(c.checksum() == [0, 0]);
1138 }
1139
1140 #[test]
1141 fn test_udp_checksum_partial_update_0xffff() {
1142 const DST_PORT: NonZeroU16 = NonZeroU16::new(1).unwrap();
1143 const ADDR: Ipv4Addr = Ipv4::UNSPECIFIED_ADDRESS;
1144 let serializer = UdpPacketBuilder::new(ADDR, ADDR, None, DST_PORT)
1145 .wrap_body((&[0xff, 0xd9]).into_serializer());
1146 let mut buf = serializer.serialize_vec_outer(&mut NoOpSerializationContext).unwrap();
1147 let mut packet = buf
1148 .parse_with_mut::<_, UdpPacket<_>>(UdpParseArgs::new(ADDR, ADDR))
1149 .expect("parse should succeed");
1150 assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1151
1152 packet.set_src_port(0); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1155 packet.set_src_port(1234);
1156 assert_ne!(packet.header.checksum, [0xFF, 0xFF]);
1157 packet.set_src_port(0); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1159
1160 packet.set_dst_port(DST_PORT); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1163 packet.set_dst_port(NonZeroU16::new(1234).unwrap());
1164 assert_ne!(packet.header.checksum, [0xFF, 0xFF]);
1165 packet.set_dst_port(DST_PORT); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1167
1168 packet.update_checksum_pseudo_header_address(ADDR, ADDR); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1171 const OTHER_ADDR: Ipv4Addr = Ipv4Addr::new([123, 124, 125, 126]);
1172 packet.update_checksum_pseudo_header_address(ADDR, OTHER_ADDR);
1173 assert_ne!(packet.header.checksum, [0xFF, 0xFF]);
1174 packet.update_checksum_pseudo_header_address(OTHER_ADDR, ADDR); assert_eq!(packet.header.checksum, [0xFF, 0xFF]);
1176 }
1177}