trust_dns_proto/op/
message.rs

1// Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Basic protocol message for DNS
9
10use std::{fmt, iter, mem, ops::Deref, sync::Arc};
11
12use tracing::{debug, warn};
13
14use crate::{
15    error::*,
16    op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
17    rr::{Record, RecordType},
18    serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
19    xfer::DnsResponse,
20};
21
22/// The basic request and response datastructure, used for all DNS protocols.
23///
24/// [RFC 1035, DOMAIN NAMES - IMPLEMENTATION AND SPECIFICATION, November 1987](https://tools.ietf.org/html/rfc1035)
25///
26/// ```text
27/// 4.1. Format
28///
29/// All communications inside of the domain protocol are carried in a single
30/// format called a message.  The top level format of message is divided
31/// into 5 sections (some of which are empty in certain cases) shown below:
32///
33///     +--------------------------+
34///     |        Header            |
35///     +--------------------------+
36///     |  Question / Zone         | the question for the name server
37///     +--------------------------+
38///     |   Answer  / Prerequisite | RRs answering the question
39///     +--------------------------+
40///     | Authority / Update       | RRs pointing toward an authority
41///     +--------------------------+
42///     |      Additional          | RRs holding additional information
43///     +--------------------------+
44///
45/// The header section is always present.  The header includes fields that
46/// specify which of the remaining sections are present, and also specify
47/// whether the message is a query or a response, a standard query or some
48/// other opcode, etc.
49///
50/// The names of the sections after the header are derived from their use in
51/// standard queries.  The question section contains fields that describe a
52/// question to a name server.  These fields are a query type (QTYPE), a
53/// query class (QCLASS), and a query domain name (QNAME).  The last three
54/// sections have the same format: a possibly empty list of concatenated
55/// resource records (RRs).  The answer section contains RRs that answer the
56/// question; the authority section contains RRs that point toward an
57/// authoritative name server; the additional records section contains RRs
58/// which relate to the query, but are not strictly answers for the
59/// question.
60/// ```
61///
62/// By default Message is a Query. Use the Message::as_update() to create and update, or
63///  Message::new_update()
64#[derive(Clone, Debug, PartialEq, Eq, Default)]
65pub struct Message {
66    header: Header,
67    queries: Vec<Query>,
68    answers: Vec<Record>,
69    name_servers: Vec<Record>,
70    additionals: Vec<Record>,
71    signature: Vec<Record>,
72    edns: Option<Edns>,
73}
74
75/// Returns a new Header with accurate counts for each Message section
76pub fn update_header_counts(
77    current_header: &Header,
78    is_truncated: bool,
79    counts: HeaderCounts,
80) -> Header {
81    assert!(counts.query_count <= u16::max_value() as usize);
82    assert!(counts.answer_count <= u16::max_value() as usize);
83    assert!(counts.nameserver_count <= u16::max_value() as usize);
84    assert!(counts.additional_count <= u16::max_value() as usize);
85
86    // TODO: should the function just take by value?
87    let mut header = *current_header;
88    header
89        .set_query_count(counts.query_count as u16)
90        .set_answer_count(counts.answer_count as u16)
91        .set_name_server_count(counts.nameserver_count as u16)
92        .set_additional_count(counts.additional_count as u16)
93        .set_truncated(is_truncated);
94
95    header
96}
97
98/// Tracks the counts of the records in the Message.
99///
100/// This is only used internally during serialization.
101#[derive(Clone, Copy, Debug)]
102pub struct HeaderCounts {
103    /// The number of queries in the Message
104    pub query_count: usize,
105    /// The number of answers in the Message
106    pub answer_count: usize,
107    /// The number of nameservers or authorities in the Message
108    pub nameserver_count: usize,
109    /// The number of additional records in the Message
110    pub additional_count: usize,
111}
112
113impl Message {
114    /// Returns a new "empty" Message
115    pub fn new() -> Self {
116        Self {
117            header: Header::new(),
118            queries: Vec::new(),
119            answers: Vec::new(),
120            name_servers: Vec::new(),
121            additionals: Vec::new(),
122            signature: Vec::new(),
123            edns: None,
124        }
125    }
126
127    /// Returns a Message constructed with error details to return to a client
128    ///
129    /// # Arguments
130    ///
131    /// * `id` - message id should match the request message id
132    /// * `op_code` - operation of the request
133    /// * `response_code` - the error code for the response
134    pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
135        let mut message = Self::new();
136        message
137            .set_message_type(MessageType::Response)
138            .set_id(id)
139            .set_response_code(response_code)
140            .set_op_code(op_code);
141
142        message
143    }
144
145    /// Truncates a Message, this blindly removes all response fields and sets truncated to `true`
146    pub fn truncate(&self) -> Self {
147        let mut truncated = self.clone();
148        truncated.set_truncated(true);
149        // drops additional/answer/queries so len is 0
150        truncated.take_additionals();
151        truncated.take_answers();
152        truncated.take_queries();
153
154        // TODO, perhaps just quickly add a few response records here? that we know would fit?
155        truncated
156    }
157
158    /// Sets the `Header` with provided
159    pub fn set_header(&mut self, header: Header) -> &mut Self {
160        self.header = header;
161        self
162    }
163
164    /// see `Header::set_id`
165    pub fn set_id(&mut self, id: u16) -> &mut Self {
166        self.header.set_id(id);
167        self
168    }
169
170    /// see `Header::set_message_type`
171    pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
172        self.header.set_message_type(message_type);
173        self
174    }
175
176    /// see `Header::set_op_code`
177    pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
178        self.header.set_op_code(op_code);
179        self
180    }
181
182    /// see `Header::set_authoritative`
183    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
184        self.header.set_authoritative(authoritative);
185        self
186    }
187
188    /// see `Header::set_truncated`
189    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
190        self.header.set_truncated(truncated);
191        self
192    }
193
194    /// see `Header::set_recursion_desired`
195    pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
196        self.header.set_recursion_desired(recursion_desired);
197        self
198    }
199
200    /// see `Header::set_recursion_available`
201    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
202        self.header.set_recursion_available(recursion_available);
203        self
204    }
205
206    /// see `Header::set_authentic_data`
207    pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
208        self.header.set_authentic_data(authentic_data);
209        self
210    }
211
212    /// see `Header::set_checking_disabled`
213    pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
214        self.header.set_checking_disabled(checking_disabled);
215        self
216    }
217
218    /// see `Header::set_response_code`
219    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
220        self.header.set_response_code(response_code);
221        self
222    }
223
224    /// Add a query to the Message, either the query response from the server, or the request Query.
225    pub fn add_query(&mut self, query: Query) -> &mut Self {
226        self.queries.push(query);
227        self
228    }
229
230    /// Adds an iterator over a set of Queries to be added to the message
231    pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
232    where
233        Q: IntoIterator<Item = Query, IntoIter = I>,
234        I: Iterator<Item = Query>,
235    {
236        for query in queries {
237            self.add_query(query);
238        }
239
240        self
241    }
242
243    /// Add an answer to the Message
244    pub fn add_answer(&mut self, record: Record) -> &mut Self {
245        self.answers.push(record);
246        self
247    }
248
249    /// Add all the records from the iterator to the answers section of the Message
250    pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
251    where
252        R: IntoIterator<Item = Record, IntoIter = I>,
253        I: Iterator<Item = Record>,
254    {
255        for record in records {
256            self.add_answer(record);
257        }
258
259        self
260    }
261
262    /// Sets the answers to the specified set of Records.
263    ///
264    /// # Panics
265    ///
266    /// Will panic if answer records are already associated to the message.
267    pub fn insert_answers(&mut self, records: Vec<Record>) {
268        assert!(self.answers.is_empty());
269        self.answers = records;
270    }
271
272    /// Add a name server record to the Message
273    pub fn add_name_server(&mut self, record: Record) -> &mut Self {
274        self.name_servers.push(record);
275        self
276    }
277
278    /// Add all the records in the Iterator to the name server section of the message
279    pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
280    where
281        R: IntoIterator<Item = Record, IntoIter = I>,
282        I: Iterator<Item = Record>,
283    {
284        for record in records {
285            self.add_name_server(record);
286        }
287
288        self
289    }
290
291    /// Sets the name_servers to the specified set of Records.
292    ///
293    /// # Panics
294    ///
295    /// Will panic if name_servers records are already associated to the message.
296    pub fn insert_name_servers(&mut self, records: Vec<Record>) {
297        assert!(self.name_servers.is_empty());
298        self.name_servers = records;
299    }
300
301    /// Add an additional Record to the message
302    pub fn add_additional(&mut self, record: Record) -> &mut Self {
303        self.additionals.push(record);
304        self
305    }
306
307    /// Add all the records from the iterator to the additionals section of the Message
308    pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
309    where
310        R: IntoIterator<Item = Record, IntoIter = I>,
311        I: Iterator<Item = Record>,
312    {
313        for record in records {
314            self.add_additional(record);
315        }
316
317        self
318    }
319
320    /// Sets the additional to the specified set of Records.
321    ///
322    /// # Panics
323    ///
324    /// Will panic if additional records are already associated to the message.
325    pub fn insert_additionals(&mut self, records: Vec<Record>) {
326        assert!(self.additionals.is_empty());
327        self.additionals = records;
328    }
329
330    /// Add the EDNS section to the Message
331    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
332        self.edns = Some(edns);
333        self
334    }
335
336    /// Add a SIG0 record, i.e. sign this message
337    ///
338    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
339    #[cfg(feature = "dnssec")]
340    #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
341    pub fn add_sig0(&mut self, record: Record) -> &mut Self {
342        assert_eq!(RecordType::SIG, record.rr_type());
343        self.signature.push(record);
344        self
345    }
346
347    /// Add a TSIG record, i.e. authenticate this message
348    ///
349    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
350    #[cfg(feature = "dnssec")]
351    #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
352    pub fn add_tsig(&mut self, record: Record) -> &mut Self {
353        assert_eq!(RecordType::TSIG, record.rr_type());
354        self.signature.push(record);
355        self
356    }
357
358    /// Gets the header of the Message
359    pub fn header(&self) -> &Header {
360        &self.header
361    }
362
363    /// see `Header::id()`
364    pub fn id(&self) -> u16 {
365        self.header.id()
366    }
367
368    /// see `Header::message_type()`
369    pub fn message_type(&self) -> MessageType {
370        self.header.message_type()
371    }
372
373    /// see `Header::op_code()`
374    pub fn op_code(&self) -> OpCode {
375        self.header.op_code()
376    }
377
378    /// see `Header::authoritative()`
379    pub fn authoritative(&self) -> bool {
380        self.header.authoritative()
381    }
382
383    /// see `Header::truncated()`
384    pub fn truncated(&self) -> bool {
385        self.header.truncated()
386    }
387
388    /// see `Header::recursion_desired()`
389    pub fn recursion_desired(&self) -> bool {
390        self.header.recursion_desired()
391    }
392
393    /// see `Header::recursion_available()`
394    pub fn recursion_available(&self) -> bool {
395        self.header.recursion_available()
396    }
397
398    /// see `Header::authentic_data()`
399    pub fn authentic_data(&self) -> bool {
400        self.header.authentic_data()
401    }
402
403    /// see `Header::checking_disabled()`
404    pub fn checking_disabled(&self) -> bool {
405        self.header.checking_disabled()
406    }
407
408    /// # Return value
409    ///
410    /// The `ResponseCode`, if this is an EDNS message then this will join the section from the OPT
411    ///  record to create the EDNS `ResponseCode`
412    pub fn response_code(&self) -> ResponseCode {
413        self.header.response_code()
414    }
415
416    /// Returns the query from this Message.
417    ///
418    /// In almost all cases, a Message will only contain one query. This is a convenience function to get the single query.
419    /// See the alternative `queries*` methods for the raw set of queries in the Message
420    pub fn query(&self) -> Option<&Query> {
421        self.queries.first()
422    }
423
424    /// ```text
425    /// Question        Carries the query name and other query parameters.
426    /// ```
427    pub fn queries(&self) -> &[Query] {
428        &self.queries
429    }
430
431    /// Provides mutable access to `queries`
432    pub fn queries_mut(&mut self) -> &mut Vec<Query> {
433        &mut self.queries
434    }
435
436    /// Removes all the answers from the Message
437    pub fn take_queries(&mut self) -> Vec<Query> {
438        mem::take(&mut self.queries)
439    }
440
441    /// ```text
442    /// Answer          Carries RRs which directly answer the query.
443    /// ```
444    pub fn answers(&self) -> &[Record] {
445        &self.answers
446    }
447
448    /// Provides mutable access to `answers`
449    pub fn answers_mut(&mut self) -> &mut Vec<Record> {
450        &mut self.answers
451    }
452
453    /// Removes all the answers from the Message
454    pub fn take_answers(&mut self) -> Vec<Record> {
455        mem::take(&mut self.answers)
456    }
457
458    /// ```text
459    /// Authority       Carries RRs which describe other authoritative servers.
460    ///                 May optionally carry the SOA RR for the authoritative
461    ///                 data in the answer section.
462    /// ```
463    pub fn name_servers(&self) -> &[Record] {
464        &self.name_servers
465    }
466
467    /// Provides mutable access to `name_servers`
468    pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
469        &mut self.name_servers
470    }
471
472    /// Remove the name servers from the Message
473    pub fn take_name_servers(&mut self) -> Vec<Record> {
474        mem::take(&mut self.name_servers)
475    }
476
477    /// ```text
478    /// Additional      Carries RRs which may be helpful in using the RRs in the
479    ///                 other sections.
480    /// ```
481    pub fn additionals(&self) -> &[Record] {
482        &self.additionals
483    }
484
485    /// Provides mutable access to `additionals`
486    pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
487        &mut self.additionals
488    }
489
490    /// Remove the additional Records from the Message
491    pub fn take_additionals(&mut self) -> Vec<Record> {
492        mem::take(&mut self.additionals)
493    }
494
495    /// All sections chained
496    pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
497        self.answers
498            .iter()
499            .chain(self.name_servers().iter())
500            .chain(self.additionals.iter())
501    }
502
503    /// [RFC 6891, EDNS(0) Extensions, April 2013](https://tools.ietf.org/html/rfc6891#section-6.1.1)
504    ///
505    /// ```text
506    /// 6.1.1.  Basic Elements
507    ///
508    ///  An OPT pseudo-RR (sometimes called a meta-RR) MAY be added to the
509    ///  additional data section of a request.
510    ///
511    ///  The OPT RR has RR type 41.
512    ///
513    ///  If an OPT record is present in a received request, compliant
514    ///  responders MUST include an OPT record in their respective responses.
515    ///
516    ///  An OPT record does not carry any DNS data.  It is used only to
517    ///  contain control information pertaining to the question-and-answer
518    ///  sequence of a specific transaction.  OPT RRs MUST NOT be cached,
519    ///  forwarded, or stored in or loaded from Zone Files.
520    ///
521    ///  The OPT RR MAY be placed anywhere within the additional data section.
522    ///  When an OPT RR is included within any DNS message, it MUST be the
523    ///  only OPT RR in that message.  If a query message with more than one
524    ///  OPT RR is received, a FORMERR (RCODE=1) MUST be returned.  The
525    ///  placement flexibility for the OPT RR does not override the need for
526    ///  the TSIG or SIG(0) RRs to be the last in the additional section
527    ///  whenever they are present.
528    /// ```
529    /// # Return value
530    ///
531    /// Optionally returns a reference to EDNS section
532    #[deprecated(note = "Please use `extensions()`")]
533    pub fn edns(&self) -> Option<&Edns> {
534        self.edns.as_ref()
535    }
536
537    /// Optionally returns mutable reference to EDNS section
538    #[deprecated(
539        note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
540    )]
541    pub fn edns_mut(&mut self) -> &mut Edns {
542        if self.edns.is_none() {
543            self.set_edns(Edns::new());
544        }
545        self.edns.as_mut().unwrap()
546    }
547
548    /// Returns reference of Edns section
549    pub fn extensions(&self) -> &Option<Edns> {
550        &self.edns
551    }
552
553    /// Returns mutable reference of Edns section
554    pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
555        &mut self.edns
556    }
557
558    /// # Return value
559    ///
560    /// the max payload value as it's defined in the EDNS section.
561    pub fn max_payload(&self) -> u16 {
562        let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
563        if max_size < 512 {
564            512
565        } else {
566            max_size
567        }
568    }
569
570    /// # Return value
571    ///
572    /// the version as defined in the EDNS record
573    pub fn version(&self) -> u8 {
574        self.edns.as_ref().map_or(0, Edns::version)
575    }
576
577    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
578    ///
579    /// ```text
580    /// A DNS request may be optionally signed by including one or more SIGs
581    ///  at the end of the query. Such SIGs are identified by having a "type
582    ///  covered" field of zero. They sign the preceding DNS request message
583    ///  including DNS header but not including the IP header or any request
584    ///  SIGs at the end and before the request RR counts have been adjusted
585    ///  for the inclusions of any request SIG(s).
586    /// ```
587    ///
588    /// # Return value
589    ///
590    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
591    // comportment change: can now return TSIG instead of SIG0. Maybe should get deprecated in
592    // favor of signature() which have more correct naming ?
593    pub fn sig0(&self) -> &[Record] {
594        &self.signature
595    }
596
597    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
598    ///
599    /// ```text
600    /// A DNS request may be optionally signed by including one or more SIGs
601    ///  at the end of the query. Such SIGs are identified by having a "type
602    ///  covered" field of zero. They sign the preceding DNS request message
603    ///  including DNS header but not including the IP header or any request
604    ///  SIGs at the end and before the request RR counts have been adjusted
605    ///  for the inclusions of any request SIG(s).
606    /// ```
607    ///
608    /// # Return value
609    ///
610    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
611    pub fn signature(&self) -> &[Record] {
612        &self.signature
613    }
614
615    /// Remove signatures from the Message
616    pub fn take_signature(&mut self) -> Vec<Record> {
617        mem::take(&mut self.signature)
618    }
619
620    // TODO: only necessary in tests, should it be removed?
621    /// this is necessary to match the counts in the header from the record sections
622    ///  this happens implicitly on write_to, so no need to call before write_to
623    #[cfg(test)]
624    pub fn update_counts(&mut self) -> &mut Self {
625        self.header = update_header_counts(
626            &self.header,
627            self.truncated(),
628            HeaderCounts {
629                query_count: self.queries.len(),
630                answer_count: self.answers.len(),
631                nameserver_count: self.name_servers.len(),
632                additional_count: self.additionals.len(),
633            },
634        );
635        self
636    }
637
638    /// Attempts to read the specified number of `Query`s
639    pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
640        let mut queries = Vec::with_capacity(count);
641        for _ in 0..count {
642            queries.push(Query::read(decoder)?);
643        }
644        Ok(queries)
645    }
646
647    /// Attempts to read the specified number of records
648    ///
649    /// # Returns
650    ///
651    /// This returns a tuple of first standard Records, then a possibly associated Edns, and then finally any optionally associated SIG0 and TSIG records.
652    #[cfg_attr(not(feature = "dnssec"), allow(unused_mut))]
653    pub fn read_records(
654        decoder: &mut BinDecoder<'_>,
655        count: usize,
656        is_additional: bool,
657    ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
658        let mut records: Vec<Record> = Vec::with_capacity(count);
659        let mut edns: Option<Edns> = None;
660        let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
661
662        // sig0 must be last, once this is set, disable.
663        let mut saw_sig0 = false;
664        // tsig must be last, once this is set, disable.
665        let mut saw_tsig = false;
666        for _ in 0..count {
667            let record = Record::read(decoder)?;
668            if saw_tsig {
669                return Err("tsig must be final resource record".into());
670            } // TSIG must be last and multiple TSIG records are not allowed
671            if !is_additional {
672                if saw_sig0 {
673                    return Err("sig0 must be final resource record".into());
674                } // SIG0 must be last
675                records.push(record)
676            } else {
677                match record.rr_type() {
678                    #[cfg(feature = "dnssec")]
679                    RecordType::SIG => {
680                        saw_sig0 = true;
681                        sigs.push(record);
682                    }
683                    #[cfg(feature = "dnssec")]
684                    RecordType::TSIG => {
685                        if saw_sig0 {
686                            return Err("sig0 must be final resource record".into());
687                        } // SIG0 must be last
688                        saw_tsig = true;
689                        sigs.push(record);
690                    }
691                    RecordType::OPT => {
692                        if saw_sig0 {
693                            return Err("sig0 must be final resource record".into());
694                        } // SIG0 must be last
695                        if edns.is_some() {
696                            return Err("more than one edns record present".into());
697                        }
698                        edns = Some((&record).into());
699                    }
700                    _ => {
701                        if saw_sig0 {
702                            return Err("sig0 must be final resource record".into());
703                        } // SIG0 must be last
704                        records.push(record);
705                    }
706                }
707            }
708        }
709
710        Ok((records, edns, sigs))
711    }
712
713    /// Decodes a message from the buffer.
714    pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
715        let mut decoder = BinDecoder::new(buffer);
716        Self::read(&mut decoder)
717    }
718
719    /// Encodes the Message into a buffer
720    pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
721        // TODO: this feels like the right place to verify the max packet size of the message,
722        //  will need to update the header for truncation and the lengths if we send less than the
723        //  full response. This needs to conform with the EDNS settings of the server...
724        let mut buffer = Vec::with_capacity(512);
725        {
726            let mut encoder = BinEncoder::new(&mut buffer);
727            self.emit(&mut encoder)?;
728        }
729
730        Ok(buffer)
731    }
732
733    /// Finalize the message prior to sending.
734    ///
735    /// Subsequent to calling this, the Message should not change.
736    #[allow(clippy::match_single_binding)]
737    pub fn finalize<MF: MessageFinalizer>(
738        &mut self,
739        finalizer: &MF,
740        inception_time: u32,
741    ) -> ProtoResult<Option<MessageVerifier>> {
742        debug!("finalizing message: {:?}", self);
743        let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
744            finalizer.finalize_message(self, inception_time)?;
745
746        // append all records to message
747        for fin in finals {
748            match fin.rr_type() {
749                // SIG0's are special, and come at the very end of the message
750                #[cfg(feature = "dnssec")]
751                RecordType::SIG => self.add_sig0(fin),
752                #[cfg(feature = "dnssec")]
753                RecordType::TSIG => self.add_tsig(fin),
754                _ => self.add_additional(fin),
755            };
756        }
757
758        Ok(verifier)
759    }
760
761    /// Consumes `Message` and returns into components
762    pub fn into_parts(self) -> MessageParts {
763        self.into()
764    }
765}
766
767/// Consumes `Message` giving public access to fields in `Message` so they can be
768/// destructured and taken by value
769/// ```rust
770/// use trust_dns_proto::op::{Message, MessageParts};
771///
772///  let msg = Message::new();
773///  let MessageParts { queries, .. } = msg.into_parts();
774/// ```
775#[derive(Clone, Debug, PartialEq, Eq, Default)]
776pub struct MessageParts {
777    /// message header
778    pub header: Header,
779    /// message queries
780    pub queries: Vec<Query>,
781    /// message answers
782    pub answers: Vec<Record>,
783    /// message name_servers
784    pub name_servers: Vec<Record>,
785    /// message additional records
786    pub additionals: Vec<Record>,
787    /// sig0 or tsig
788    // this can now contains TSIG too. It should probably be renamed to reflect that, but it's a
789    // breaking change
790    pub sig0: Vec<Record>,
791    /// optional edns records
792    pub edns: Option<Edns>,
793}
794
795impl From<Message> for MessageParts {
796    fn from(msg: Message) -> Self {
797        let Message {
798            header,
799            queries,
800            answers,
801            name_servers,
802            additionals,
803            signature,
804            edns,
805        } = msg;
806        Self {
807            header,
808            queries,
809            answers,
810            name_servers,
811            additionals,
812            sig0: signature,
813            edns,
814        }
815    }
816}
817
818impl From<MessageParts> for Message {
819    fn from(msg: MessageParts) -> Self {
820        let MessageParts {
821            header,
822            queries,
823            answers,
824            name_servers,
825            additionals,
826            sig0,
827            edns,
828        } = msg;
829        Self {
830            header,
831            queries,
832            answers,
833            name_servers,
834            additionals,
835            signature: sig0,
836            edns,
837        }
838    }
839}
840
841impl Deref for Message {
842    type Target = Header;
843
844    fn deref(&self) -> &Self::Target {
845        &self.header
846    }
847}
848
849/// Alias for a function verifying if a message is properly signed
850pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
851
852/// A trait for performing final amendments to a Message before it is sent.
853///
854/// An example of this is a SIG0 signer, which needs the final form of the message,
855///  but then needs to attach additional data to the body of the message.
856pub trait MessageFinalizer: Send + Sync + 'static {
857    /// The message taken in should be processed and then return [`Record`]s which should be
858    ///  appended to the additional section of the message.
859    ///
860    /// # Arguments
861    ///
862    /// * `message` - message to process
863    /// * `current_time` - the current time as specified by the system, it's not recommended to read the current time as that makes testing complicated.
864    ///
865    /// # Return
866    ///
867    /// A vector to append to the additionals section of the message, sorted in the order as they should appear in the message.
868    fn finalize_message(
869        &self,
870        message: &Message,
871        current_time: u32,
872    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
873
874    /// Return whether the message require futher processing before being sent
875    /// By default, returns true for AXFR and IXFR queries, and Update and Notify messages
876    fn should_finalize_message(&self, message: &Message) -> bool {
877        [OpCode::Update, OpCode::Notify].contains(&message.op_code())
878            || message
879                .queries()
880                .iter()
881                .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
882    }
883}
884
885/// A MessageFinalizer which does nothing
886///
887/// *WARNING* This should only be used in None context, it will panic in all cases where finalize is called.
888#[derive(Clone, Copy, Debug)]
889pub struct NoopMessageFinalizer;
890
891impl NoopMessageFinalizer {
892    /// Always returns None
893    pub fn new() -> Option<Arc<Self>> {
894        None
895    }
896}
897
898impl MessageFinalizer for NoopMessageFinalizer {
899    fn finalize_message(
900        &self,
901        _: &Message,
902        _: u32,
903    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)> {
904        panic!("Misused NoopMessageFinalizer, None should be used instead")
905    }
906
907    fn should_finalize_message(&self, _: &Message) -> bool {
908        true
909    }
910}
911
912/// Returns the count written and a boolean if it was truncated
913pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
914    result.map(|count| (count, false)).or_else(|e| {
915        if let ProtoErrorKind::NotAllRecordsWritten { count } = e.kind() {
916            return Ok((*count, true));
917        }
918
919        Err(e)
920    })
921}
922
923/// A trait that defines types which can be emitted as a set, with the associated count returned.
924pub trait EmitAndCount {
925    /// Emit self to the encoder and return the count of items
926    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
927}
928
929impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
930    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
931        encoder.emit_all(self)
932    }
933}
934
935/// Emits the different sections of a message properly
936///
937/// # Return
938///
939/// In the case of a successful emit, the final header (updated counts, etc) is returned for help with logging, etc.
940#[allow(clippy::too_many_arguments)]
941pub fn emit_message_parts<Q, A, N, D>(
942    header: &Header,
943    queries: &mut Q,
944    answers: &mut A,
945    name_servers: &mut N,
946    additionals: &mut D,
947    edns: Option<&Edns>,
948    signature: &[Record],
949    encoder: &mut BinEncoder<'_>,
950) -> ProtoResult<Header>
951where
952    Q: EmitAndCount,
953    A: EmitAndCount,
954    N: EmitAndCount,
955    D: EmitAndCount,
956{
957    let include_signature: bool = encoder.mode() != EncodeMode::Signing;
958    let place = encoder.place::<Header>()?;
959
960    let query_count = queries.emit(encoder)?;
961    // TODO: need to do something on max records
962    //  return offset of last emitted record.
963    let answer_count = count_was_truncated(answers.emit(encoder))?;
964    let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
965    let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
966
967    if let Some(mut edns) = edns.cloned() {
968        // need to commit the error code
969        edns.set_rcode_high(header.response_code().high());
970
971        let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
972        additional_count.0 += count.0;
973        additional_count.1 |= count.1;
974    } else if header.response_code().high() > 0 {
975        warn!(
976            "response code: {} for request: {} requires EDNS but none available",
977            header.response_code(),
978            header.id()
979        );
980    }
981
982    // this is a little hacky, but if we are Verifying a signature, i.e. the original Message
983    //  then the SIG0 records should not be encoded and the edns record (if it exists) is already
984    //  part of the additionals section.
985    if include_signature {
986        let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
987        additional_count.0 += count.0;
988        additional_count.1 |= count.1;
989    }
990
991    let counts = HeaderCounts {
992        query_count,
993        answer_count: answer_count.0,
994        nameserver_count: nameserver_count.0,
995        additional_count: additional_count.0,
996    };
997    let was_truncated =
998        header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
999
1000    let final_header = update_header_counts(header, was_truncated, counts);
1001    place.replace(encoder, final_header)?;
1002    Ok(final_header)
1003}
1004
1005impl BinEncodable for Message {
1006    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1007        emit_message_parts(
1008            &self.header,
1009            &mut self.queries.iter(),
1010            &mut self.answers.iter(),
1011            &mut self.name_servers.iter(),
1012            &mut self.additionals.iter(),
1013            self.edns.as_ref(),
1014            &self.signature,
1015            encoder,
1016        )?;
1017
1018        Ok(())
1019    }
1020}
1021
1022impl<'r> BinDecodable<'r> for Message {
1023    fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1024        let mut header = Header::read(decoder)?;
1025
1026        // TODO: return just header, and in the case of the rest of message getting an error.
1027        //  this could improve error detection while decoding.
1028
1029        // get the questions
1030        let count = header.query_count() as usize;
1031        let mut queries = Vec::with_capacity(count);
1032        for _ in 0..count {
1033            queries.push(Query::read(decoder)?);
1034        }
1035
1036        // get all counts before header moves
1037        let answer_count = header.answer_count() as usize;
1038        let name_server_count = header.name_server_count() as usize;
1039        let additional_count = header.additional_count() as usize;
1040
1041        let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1042        let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1043        let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1044
1045        // need to grab error code from EDNS (which might have a higher value)
1046        if let Some(edns) = &edns {
1047            let high_response_code = edns.rcode_high();
1048            header.merge_response_code(high_response_code);
1049        }
1050
1051        Ok(Self {
1052            header,
1053            queries,
1054            answers,
1055            name_servers,
1056            additionals,
1057            signature,
1058            edns,
1059        })
1060    }
1061}
1062
1063impl fmt::Display for Message {
1064    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1065        let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1066            for d in slice {
1067                writeln!(f, ";; {}", d)?;
1068            }
1069
1070            Ok(())
1071        };
1072
1073        let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1074            for d in slice {
1075                writeln!(f, "{}", d)?;
1076            }
1077
1078            Ok(())
1079        };
1080
1081        writeln!(f, "; header {header}", header = self.header())?;
1082
1083        if let Some(edns) = self.extensions() {
1084            writeln!(f, "; edns {}", edns)?;
1085        }
1086
1087        writeln!(f, "; query")?;
1088        write_query(self.queries(), f)?;
1089
1090        if self.header().message_type() == MessageType::Response
1091            || self.header().op_code() == OpCode::Update
1092        {
1093            writeln!(f, "; answers {}", self.answer_count())?;
1094            write_slice(self.answers(), f)?;
1095            writeln!(f, "; nameservers {}", self.name_server_count())?;
1096            write_slice(self.name_servers(), f)?;
1097            writeln!(f, "; additionals {}", self.additional_count())?;
1098            write_slice(self.additionals(), f)?;
1099        }
1100
1101        Ok(())
1102    }
1103}
1104
1105#[test]
1106fn test_emit_and_read_header() {
1107    let mut message = Message::new();
1108    message
1109        .set_id(10)
1110        .set_message_type(MessageType::Response)
1111        .set_op_code(OpCode::Update)
1112        .set_authoritative(true)
1113        .set_truncated(false)
1114        .set_recursion_desired(true)
1115        .set_recursion_available(true)
1116        .set_response_code(ResponseCode::ServFail);
1117
1118    test_emit_and_read(message);
1119}
1120
1121#[test]
1122fn test_emit_and_read_query() {
1123    let mut message = Message::new();
1124    message
1125        .set_id(10)
1126        .set_message_type(MessageType::Response)
1127        .set_op_code(OpCode::Update)
1128        .set_authoritative(true)
1129        .set_truncated(true)
1130        .set_recursion_desired(true)
1131        .set_recursion_available(true)
1132        .set_response_code(ResponseCode::ServFail)
1133        .add_query(Query::new())
1134        .update_counts(); // we're not testing the query parsing, just message
1135
1136    test_emit_and_read(message);
1137}
1138
1139#[test]
1140fn test_emit_and_read_records() {
1141    let mut message = Message::new();
1142    message
1143        .set_id(10)
1144        .set_message_type(MessageType::Response)
1145        .set_op_code(OpCode::Update)
1146        .set_authoritative(true)
1147        .set_truncated(true)
1148        .set_recursion_desired(true)
1149        .set_recursion_available(true)
1150        .set_authentic_data(true)
1151        .set_checking_disabled(true)
1152        .set_response_code(ResponseCode::ServFail);
1153
1154    message.add_answer(Record::new());
1155    message.add_name_server(Record::new());
1156    message.add_additional(Record::new());
1157    message.update_counts(); // needed for the comparison...
1158
1159    test_emit_and_read(message);
1160}
1161
1162#[cfg(test)]
1163fn test_emit_and_read(message: Message) {
1164    let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1165    {
1166        let mut encoder = BinEncoder::new(&mut byte_vec);
1167        message.emit(&mut encoder).unwrap();
1168    }
1169
1170    let mut decoder = BinDecoder::new(&byte_vec);
1171    let got = Message::read(&mut decoder).unwrap();
1172
1173    assert_eq!(got, message);
1174}
1175
1176#[test]
1177#[rustfmt::skip]
1178fn test_legit_message() {
1179    let buf: Vec<u8> = vec![
1180    0x10,0x00,0x81,0x80, // id = 4096, response, op=query, recursion_desired, recursion_available, no_error
1181    0x00,0x01,0x00,0x01, // 1 query, 1 answer,
1182    0x00,0x00,0x00,0x00, // 0 namesservers, 0 additional record
1183
1184    0x03,b'w',b'w',b'w', // query --- www.example.com
1185    0x07,b'e',b'x',b'a', //
1186    b'm',b'p',b'l',b'e', //
1187    0x03,b'c',b'o',b'm', //
1188    0x00,                // 0 = endname
1189    0x00,0x01,0x00,0x01, // ReordType = A, Class = IN
1190
1191    0xC0,0x0C,           // name pointer to www.example.com
1192    0x00,0x01,0x00,0x01, // RecordType = A, Class = IN
1193    0x00,0x00,0x00,0x02, // TTL = 2 seconds
1194    0x00,0x04,           // record length = 4 (ipv4 address)
1195    0x5D,0xB8,0xD8,0x22, // address = 93.184.216.34
1196    ];
1197
1198    let mut decoder = BinDecoder::new(&buf);
1199    let message = Message::read(&mut decoder).unwrap();
1200
1201    assert_eq!(message.id(), 4096);
1202
1203    let mut buf: Vec<u8> = Vec::with_capacity(512);
1204    {
1205        let mut encoder = BinEncoder::new(&mut buf);
1206        message.emit(&mut encoder).unwrap();
1207    }
1208
1209    let mut decoder = BinDecoder::new(&buf);
1210    let message = Message::read(&mut decoder).unwrap();
1211
1212    assert_eq!(message.id(), 4096);
1213}
1214
1215#[test]
1216fn rdata_zero_roundtrip() {
1217    let buf = &[
1218        160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1219    ];
1220
1221    assert!(Message::from_bytes(buf).is_err());
1222}