mdns/
protocol.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Protocol contains functions and traits for mDNS protocol parsing and packet
6//! creation.
7
8use std::collections::HashSet;
9use std::fmt::{Debug, Display, Formatter};
10use std::mem;
11use std::net::IpAddr;
12
13use packet::{BufferView, BufferViewMut, InnerPacketBuilder, ParsablePacket, ParseMetadata};
14use zerocopy::byteorder::network_endian::{U16, U32};
15use zerocopy::{
16    FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut, Unaligned,
17};
18
19const IPV4_SIZE: usize = 4;
20const IPV6_SIZE: usize = 16;
21const SRV_PAYLOAD_SIZE_OCTETS: u16 = 6;
22const DOMAIN_COMPRESSION_MASK_U8: u8 = 0xc0;
23const DOMAIN_COMPRESSION_MASK_U16: u16 = 0xc000;
24const IS_RESPONSE_MASK: u16 = 0x8000;
25
26// https://tools.ietf.org/html/rfc1035#section-3.1
27const MAX_DOMAIN_SIZE: usize = 255;
28const MAX_LABEL_SIZE: usize = 63;
29
30fn is_compression_byte(b: u8) -> bool {
31    b & DOMAIN_COMPRESSION_MASK_U8 == DOMAIN_COMPRESSION_MASK_U8
32}
33
34fn unwrap_domain_pointer(i: u16) -> u16 {
35    i ^ DOMAIN_COMPRESSION_MASK_U16
36}
37
38/// Packet builder that doesn't do any auxiliary storage.
39pub trait EmbeddedPacketBuilder {
40    fn bytes_len(&self) -> usize;
41
42    fn serialize<B: SplitByteSliceMut, BV: BufferViewMut<B>>(&self, bv: &mut BV);
43
44    /// Return the output of packet building as a Vec<u8>, useful for tests that don't care about
45    /// zerocopy resource constraints.
46    fn bytes(&self) -> Vec<u8> {
47        let mut vec = vec![0; self.bytes_len()];
48        vec.resize(self.bytes_len(), 0u8);
49        self.serialize(&mut &mut vec.as_mut_slice());
50        vec
51    }
52}
53
54struct BufferViewWrapper<B>(B);
55
56impl<B: SplitByteSlice + Clone> BufferView<B> for BufferViewWrapper<B> {
57    fn into_rest(self) -> B {
58        self.0
59    }
60
61    fn take_front(&mut self, n: usize) -> Option<B> {
62        if self.len() >= n {
63            let (ret, next) = self.0.clone().split_at(n).ok().unwrap();
64            self.0 = next;
65            Some(ret)
66        } else {
67            None
68        }
69    }
70
71    /// This isn't implemented as it currently is not used in this
72    /// implementation.
73    fn take_back(&mut self, _n: usize) -> Option<B> {
74        unimplemented!()
75    }
76}
77
78impl<B: SplitByteSlice> AsRef<[u8]> for BufferViewWrapper<B> {
79    fn as_ref(&self) -> &[u8] {
80        &self.0
81    }
82}
83
84/// Determines which error was run into during parsing.
85///
86/// For ones that contain lengths, this tells which length was encountered during
87/// parsing. For example, `RDataLen` is a pretty general error relating to the
88/// RData being the wrong size for the included type (both the size and type
89/// are included. More below).
90///
91/// For `RData` errors:
92/// -- `Type::A` the size was not found to be an IPv4 address.
93/// -- `Type::Aaaa` the size was not found to be an IPv6 address.
94/// -- `Type::Srv` the size of RData was not large enough to fit a SRV record
95///    header as well as a payload.
96///
97/// `BadPointerIndex` returns the last encountered pointer that attempted to
98/// reference data beyond the bounds of the available packet buffer.
99///
100/// `LabelTooLong` refers to there being too long of a label byte when parsing.
101///
102/// `DomainTooLong` refers to overrunning the maximum size of a domain
103/// (255 bytes) when parsing.
104#[derive(Debug, PartialEq)]
105pub enum ParseError {
106    RDataLen(Type, u16),
107    Malformed,
108    UnexpectedZeroCharacter,
109    PointerCycle,
110    BadPointerIndex(u16),
111    DomainTooLong(usize),
112    LabelTooLong(usize),
113    UnknownType(u16),
114    UnknownClass(u16),
115}
116
117/// Standard mDNS types supported in this protocol library.
118#[derive(Debug, Copy, Clone, Eq, PartialEq)]
119pub enum Type {
120    A,
121    Aaaa,
122    Ptr,
123    Srv,
124    Txt,
125}
126
127impl From<Type> for u16 {
128    fn from(value: Type) -> u16 {
129        match value {
130            Type::A => 1,
131            Type::Aaaa => 28,
132            Type::Ptr => 12,
133            Type::Srv => 33,
134            Type::Txt => 16,
135        }
136    }
137}
138
139impl TryFrom<u16> for Type {
140    type Error = ParseError;
141
142    fn try_from(value: u16) -> Result<Self, Self::Error> {
143        match value {
144            1 => Ok(Type::A),
145            28 => Ok(Type::Aaaa),
146            12 => Ok(Type::Ptr),
147            33 => Ok(Type::Srv),
148            16 => Ok(Type::Txt),
149            v => Err(ParseError::UnknownType(v)),
150        }
151    }
152}
153
154/// Standard DNS classes supported by this protocol library.
155#[derive(Debug, Copy, Clone, Eq, PartialEq)]
156pub enum Class {
157    In,
158    Any,
159}
160
161impl Class {
162    /// Used for mapping with flush bool and unicast bool.
163    fn into_u16_with_bool(self, b: bool) -> u16 {
164        u16::from(self) | (b as u16) << 15
165    }
166}
167
168impl From<Class> for u16 {
169    fn from(value: Class) -> u16 {
170        match value {
171            Class::In => 1,
172            Class::Any => 255,
173        }
174    }
175}
176
177impl TryFrom<u16> for Class {
178    type Error = ParseError;
179
180    fn try_from(value: u16) -> Result<Self, Self::Error> {
181        match value {
182            1 => Ok(Class::In),
183            255 => Ok(Class::Any),
184            v => Err(ParseError::UnknownClass(v)),
185        }
186    }
187}
188
189/// Represents an mDNS packet header.
190#[repr(C)]
191#[derive(KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned)]
192pub struct Header {
193    id: U16,
194    flags: U16,
195    question_count: U16,
196    answer_count: U16,
197    authority_count: U16,
198    additional_count: U16,
199}
200
201impl Header {
202    /// Returns true if this is a query (the first bit is zero).
203    pub fn is_query(&self) -> bool {
204        !self.is_response()
205    }
206
207    /// Returns true if this is a response (the first bit is 1).
208    pub fn is_response(&self) -> bool {
209        self.flags.get() & IS_RESPONSE_MASK != 0
210    }
211
212    /// Returns the question count of this header.
213    pub fn question_count(&self) -> usize {
214        self.question_count.get().into()
215    }
216
217    /// Returns the answer count of this header.
218    pub fn answer_count(&self) -> usize {
219        self.answer_count.get().into()
220    }
221
222    /// Returns the authority count of this header.
223    pub fn authority_count(&self) -> usize {
224        self.authority_count.get().into()
225    }
226
227    /// Returns the additional record count of this header.
228    pub fn additional_count(&self) -> usize {
229        self.additional_count.get().into()
230    }
231}
232
233/// Represents a parsed mDNS question.
234pub struct Question<B: SplitByteSlice> {
235    pub domain: Domain<B>,
236    pub qtype: Type,
237    pub class: Class,
238    pub unicast: bool,
239}
240
241impl<B: SplitByteSlice + Copy> Question<B> {
242    fn parse<BV: BufferView<B>>(buffer: &mut BV, parent: Option<B>) -> Result<Self, ParseError> {
243        let domain = Domain::parse(buffer, parent)?;
244        let qtype = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?;
245        let class_and_ucast = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?;
246        let unicast: bool = class_and_ucast.get() & (1u16 << 15) != 0;
247        let class: u16 = class_and_ucast.get() & 0x7fff;
248        Ok(Self { domain, qtype: qtype.get().try_into()?, class: class.try_into()?, unicast })
249    }
250}
251
252impl<B: SplitByteSlice + Copy> Display for Question<B> {
253    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
254        write!(f, "{:?}", self)
255    }
256}
257
258impl<B: SplitByteSlice + Copy> Debug for Question<B> {
259    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
260        write!(
261            f,
262            "{}: type {:?}, class {:?}, unicast {}",
263            self.domain, self.qtype, self.class, self.unicast,
264        )
265    }
266}
267
268/// A packet builder for an mDNS question.
269pub struct QuestionBuilder {
270    domain: DomainBuilder,
271    qtype: Type,
272    class: Class,
273    unicast: bool,
274}
275
276impl QuestionBuilder {
277    /// Constructs a QuestionBuilder.
278    pub fn new(domain: DomainBuilder, qtype: Type, class: Class, unicast: bool) -> Self {
279        Self { domain, qtype, class, unicast }
280    }
281}
282
283impl EmbeddedPacketBuilder for QuestionBuilder {
284    fn bytes_len(&self) -> usize {
285        self.domain.bytes_len()
286            + mem::size_of::<U16>() // type
287            + mem::size_of::<U16>() // class + unicast
288    }
289
290    fn serialize<B: SplitByteSliceMut, BV: BufferViewMut<B>>(&self, bv: &mut BV) {
291        self.domain.serialize(bv);
292        bv.take_obj_front::<U16>().unwrap().set(self.qtype.into());
293        bv.take_obj_front::<U16>().unwrap().set(self.class.into_u16_with_bool(self.unicast));
294    }
295}
296
297/// A parsed AAAA type record.
298#[derive(KnownLayout, FromBytes, Immutable)]
299pub struct Aaaa([u8; IPV6_SIZE]);
300
301/// A parsed A type record.
302#[derive(KnownLayout, FromBytes, Immutable)]
303pub struct A([u8; IPV4_SIZE]);
304
305/// A parsed SRV type record.
306pub struct SrvRecord<B: SplitByteSlice> {
307    priority: u16,
308    weight: u16,
309    port: u16,
310    domain: Domain<B>,
311}
312
313impl<B: SplitByteSlice + Copy> SrvRecord<B> {
314    fn parse<BV: BufferView<B>>(
315        buffer: &mut BV,
316        parent: Option<B>,
317        len_limit: u16,
318    ) -> Result<Self, ParseError> {
319        let priority = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get();
320        let weight = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get();
321        let port = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get();
322        let domain_buf = buffer.take_front(len_limit as usize).ok_or(ParseError::Malformed)?;
323        // Needs a length limit as the SRV record necessitates it.
324        let mut bv = BufferViewWrapper(domain_buf);
325        let domain = Domain::parse(&mut bv, parent)?;
326        // The domain should have consumed the entire buffer view.
327        if bv.as_ref().len() != 0 {
328            return Err(ParseError::Malformed);
329        }
330        Ok(Self { priority, weight, port, domain })
331    }
332}
333
334/// A parsed RData (can be one of several types). If this has been parsed in a
335/// PTR type, this will always be a `RData::Domain`. In a SRV type packet, this
336/// will always be a `RData::Srv`, anything else, currently, will be converted
337/// into `RData::Bytes`.
338pub enum RData<B: SplitByteSlice> {
339    A(A),
340    Aaaa(Aaaa),
341    Bytes(B),
342    Domain(Domain<B>),
343    Srv(SrvRecord<B>),
344}
345
346impl<B: SplitByteSlice> RData<B> {
347    /// Returns a reference to a `SrvRecord` if possible `None` otherwise.
348    fn srv(&self) -> Option<&SrvRecord<B>> {
349        match self {
350            RData::Srv(s) => Some(s),
351            _ => None,
352        }
353    }
354
355    /// Returns a `Domain` if possible, `None` otherwise. If this is a
356    /// `RData::Srv` then this returns a reference to its internal `Domain`.
357    fn domain(&self) -> Option<&Domain<B>> {
358        match self {
359            RData::Domain(d) => Some(d),
360            RData::Srv(s) => Some(&s.domain),
361            _ => None,
362        }
363    }
364
365    /// Returns a `IpAddr` if possible, `None` otherwise.
366    pub fn ip_addr(&self) -> Option<IpAddr> {
367        match self {
368            RData::Aaaa(aaaa) => Some(IpAddr::from(aaaa.0)),
369            RData::A(a) => Some(IpAddr::from(a.0)),
370            _ => None,
371        }
372    }
373
374    // TODO(awdavies): This is used in tests, and will be useful for getting
375    // strings out of Txt data later when there is an actual client
376    // implementation.
377    #[allow(unused)]
378    pub fn bytes(&self) -> Option<&B> {
379        match self {
380            RData::Bytes(b) => Some(b),
381            _ => None,
382        }
383    }
384}
385
386/// Record is the catch-all container for Answer, Authority, and Additional
387/// Records sections of an MDNS packet. This is the parsed version that is
388/// created when parsing a packet.
389pub struct Record<B: SplitByteSlice> {
390    pub domain: Domain<B>,
391    pub rtype: Type,
392    pub class: Class,
393    pub ttl: u32,
394    pub flush: bool,
395    pub rdata: RData<B>,
396}
397
398impl<B: SplitByteSlice + Copy> Display for Record<B> {
399    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
400        write!(f, "{:?}", self)
401    }
402}
403
404impl<B: SplitByteSlice + Copy> Debug for Record<B> {
405    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
406        write!(f, "{}: type {:?}, class {:?}", self.domain, self.rtype, self.class)?;
407        match self.rtype {
408            Type::Srv => {
409                let srv = self.rdata.srv().unwrap();
410                write!(
411                    f,
412                    ", priority {}, weight {}, port {}, target {}",
413                    srv.priority, srv.weight, srv.port, srv.domain
414                )
415            }
416            Type::Ptr => write!(f, ", {}", self.rdata.domain().unwrap()),
417            // Don't print anything unless it's guaranteed that a certain type
418            // will have certain data (Srv and Ptr will always be their
419            // respective types for now).
420            _ => Ok(()),
421        }
422    }
423}
424
425fn valid_rdata_len(r: Type, len: u16) -> Result<u16, ParseError> {
426    match r {
427        Type::A => {
428            if len != IPV4_SIZE as u16 {
429                return Err(ParseError::RDataLen(r, len));
430            }
431        }
432        Type::Aaaa => {
433            if len != IPV6_SIZE as u16 {
434                return Err(ParseError::RDataLen(r, len));
435            }
436        }
437        Type::Srv => {
438            // Minimum size of SRV is the payload and enough for a domain
439            // pointer or domain of a single character.
440            if len < SRV_PAYLOAD_SIZE_OCTETS + 2 {
441                return Err(ParseError::RDataLen(r, len));
442            }
443        }
444        _ => (),
445    }
446
447    Ok(len)
448}
449
450impl<B: SplitByteSlice + Copy> Record<B> {
451    fn parse<BV: BufferView<B>>(buffer: &mut BV, parent: Option<B>) -> Result<Self, ParseError> {
452        let domain = Domain::parse(buffer, parent)?;
453        let rtype: Type =
454            buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get().try_into()?;
455        let class_and_flush = buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?;
456        let flush = class_and_flush.get() & (1u16 << 15) != 0;
457        let class: Class = (class_and_flush.get() & 0x7fff).try_into()?;
458        let ttl = buffer.take_obj_front::<U32>().ok_or(ParseError::Malformed)?.get();
459        let rdata_len = valid_rdata_len(
460            rtype,
461            buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get(),
462        )?;
463        let rdata = match rtype {
464            Type::Srv => {
465                RData::Srv(SrvRecord::parse(buffer, parent, rdata_len - SRV_PAYLOAD_SIZE_OCTETS)?)
466            }
467            Type::Ptr => {
468                let ptr_domain_buf =
469                    buffer.take_front(rdata_len.into()).ok_or(ParseError::Malformed)?;
470                let mut ptr_domain_bv = BufferViewWrapper(ptr_domain_buf);
471                let ptr_domain = Domain::parse(&mut ptr_domain_bv, parent)?;
472                // Must consume the whole buffer.
473                if ptr_domain_bv.as_ref().len() != 0 {
474                    return Err(ParseError::Malformed);
475                }
476                RData::Domain(ptr_domain)
477            }
478            Type::A => {
479                let buf = buffer.take_front(IPV4_SIZE).ok_or(ParseError::Malformed)?;
480                RData::A(A::read_from_bytes(&buf).map_err(|_| ParseError::Malformed)?)
481            }
482            Type::Aaaa => {
483                let buf = buffer.take_front(IPV6_SIZE).ok_or(ParseError::Malformed)?;
484                RData::Aaaa(Aaaa::read_from_bytes(&buf).map_err(|_| ParseError::Malformed)?)
485            }
486            _ => RData::Bytes(buffer.take_front(rdata_len.into()).ok_or(ParseError::Malformed)?),
487        };
488
489        Ok(Self { domain, rtype, class, ttl, flush, rdata })
490    }
491}
492
493/// A record builder for creating a serialized version of an mDNS Record, which
494/// is the catch-all type for Answers, Additional Records, and Authority
495/// records.
496pub struct RecordBuilder<'a> {
497    domain: DomainBuilder,
498    rtype: Type,
499    class: Class,
500    flush: bool,
501    ttl: u32,
502    rdata: &'a [u8],
503}
504
505impl<'a> RecordBuilder<'a> {
506    /// Constructs a `RecordBuilder`. Inputs must be valid for constructing an
507    /// mDNS message or this will panic.
508    ///
509    /// # Panics
510    ///
511    /// Will panic if `rdata` is too large to have its length stored in a `u16`,
512    /// which is necessary for successful serialization.
513    ///
514    /// Will panic if `rdata` is empty, as this is not supported.
515    pub fn new(
516        domain: DomainBuilder,
517        rtype: Type,
518        class: Class,
519        flush: bool,
520        ttl: u32,
521        rdata: &'a [u8],
522    ) -> Self {
523        // Will panic if attempting to create too large of a message.
524        let len = u16::try_from(rdata.len()).unwrap();
525        if len == 0 {
526            panic!("empty rdata not supported");
527        }
528        Self { domain, rtype, class, ttl, flush, rdata }
529    }
530}
531
532impl EmbeddedPacketBuilder for RecordBuilder<'_> {
533    fn bytes_len(&self) -> usize {
534        self.domain.bytes_len()
535            + mem::size_of::<U16>()  // type
536            + mem::size_of::<U16>()  // class + flush
537            + mem::size_of::<U32>()  // ttl
538            + mem::size_of::<U16>()  // rdata_len
539            + self.rdata.len()
540    }
541
542    fn serialize<B: SplitByteSliceMut, BV: BufferViewMut<B>>(&self, bv: &mut BV) {
543        self.domain.serialize(bv);
544        bv.take_obj_front::<U16>().unwrap().set(self.rtype.into());
545        bv.take_obj_front::<U16>().unwrap().set(self.class.into_u16_with_bool(self.flush));
546        bv.take_obj_front::<U32>().unwrap().set(self.ttl);
547        bv.take_obj_front::<U16>().unwrap().set(u16::try_from(self.rdata.len()).unwrap());
548        bv.take_front(self.rdata.len()).unwrap().copy_from_slice(self.rdata);
549    }
550}
551
552/// A parsed mDNS message in its entirety.
553pub struct Message<B: SplitByteSlice> {
554    pub header: Ref<B, Header>,
555    pub questions: Vec<Question<B>>,
556    pub answers: Vec<Record<B>>,
557    pub authority: Vec<Record<B>>,
558    pub additional: Vec<Record<B>>,
559}
560
561impl<B: SplitByteSlice + Copy> Message<B> {
562    #[inline]
563    fn parse_records<BV: BufferView<B>>(
564        buffer: &mut BV,
565        parent: Option<B>,
566        count: usize,
567    ) -> Result<Vec<Record<B>>, ParseError> {
568        let mut records = Vec::<Record<B>>::with_capacity(count);
569        for _ in 0..count {
570            records.push(Record::parse(buffer, parent)?);
571        }
572        Ok(records)
573    }
574}
575
576impl<B: SplitByteSlice + Copy> ParsablePacket<B, ()> for Message<B> {
577    type Error = ParseError;
578
579    fn parse<BV: BufferView<B>>(buffer: BV, _args: ()) -> Result<Self, Self::Error> {
580        let body = buffer.into_rest();
581        let mut buffer = BufferViewWrapper(body);
582
583        let header = buffer.take_obj_front::<Header>().ok_or(ParseError::Malformed)?;
584        let mut questions: Vec<Question<B>> = Vec::with_capacity(header.question_count());
585        for _ in 0..header.question_count.get() {
586            questions.push(Question::parse(&mut buffer, Some(body))?);
587        }
588        let answers = Message::parse_records(&mut buffer, Some(body), header.answer_count())?;
589        let authority = Message::parse_records(&mut buffer, Some(body), header.authority_count())?;
590        let additional =
591            Message::parse_records(&mut buffer, Some(body), header.additional_count())?;
592        Ok(Self { header, questions, answers, authority, additional })
593    }
594
595    fn parse_metadata(&self) -> ParseMetadata {
596        // ParseMetadata is only needed if we do undo parse.
597        unimplemented!()
598    }
599}
600
601/// A builder for creating an mDNS message.
602pub struct MessageBuilder<'a> {
603    pub id: u16,
604    pub flags: u16,
605
606    questions: Vec<QuestionBuilder>,
607    answers: Vec<RecordBuilder<'a>>,
608    authority: Vec<RecordBuilder<'a>>,
609    additional: Vec<RecordBuilder<'a>>,
610}
611
612impl<'a> MessageBuilder<'a> {
613    pub fn new(id: u16, is_query: bool) -> Self {
614        let mut flags = 0u16;
615        if !is_query {
616            flags |= IS_RESPONSE_MASK;
617        }
618        Self {
619            id,
620            flags,
621            questions: Vec::new(),
622            answers: Vec::new(),
623            authority: Vec::new(),
624            additional: Vec::new(),
625        }
626    }
627
628    pub fn add_question(&mut self, q: QuestionBuilder) {
629        self.questions.push(q);
630    }
631
632    pub fn add_answer(&mut self, a: RecordBuilder<'a>) {
633        self.answers.push(a);
634    }
635
636    pub fn add_authority(&mut self, a: RecordBuilder<'a>) {
637        self.authority.push(a);
638    }
639
640    pub fn add_additional(&mut self, a: RecordBuilder<'a>) {
641        self.additional.push(a);
642    }
643}
644
645impl InnerPacketBuilder for MessageBuilder<'_> {
646    fn bytes_len(&self) -> usize {
647        mem::size_of::<Header>()
648            + self.questions.iter().fold(0, |r, s| r + s.bytes_len())
649            + self.answers.iter().fold(0, |r, s| r + s.bytes_len())
650            + self.authority.iter().fold(0, |r, s| r + s.bytes_len())
651            + self.additional.iter().fold(0, |r, s| r + s.bytes_len())
652    }
653
654    fn serialize(&self, mut buffer: &mut [u8]) {
655        // Inherits BufferViewMut trait.
656        let mut bv = &mut buffer;
657        let mut header = bv.take_obj_front_zero::<Header>().unwrap();
658        header.id.set(self.id);
659        header.flags.set(self.flags);
660        header.question_count.set(self.questions.len() as u16);
661        header.answer_count.set(self.answers.len() as u16);
662        header.authority_count.set(self.authority.len() as u16);
663        header.additional_count.set(self.additional.len() as u16);
664        self.questions.iter().for_each(|e| e.serialize(&mut bv));
665        self.answers.iter().for_each(|e| e.serialize(&mut bv));
666        self.authority.iter().for_each(|e| e.serialize(&mut bv));
667        self.additional.iter().for_each(|e| e.serialize(&mut bv));
668    }
669}
670
671/// A parsed mDNS domain. There is no need to worry about message compression
672/// when comparing against a string, and can be treated as a contiguous domain.
673#[derive(PartialEq, Eq)]
674pub struct Domain<B: SplitByteSlice> {
675    fragments: Vec<B>,
676}
677
678enum DomainData<B: SplitByteSlice> {
679    Domain(B),
680    Pointer(Option<B>, u16),
681}
682
683impl<B: SplitByteSlice + Copy> Domain<B> {
684    fn fmt_byte_slice(f: &mut Formatter<'_>, b: &B) -> std::fmt::Result {
685        let mut iter = b.as_ref().iter();
686        let mut idx = 0;
687        loop {
688            let opt = iter.next();
689            // Here it's possible that there's no null terminator (in the case
690            // of pointers), but for any valid domain that has passed the
691            // `parse` method, this should not be an issue.
692            let c = match opt {
693                Some(v) => v,
694                None => break,
695            };
696            if *c == 0 {
697                break;
698            }
699            if idx > 0 {
700                f.write_str(".")?;
701            }
702            let skip = *c as usize;
703            for _ in 0..skip {
704                write!(f, "{}", *iter.next().unwrap() as char)?;
705            }
706            idx += 1;
707        }
708        Ok(())
709    }
710
711    fn partial_eq_helper_slice<BV: BufferView<B>>(
712        other_bv: &mut BV,
713        b: &B,
714    ) -> Result<bool, ParseError> {
715        // TODO(awdavies): This comparison and builder logic should probably
716        // abstracted a bit.
717        let mut dref = &mut b.as_ref();
718        // Gets BufferView trait.
719        let bv = &mut dref;
720        loop {
721            let domain_len = match bv.take_byte_front() {
722                Some(d) => d,
723                None => break,
724            };
725            if domain_len == 0 {
726                break;
727            }
728            let mut other_len = 0u8;
729            loop {
730                match other_bv.take_byte_front() {
731                    // At end of string or a '.' symbol.
732                    Some(46) | None => {
733                        if domain_len != other_len {
734                            return Ok(false);
735                        }
736                        break;
737                    }
738                    Some(c) => {
739                        if c != bv.take_byte_front().ok_or(ParseError::Malformed)? {
740                            return Ok(false);
741                        }
742                        other_len += 1;
743                    }
744                }
745            }
746        }
747        if bv.len() > 0 {
748            return Ok(false);
749        }
750        Ok(true)
751    }
752
753    fn partial_eq_helper_str(&self, other: &str) -> Result<bool, ParseError> {
754        let mut domain_bv = BufferViewWrapper(other.as_bytes());
755        for d in self.fragments.iter() {
756            if !Domain::partial_eq_helper_slice(&mut domain_bv, &d.as_ref())? {
757                return Ok(false);
758            }
759        }
760        return Ok(domain_bv.as_ref().len() == 0);
761    }
762
763    fn parse_domain_helper<BV: BufferView<B>>(
764        buffer: &mut BV,
765    ) -> Result<DomainData<B>, ParseError> {
766        let mut iter = buffer.as_ref().iter();
767        let mut idx = 0;
768        loop {
769            let domain_len = *iter.next().ok_or(ParseError::Malformed)?;
770            idx += 1;
771            // If this is a compression byte, then either we're at the end of
772            // the domain, or this is the first byte of the domain and the whole
773            // thing is determined by the pointer.
774            if is_compression_byte(domain_len) {
775                return match idx {
776                    1 => {
777                        let location =
778                            buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get();
779                        Ok(DomainData::Pointer(None, unwrap_domain_pointer(location)))
780                    }
781                    _ => {
782                        let data = buffer.take_front(idx - 1).ok_or(ParseError::Malformed)?;
783                        let location =
784                            buffer.take_obj_front::<U16>().ok_or(ParseError::Malformed)?.get();
785                        Ok(DomainData::Pointer(Some(data), unwrap_domain_pointer(location)))
786                    }
787                };
788            }
789            // If this is the null terminator, then we're either done iterating,
790            // or this is a malformed label (if it is the first byte).
791            if domain_len == 0 {
792                if idx == 1 {
793                    return Err(ParseError::UnexpectedZeroCharacter);
794                }
795                break;
796            }
797            if domain_len as usize > MAX_LABEL_SIZE {
798                return Err(ParseError::LabelTooLong(domain_len.into()));
799            }
800            for _ in 0..domain_len {
801                if *iter.next().ok_or(ParseError::Malformed)? == 0 {
802                    return Err(ParseError::UnexpectedZeroCharacter)?;
803                }
804                idx += 1;
805            }
806        }
807        if idx > MAX_DOMAIN_SIZE {
808            return Err(ParseError::DomainTooLong(idx))?;
809        }
810        Ok(DomainData::Domain(buffer.take_front(idx).ok_or(ParseError::Malformed)?))
811    }
812
813    /// Parse the provided record.
814    ///
815    /// `parent` is used for resolving compressed names as specified by [RFC
816    /// 1035 Section 4.1.4]. If `parent` is None, a reference to previous data
817    /// will be treated as an error.
818    ///
819    /// [RFC 1035 Section 4.1.4]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4
820    pub fn parse<BV: BufferView<B>>(
821        buffer: &mut BV,
822        parent: Option<B>,
823    ) -> Result<Self, ParseError> {
824        let mut fragments = Vec::<B>::new();
825        let mut pointer_set = HashSet::<u16>::new();
826        let mut result = Domain::parse_domain_helper(buffer)?;
827        loop {
828            match result {
829                DomainData::Domain(data) => {
830                    fragments.push(data);
831                    return Ok(Self { fragments });
832                }
833                DomainData::Pointer(data, pointer) => {
834                    if let Some(d) = data {
835                        fragments.push(d);
836                    }
837                    if pointer_set.contains(&pointer) {
838                        return Err(ParseError::PointerCycle);
839                    }
840                    pointer_set.insert(pointer);
841                    let mut bv = parent
842                        .and_then(|parent| {
843                            let mut bv = BufferViewWrapper(parent.clone());
844                            bv.take_front(pointer.into()).map(|_: B| bv)
845                        })
846                        .ok_or(ParseError::BadPointerIndex(pointer))?;
847                    result = Domain::parse_domain_helper(&mut bv)?;
848                }
849            }
850        }
851    }
852}
853
854/// Implementation of PartialEq to make it possible to compare a parsed domain
855/// with the initial string that was used to construct it.
856impl<B: SplitByteSlice + Copy> PartialEq<&str> for Domain<B> {
857    fn eq(&self, other: &&str) -> bool {
858        self.partial_eq_helper_str(other).or::<bool>(Ok(false)).unwrap()
859    }
860}
861
862impl<B: SplitByteSlice + Copy> Display for Domain<B> {
863    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
864        write!(f, "{:?}", self)
865    }
866}
867
868impl<B: SplitByteSlice + Copy> Debug for Domain<B> {
869    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
870        let mut iter = self.fragments.iter();
871        let data = iter.next().unwrap();
872        Domain::fmt_byte_slice(f, data)?;
873        loop {
874            if let Some(next) = iter.next() {
875                f.write_str(".")?;
876                Domain::fmt_byte_slice(f, next)?;
877            } else {
878                return Ok(());
879            }
880        }
881    }
882}
883
884/// An mDNS compliant domain builder. Does not support message compression.
885#[derive(Clone, Debug, PartialEq)]
886pub struct DomainBuilder {
887    // TODO(awdavies): This can probably use the Buf struct instead.
888    data: Vec<u8>,
889}
890
891impl DomainBuilder {
892    /// Attempts to construct a domain from a string formatted according to DNS
893    /// standards.
894    ///
895    /// Example usage:
896    /// ```rust
897    /// let domain = DomainBuilder::from_str("_fuchsia._udp.local")?;
898    /// ```
899    ///
900    /// # Errors
901    ///
902    /// If the domain you supply is larger than `MAX_DOMAIN_SIZE` this will
903    /// return an error. It is also an error if any individual label
904    /// (the section of string between dots) is longer than 63 bytes.
905    ///
906    /// Currently, terminating a string with a dot is not supported.
907    pub fn from_str(domain: &str) -> Result<Self, ParseError> {
908        let mut data = Vec::<u8>::with_capacity(MAX_DOMAIN_SIZE);
909        let mut domain_iter = domain.as_bytes().as_ref().iter();
910        loop {
911            data.push(0);
912            // When copying is complete there will be one extra byte on the
913            // beginning and end of the string, so the last_len_idx will be
914            // equal to the total number of characters in the domain string plus
915            // one.
916            let last_len_idx = data.len() - 1;
917            if last_len_idx == domain.len() + 1 {
918                break;
919            }
920            let mut str_len = 0u8;
921            loop {
922                match domain_iter.next() {
923                    // At end of string or a '.' symbol.
924                    Some(46) | None => {
925                        if str_len > MAX_LABEL_SIZE as u8 {
926                            return Err(ParseError::Malformed);
927                        }
928                        data[last_len_idx] = str_len;
929                        break;
930                    }
931                    Some(&c) => {
932                        data.push(c);
933                        str_len += 1;
934                    }
935                }
936            }
937        }
938        if data.len() == 0 || data.len() > MAX_DOMAIN_SIZE {
939            return Err(ParseError::Malformed);
940        }
941        Ok(Self { data })
942    }
943}
944
945impl EmbeddedPacketBuilder for DomainBuilder {
946    fn bytes_len(&self) -> usize {
947        self.data.len()
948    }
949
950    fn serialize<B: SplitByteSliceMut, BV: BufferViewMut<B>>(&self, bv: &mut BV) {
951        bv.take_front(self.data.len()).unwrap().copy_from_slice(self.data.as_slice());
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958
959    use packet::{ParseBuffer, Serializer};
960    use std::fmt::Write;
961
962    trait EmbeddedPacketBuilderTestExt: EmbeddedPacketBuilder {
963        /// Convenience method for testing.
964        fn serialize_to_buf(&self, mut buf: &mut [u8]) {
965            let mut bv = &mut buf;
966            self.serialize(&mut bv);
967        }
968    }
969    impl<B: EmbeddedPacketBuilder> EmbeddedPacketBuilderTestExt for B {}
970
971    struct DomainParseTest {
972        packet: Vec<u8>,
973        parsing_offset: usize,
974        expected_result: &'static str,
975    }
976
977    // Some standard-looking domains gathered from the real world.
978    const DOMAIN_STRING: &str = "_fuchsia._udp.local";
979    const NODENAME_DOMAIN_STRING: &str = "thumb-set-human-shred._fuchsia._udp.local";
980    const DOMAIN_BYTES: [u8; 21] = [
981        0x08, 0x5f, 0x66, 0x75, 0x63, 0x68, 0x73, 0x69, 0x61, 0x04, 0x5f, 0x75, 0x64, 0x70, 0x05,
982        0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x00,
983    ];
984
985    fn make_buf(size: usize) -> Vec<u8> {
986        let mut buf = Vec::<u8>::with_capacity(size);
987        for _ in 0..size {
988            buf.push(0);
989        }
990        return buf;
991    }
992
993    #[test]
994    fn test_embedded_packet_builder_bytes() {
995        assert_eq!(
996            &[3, 'f' as u8, 'o' as u8, 'o' as u8, 0][..],
997            DomainBuilder::from_str("foo").unwrap().bytes()
998        );
999    }
1000
1001    #[test]
1002    fn test_parse_type() {
1003        const TYPES: [Type; 5] = [Type::A, Type::Aaaa, Type::Ptr, Type::Srv, Type::Txt];
1004        for t in TYPES.iter() {
1005            match Type::try_from(u16::from(*t)) {
1006                Ok(parsed_type) => assert_eq!(*t, parsed_type),
1007                Err(e) => panic!("parse error {:?}", e),
1008            }
1009        }
1010    }
1011
1012    #[test]
1013    fn test_parse_class() {
1014        const CLASSES: [Class; 2] = [Class::In, Class::Any];
1015        for c in CLASSES.iter() {
1016            match Class::try_from(u16::from(*c)) {
1017                Ok(parsed_class) => assert_eq!(*c, parsed_class),
1018                Err(e) => panic!("parse error {:?}", e),
1019            }
1020        }
1021    }
1022
1023    #[test]
1024    fn test_domain_parse() {
1025        let mut bv = BufferViewWrapper(&DOMAIN_BYTES[..]);
1026        let _ = Domain::parse(&mut bv, None).expect("Failed to parse");
1027    }
1028
1029    #[test]
1030    fn test_domain_roundtrip() {
1031        for example in [DOMAIN_STRING, NODENAME_DOMAIN_STRING] {
1032            let domain = DomainBuilder::from_str(example).unwrap();
1033            let mut buf = make_buf(domain.bytes_len());
1034            domain.serialize_to_buf(buf.as_mut_slice());
1035
1036            let mut bv = BufferViewWrapper(buf.as_slice());
1037            let parsed = Domain::parse(&mut bv, None).unwrap();
1038            assert_eq!(example, format!("{}", parsed));
1039        }
1040    }
1041
1042    #[test]
1043    fn test_ipv4_parse() {
1044        const ADDR: [u8; IPV4_SIZE] = [192, 168, 0, 2];
1045        let a = RData::<&[u8]>::A(A::read_from_bytes(&ADDR[..]).unwrap());
1046        assert_eq!(a.ip_addr(), Some(IpAddr::from(ADDR)));
1047    }
1048
1049    #[test]
1050    fn test_ipv6_parse() {
1051        const ADDR: [u8; IPV6_SIZE] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
1052        let aaaa = RData::<&[u8]>::Aaaa(Aaaa::read_from_bytes(&ADDR[..]).unwrap());
1053        assert_eq!(aaaa.ip_addr(), Some(IpAddr::from(ADDR)));
1054    }
1055
1056    #[test]
1057    fn test_domain_build_and_parse() {
1058        const BAD_DOMAIN_SHORT: &'static str = "_fuchsia._udp.loca";
1059        const BAD_DOMAIN_LONG: &'static str = "_fuchsia._udp.local.whatever.toooooolong";
1060        let domain = DomainBuilder::from_str(DOMAIN_STRING).unwrap();
1061        let mut buf = make_buf(domain.bytes_len());
1062        domain.serialize_to_buf(buf.as_mut_slice());
1063
1064        let mut bv = BufferViewWrapper(buf.as_slice());
1065        let parsed = Domain::parse(&mut bv, None).unwrap();
1066        let mut s = String::new();
1067        write!(&mut s, "{}", parsed).unwrap();
1068        assert_eq!(s, DOMAIN_STRING);
1069        assert_eq!(parsed, DOMAIN_STRING);
1070        assert_ne!(parsed, BAD_DOMAIN_SHORT);
1071        assert_ne!(parsed, BAD_DOMAIN_LONG);
1072    }
1073
1074    #[test]
1075    fn test_message_build_and_parse_one_question_one_record() {
1076        let domain = DomainBuilder::from_str(DOMAIN_STRING).unwrap();
1077        let nodename = DomainBuilder::from_str(NODENAME_DOMAIN_STRING).unwrap();
1078        let question = QuestionBuilder::new(domain, Type::Aaaa, Class::In, true);
1079        let record = RecordBuilder::new(
1080            nodename,
1081            Type::Ptr,
1082            Class::Any,
1083            true,
1084            4500,
1085            &[0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0],
1086        );
1087        let mut message = MessageBuilder::new(0, true);
1088        message.add_question(question);
1089        message.add_additional(record);
1090        let mut msg_bytes = message
1091            .into_serializer()
1092            .serialize_vec_outer()
1093            .unwrap_or_else(|_| panic!("Failed to serialize"));
1094        let parsed = msg_bytes.parse::<Message<_>>().expect("Failed to parse!");
1095        // TODO(awdavies): These checks can probably be abstracted a bit.
1096        let q = &parsed.questions[0];
1097        assert_eq!(q.domain, DOMAIN_STRING);
1098        assert_eq!(q.qtype, Type::Aaaa);
1099        assert_eq!(q.class, Class::In);
1100        assert_eq!(q.unicast, true);
1101        assert_eq!(parsed.header.is_query(), true);
1102        assert_eq!(parsed.questions.len(), 1);
1103        assert_eq!(parsed.answers.len(), 0);
1104        assert_eq!(parsed.authority.len(), 0);
1105        assert_eq!(parsed.additional.len(), 1);
1106        let additional = &parsed.additional[0];
1107        assert_eq!(additional.domain, NODENAME_DOMAIN_STRING);
1108        assert_eq!(additional.ttl, 4500);
1109        assert_eq!(additional.flush, true);
1110        assert_eq!(additional.rtype, Type::Ptr);
1111        assert_eq!(additional.class, Class::Any);
1112        assert_eq!(additional.rdata.domain().unwrap(), &"foo");
1113    }
1114
1115    #[test]
1116    fn test_question_build_and_parse() {
1117        let domain = DomainBuilder::from_str(DOMAIN_STRING).unwrap();
1118        let unicast = false;
1119        let qtype = Type::Aaaa;
1120        let class = Class::In;
1121        let question = QuestionBuilder::new(domain, qtype, class, unicast);
1122        let mut buf = make_buf(question.bytes_len());
1123        question.serialize_to_buf(buf.as_mut_slice());
1124        let mut bv = BufferViewWrapper(buf.as_ref());
1125        let parsed = Question::parse(&mut bv, None).unwrap();
1126        assert_eq!(parsed.unicast, unicast);
1127        assert_eq!(parsed.qtype, qtype);
1128        assert_eq!(parsed.class, class);
1129    }
1130
1131    #[test]
1132    fn test_record_build_and_parse() {
1133        for r in [
1134            RecordBuilder {
1135                rdata: &[127, 0, 0, 1],
1136                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1137                ttl: 3500,
1138                rtype: Type::A,
1139                class: Class::In,
1140                flush: true,
1141            },
1142            RecordBuilder {
1143                rdata: &[
1144                    0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0x8e, 0xae, 0x4c, 0xff, 0xfe, 0xe9, 0xc9, 0xd3,
1145                ],
1146                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1147                ttl: 1,
1148                rtype: Type::Aaaa,
1149                class: Class::In,
1150                flush: false,
1151            },
1152            RecordBuilder {
1153                rdata: &[1, 2, 3, 4, 5, 6, 0x3, 'f' as u8, 'o' as u8, 'o' as u8, 0],
1154                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1155                ttl: 5000,
1156                rtype: Type::Srv,
1157                class: Class::In,
1158                flush: true,
1159            },
1160            RecordBuilder {
1161                rdata: &[0x04, 'q' as u8, 'u' as u8, 'u' as u8, 'x' as u8, 0x00],
1162                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1163                ttl: 10,
1164                rtype: Type::Ptr,
1165                class: Class::In,
1166                flush: false,
1167            },
1168        ]
1169        .iter()
1170        {
1171            let mut buf = make_buf(r.bytes_len());
1172            r.serialize_to_buf(buf.as_mut_slice());
1173            let mut bv = BufferViewWrapper(buf.as_ref());
1174            let parsed = Record::parse(&mut bv, None).unwrap();
1175            assert_eq!(parsed.domain, DOMAIN_STRING);
1176            assert_eq!(r.rtype, parsed.rtype);
1177            assert_eq!(r.ttl, parsed.ttl);
1178            assert_eq!(r.class, parsed.class);
1179            assert_eq!(r.flush, parsed.flush);
1180            match parsed.rtype {
1181                Type::Srv => {
1182                    assert_eq!(parsed.rdata.domain().unwrap(), &"foo");
1183                    let srv = parsed.rdata.srv().unwrap();
1184                    assert_eq!(srv.domain, "foo");
1185                    assert_eq!(srv.priority, 0x0102);
1186                    assert_eq!(srv.weight, 0x0304);
1187                    assert_eq!(srv.port, 0x0506);
1188                }
1189                Type::A => {
1190                    if let IpAddr::V4(addr) = parsed.rdata.ip_addr().unwrap() {
1191                        assert_eq!(&addr.octets(), &r.rdata);
1192                    } else {
1193                        panic!("expected IpAddr::V4");
1194                    }
1195                }
1196                Type::Aaaa => {
1197                    if let IpAddr::V6(addr) = parsed.rdata.ip_addr().unwrap() {
1198                        assert_eq!(&addr.octets(), &r.rdata);
1199                    } else {
1200                        panic!("expected IpAddr::V6");
1201                    }
1202                }
1203                Type::Ptr => assert_eq!(parsed.rdata.domain().unwrap(), &"quux"),
1204                _ => (),
1205            }
1206        }
1207    }
1208
1209    #[test]
1210    fn test_srv_record_bad_sizing() {
1211        for r in [
1212            // RData with extra after null terminator.
1213            RecordBuilder {
1214                rdata: &[1, 2, 3, 4, 5, 6, 0x3, 'f' as u8, 'o' as u8, 'o' as u8, 0, 1, 2, 3, 4],
1215                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1216                class: Class::Any,
1217                flush: true,
1218                ttl: 2,
1219                rtype: Type::Srv,
1220            },
1221            // One byte too short.
1222            RecordBuilder {
1223                rdata: &[1, 2, 3, 4, 5],
1224                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1225                class: Class::Any,
1226                flush: true,
1227                ttl: 2,
1228                rtype: Type::Srv,
1229            },
1230            // Empty RData.
1231            RecordBuilder {
1232                rdata: &[],
1233                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1234                class: Class::Any,
1235                flush: true,
1236                ttl: 1,
1237                rtype: Type::Srv,
1238            },
1239            // Null domain.
1240            RecordBuilder {
1241                rdata: &[1, 2, 3, 4, 5, 6, 0],
1242                domain: DomainBuilder::from_str(DOMAIN_STRING).unwrap(),
1243                class: Class::Any,
1244                flush: true,
1245                ttl: 1,
1246                rtype: Type::Srv,
1247            },
1248        ]
1249        .iter()
1250        {
1251            let mut buf = make_buf(r.bytes_len());
1252            r.serialize_to_buf(buf.as_mut_slice());
1253            let mut bv = BufferViewWrapper(buf.as_ref());
1254            // Will panic if there is not an error.
1255            let _ = Record::parse(&mut bv, None).unwrap_err();
1256        }
1257    }
1258
1259    #[test]
1260    fn test_domain_parse_no_trailing() {
1261        let mut bv = BufferViewWrapper(&DOMAIN_BYTES[..DOMAIN_BYTES.len() - 1]);
1262        assert_eq!(Domain::parse(&mut bv, None).unwrap_err(), ParseError::Malformed);
1263    }
1264
1265    #[test]
1266    fn test_domain_parse_middle() {
1267        let packet = &mut DOMAIN_BYTES.to_vec();
1268        packet[3] = 0;
1269        let mut bv = BufferViewWrapper(&packet[..]);
1270        assert_eq!(Domain::parse(&mut bv, None).unwrap_err(), ParseError::UnexpectedZeroCharacter);
1271    }
1272
1273    #[test]
1274    fn test_domain_parse_label_too_long() {
1275        let packet = &mut DOMAIN_BYTES.to_vec();
1276        let bad_len = 65u8;
1277        packet[0] = bad_len;
1278        let mut bv = BufferViewWrapper(&packet[..]);
1279        assert_eq!(
1280            Domain::parse(&mut bv, None).unwrap_err(),
1281            ParseError::LabelTooLong(bad_len.into())
1282        );
1283    }
1284
1285    #[test]
1286    fn test_domain_parse_domain_too_long() {
1287        const LABELS: usize = 10;
1288        const SIZE: usize = MAX_LABEL_SIZE * LABELS;
1289        let mut packet = Vec::<u8>::with_capacity(SIZE);
1290        for _ in 0..LABELS {
1291            packet.push(MAX_LABEL_SIZE as u8);
1292            for _ in 0..MAX_LABEL_SIZE {
1293                packet.push('f' as u8);
1294            }
1295        }
1296        packet.push(0);
1297        let mut bv = BufferViewWrapper(packet.as_ref());
1298        assert_eq!(Domain::parse(&mut bv, None).unwrap_err(), ParseError::DomainTooLong(641));
1299    }
1300
1301    #[test]
1302    fn test_domain_parse_empty_message() {
1303        const PACKET: [u8; 1] = [0];
1304        let mut bv = BufferViewWrapper(&PACKET[..]);
1305        assert_eq!(Domain::parse(&mut bv, None).unwrap_err(), ParseError::UnexpectedZeroCharacter);
1306    }
1307
1308    #[test]
1309    fn test_domain_parse_short_malformed() {
1310        const PACKET: [u8; 2] = [1, 0];
1311        let mut bv = BufferViewWrapper(&PACKET[..]);
1312        assert_eq!(Domain::parse(&mut bv, None).unwrap_err(), ParseError::UnexpectedZeroCharacter);
1313    }
1314
1315    #[test]
1316    fn test_domain_bad_pointer_index() {
1317        let packet: Vec<u8> = vec![0u8, 0x01, 'y' as u8, 0xc0, 0x09];
1318        let slice: &[u8] = packet.as_ref();
1319        let mut bv = BufferViewWrapper(slice);
1320        bv.take_front(3).unwrap();
1321        assert_eq!(
1322            Domain::parse(&mut bv, Some(&slice)).unwrap_err(),
1323            ParseError::BadPointerIndex(0x09)
1324        );
1325    }
1326
1327    #[test]
1328    fn test_domain_pointer_with_no_parent() {
1329        let packet: Vec<u8> = vec![
1330            0u8, 0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0x03, 'b' as u8, 'a' as u8, 'r' as u8,
1331            0x00, 0x03, 'b' as u8, 'a' as u8, 'z' as u8, 0x03, 'b' as u8, 'o' as u8, 'i' as u8,
1332            0xc0, 0x01,
1333        ];
1334        let slice: &[u8] = packet.as_ref();
1335        {
1336            let mut bv = BufferViewWrapper(slice);
1337            bv.take_front(10).unwrap();
1338            // Prove that with parent this is valid.
1339            let _: Domain<_> = Domain::parse(&mut bv, Some(&slice)).expect("should succeed");
1340        }
1341
1342        {
1343            // Without parent the indirection is rejected.
1344            let mut bv = BufferViewWrapper(slice);
1345            bv.take_front(10).unwrap();
1346            assert_eq!(Domain::parse(&mut bv, None), Err(ParseError::BadPointerIndex(0x01)));
1347        }
1348    }
1349
1350    #[test]
1351    fn test_domain_pointer_cycles() {
1352        for packet in [vec![0xc0, 0x00], vec![0x02, 0x02, 0x01, 0xc0, 0x05, 0xc0, 0x03]].iter() {
1353            let slice: &[u8] = packet.as_ref();
1354            let mut bv = BufferViewWrapper(slice);
1355            assert_eq!(Domain::parse(&mut bv, Some(slice)).unwrap_err(), ParseError::PointerCycle);
1356        }
1357    }
1358
1359    #[test]
1360    fn test_domain_parse_fragmented_domains() {
1361        for test in [
1362            DomainParseTest {
1363                packet: vec![
1364                    0u8, 0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0x03, 'b' as u8, 'a' as u8,
1365                    'r' as u8, 0x00, 0xc0, 0x01,
1366                ],
1367                expected_result: "foo.bar",
1368                parsing_offset: 10,
1369            },
1370            DomainParseTest {
1371                packet: vec![
1372                    0u8, 0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0x03, 'b' as u8, 'a' as u8,
1373                    'r' as u8, 0x00, 0x03, 'b' as u8, 'a' as u8, 'z' as u8, 0x03, 'b' as u8,
1374                    'o' as u8, 'i' as u8, 0xc0, 0x01,
1375                ],
1376                expected_result: "baz.boi.foo.bar",
1377                parsing_offset: 10,
1378            },
1379            DomainParseTest {
1380                packet: vec![
1381                    2u8, 3u8, 0u8, 0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0x03, 'b' as u8,
1382                    'a' as u8, 'r' as u8, 0x00, 0x03, 'b' as u8, 'a' as u8, 'z' as u8, 0x03,
1383                    'b' as u8, 'o' as u8, 'i' as u8, 0x07, '_' as u8, 'm' as u8, 'u' as u8,
1384                    'm' as u8, 'b' as u8, 'l' as u8, 'e' as u8, 0xc0, 0x03,
1385                ],
1386                expected_result: "baz.boi._mumble.foo.bar",
1387                parsing_offset: 12,
1388            },
1389            DomainParseTest {
1390                packet: vec![
1391                    2u8, 3u8, 0u8, 0x03, 'f' as u8, 'o' as u8, 'o' as u8, 0x03, 'b' as u8,
1392                    'a' as u8, 'r' as u8, 0xc0, 0x1f, 0x03, 'b' as u8, 'a' as u8, 'z' as u8, 0x03,
1393                    'b' as u8, 'o' as u8, 'i' as u8, 0x07, '_' as u8, 'm' as u8, 'u' as u8,
1394                    'm' as u8, 'b' as u8, 'l' as u8, 'e' as u8, 0xc0, 0x03, 0x04, 'q' as u8,
1395                    'u' as u8, 'u' as u8, 'x' as u8, 0x00,
1396                ],
1397                expected_result: "baz.boi._mumble.foo.bar.quux",
1398                parsing_offset: 13,
1399            },
1400        ]
1401        .iter()
1402        {
1403            let slice: &[u8] = test.packet.as_ref();
1404            let mut bv = BufferViewWrapper(slice);
1405            bv.take_front(test.parsing_offset).unwrap();
1406            let parsed = Domain::parse(&mut bv, Some(&slice)).unwrap();
1407            let mut s = String::new();
1408            write!(&mut s, "{}", parsed).unwrap();
1409            assert_eq!(s, test.expected_result);
1410            assert_eq!(parsed, test.expected_result);
1411        }
1412    }
1413
1414    #[test]
1415    fn test_real_world_mdns_packet_response() {
1416        // This is a real world mDNS packet from a Fuchsia device. These bytes
1417        // were copied from wireshark (and the fields were extracted from there
1418        // as well).
1419        //
1420        // The structure is as follows:
1421        // Header:
1422        //  -- Flags 0x8400
1423        //  -- Question count 0
1424        //  -- Answer count 1
1425        //  -- Authority count 0
1426        //  -- Additional count 4
1427        //
1428        //  Answer 1:
1429        //   -- Type: PTR
1430        //   -- Domain: '_fuchsia._udp.local'
1431        //   -- Class: IN
1432        //   -- Flush: False
1433        //   -- TTL: 4500
1434        //   -- Data Length: 24
1435        //   -- Data (uncompressed): thumb-set-human-shred._fuchsia._udp.local
1436        //
1437        //  Additional record 1:
1438        //   -- Type: SRV
1439        //   -- Domain: thumb-set-human-shred._fuchsia._udp.local
1440        //   -- Class: IN
1441        //   -- Flush: True
1442        //   -- TTL: 120
1443        //   -- Data Length: 30
1444        //   -- Priority: 0
1445        //   -- Weight: 0
1446        //   -- Port: 5353
1447        //   -- Target: thumb-set-human-shred.local
1448        //
1449        //  Additional record 2:
1450        //   -- Type: TXT
1451        //   -- Domain: thumb-set-human-shred._fuchsia._udp.local
1452        //   -- Class: IN
1453        //   -- Flush: True
1454        //   -- TTL: 4500
1455        //   -- Data Length: 0
1456        //   -- Data: '\0'
1457        //
1458        //  Additional record 3:
1459        //   -- Type: A
1460        //   -- Domain: thumb-set-human-shred.local
1461        //   -- Class: IN
1462        //   -- Flush: True
1463        //   -- TTL: 120
1464        //   -- Data Length: 4
1465        //   -- Data: '172.16.243.38'
1466        //
1467        //  Additional record 4:
1468        //   -- Type: AAAA
1469        //   -- Domain: thumb-set-human-shred.local
1470        //   -- Class: IN
1471        //   -- Flush: True
1472        //   -- TTL: 120
1473        //   -- Data length: 16
1474        //   -- Data: 'fe80::8eae:4cff:fee9:c9d3'
1475        let packet: Vec<u8> = vec![
1476            0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x08, 0x5f,
1477            0x66, 0x75, 0x63, 0x68, 0x73, 0x69, 0x61, 0x04, 0x5f, 0x75, 0x64, 0x70, 0x05, 0x6c,
1478            0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x0c, 0x00, 0x01, 0x00, 0x00, 0x11, 0x94, 0x00,
1479            0x18, 0x15, 0x74, 0x68, 0x75, 0x6d, 0x62, 0x2d, 0x73, 0x65, 0x74, 0x2d, 0x68, 0x75,
1480            0x6d, 0x61, 0x6e, 0x2d, 0x73, 0x68, 0x72, 0x65, 0x64, 0xc0, 0x0c, 0xc0, 0x2b, 0x00,
1481            0x21, 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x14,
1482            0xe9, 0x15, 0x74, 0x68, 0x75, 0x6d, 0x62, 0x2d, 0x73, 0x65, 0x74, 0x2d, 0x68, 0x75,
1483            0x6d, 0x61, 0x6e, 0x2d, 0x73, 0x68, 0x72, 0x65, 0x64, 0xc0, 0x1a, 0xc0, 0x2b, 0x00,
1484            0x10, 0x80, 0x01, 0x00, 0x00, 0x11, 0x94, 0x00, 0x01, 0x00, 0xc0, 0x55, 0x00, 0x01,
1485            0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x04, 0xac, 0x10, 0xf3, 0x26, 0xc0, 0x55,
1486            0x00, 0x1c, 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x10, 0xfe, 0x80, 0x00, 0x00,
1487            0x00, 0x00, 0x00, 0x00, 0x8e, 0xae, 0x4c, 0xff, 0xfe, 0xe9, 0xc9, 0xd3,
1488        ];
1489        let mut packet_slice = packet.as_slice();
1490        let parsed = packet_slice.parse::<Message<_>>().expect("Failed to parse!");
1491        assert!(parsed.header.is_response());
1492        assert_eq!(parsed.header.question_count(), 0);
1493        assert_eq!(parsed.header.answer_count(), 1);
1494        assert_eq!(parsed.header.authority_count(), 0);
1495        assert_eq!(parsed.header.additional_count(), 4);
1496        let answer = &parsed.answers[0];
1497        assert_eq!(answer.rtype, Type::Ptr);
1498        assert_eq!(answer.domain, "_fuchsia._udp.local");
1499        assert_eq!(answer.class, Class::In);
1500        assert_eq!(answer.flush, false);
1501        assert_eq!(answer.ttl, 4500);
1502        assert_eq!(answer.rdata.domain().unwrap(), &"thumb-set-human-shred._fuchsia._udp.local");
1503        let srv = &parsed.additional[0];
1504        assert_eq!(srv.rtype, Type::Srv);
1505        assert_eq!(srv.domain, "thumb-set-human-shred._fuchsia._udp.local");
1506        assert_eq!(srv.class, Class::In);
1507        assert_eq!(srv.flush, true);
1508        assert_eq!(srv.ttl, 120);
1509        let srv_rdata = srv.rdata.srv().unwrap();
1510        assert_eq!(srv_rdata.weight, 0);
1511        assert_eq!(srv_rdata.priority, 0);
1512        assert_eq!(srv_rdata.port, 5353);
1513        assert_eq!(srv_rdata.domain, "thumb-set-human-shred.local");
1514        let txt = &parsed.additional[1];
1515        assert_eq!(txt.rtype, Type::Txt);
1516        assert_eq!(txt.domain, "thumb-set-human-shred._fuchsia._udp.local");
1517        assert_eq!(txt.class, Class::In);
1518        assert_eq!(txt.flush, true);
1519        assert_eq!(txt.ttl, 4500);
1520        assert_eq!(txt.rdata.bytes().unwrap().len(), 1);
1521        let a = &parsed.additional[2];
1522        assert_eq!(a.rtype, Type::A);
1523        assert_eq!(a.domain, "thumb-set-human-shred.local");
1524        assert_eq!(a.class, Class::In);
1525        assert_eq!(a.ttl, 120);
1526        if let IpAddr::V4(addr) = a.rdata.ip_addr().unwrap() {
1527            assert_eq!(&addr.octets()[..], &[172, 16, 243, 38]);
1528        } else {
1529            panic!("expected IpAddr::V4");
1530        }
1531        let aaaa = &parsed.additional[3];
1532        assert_eq!(aaaa.rtype, Type::Aaaa);
1533        assert_eq!(aaaa.domain, "thumb-set-human-shred.local");
1534        assert_eq!(aaaa.class, Class::In);
1535        assert_eq!(aaaa.ttl, 120);
1536        if let IpAddr::V6(addr) = aaaa.rdata.ip_addr().unwrap() {
1537            assert_eq!(
1538                &addr.octets()[..],
1539                &[
1540                    0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8e, 0xae, 0x4c, 0xff, 0xfe,
1541                    0xe9, 0xc9, 0xd3
1542                ]
1543            );
1544        } else {
1545            panic!("expected IpAddr::V6");
1546        }
1547    }
1548
1549    #[test]
1550    fn test_real_world_mdns_packet_question() {
1551        let packet: Vec<u8> = vec![
1552            0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x5f,
1553            0x66, 0x75, 0x63, 0x68, 0x73, 0x69, 0x61, 0x04, 0x5f, 0x75, 0x64, 0x70, 0x05, 0x6c,
1554            0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x0c, 0x00, 0x01,
1555        ];
1556        let mut packet_slice = packet.as_slice();
1557        let parsed = packet_slice.parse::<Message<_>>().expect("Failed to parse!");
1558        assert!(parsed.header.is_query());
1559        assert_eq!(parsed.header.question_count(), 1);
1560        assert_eq!(parsed.header.answer_count(), 0);
1561        assert_eq!(parsed.header.authority_count(), 0);
1562        assert_eq!(parsed.header.additional_count(), 0);
1563        let q = &parsed.questions[0];
1564        assert_eq!(q.domain, "_fuchsia._udp.local");
1565        assert_eq!(q.qtype, Type::Ptr);
1566        assert_eq!(q.class, Class::In);
1567        assert_eq!(q.unicast, false);
1568    }
1569}