trust_dns_proto/serialize/binary/
encoder.rs

1/*
2 * Copyright (C) 2015 Benjamin Fry <benjaminfry@me.com>
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16use std::marker::PhantomData;
17
18use crate::error::{ProtoErrorKind, ProtoResult};
19
20use super::BinEncodable;
21use crate::op::Header;
22
23// this is private to make sure there is no accidental access to the inner buffer.
24mod private {
25    use crate::error::{ProtoErrorKind, ProtoResult};
26
27    /// A wrapper for a buffer that guarantees writes never exceed a defined set of bytes
28    pub(crate) struct MaximalBuf<'a> {
29        max_size: usize,
30        buffer: &'a mut Vec<u8>,
31    }
32
33    impl<'a> MaximalBuf<'a> {
34        pub(crate) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
35            MaximalBuf {
36                max_size: max_size as usize,
37                buffer,
38            }
39        }
40
41        /// Sets the maximum size to enforce
42        pub(crate) fn set_max_size(&mut self, max: u16) {
43            self.max_size = max as usize;
44        }
45
46        /// returns an error if the maximum buffer size would be exceeded with the addition number of elements
47        ///
48        /// and reserves the additional space in the buffer
49        pub(crate) fn enforced_write<F>(&mut self, additional: usize, writer: F) -> ProtoResult<()>
50        where
51            F: FnOnce(&mut Vec<u8>),
52        {
53            let expected_len = self.buffer.len() + additional;
54
55            if expected_len > self.max_size {
56                Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
57            } else {
58                self.buffer.reserve(additional);
59                writer(self.buffer);
60
61                debug_assert_eq!(self.buffer.len(), expected_len);
62                Ok(())
63            }
64        }
65
66        /// truncates are always safe
67        pub(crate) fn truncate(&mut self, len: usize) {
68            self.buffer.truncate(len)
69        }
70
71        /// returns the length of the underlying buffer
72        pub(crate) fn len(&self) -> usize {
73            self.buffer.len()
74        }
75
76        /// Immutable reads are always safe
77        pub(crate) fn buffer(&'a self) -> &'a [u8] {
78            self.buffer as &'a [u8]
79        }
80
81        /// Returns a reference to the internal buffer
82        pub(crate) fn into_bytes(self) -> &'a Vec<u8> {
83            self.buffer
84        }
85    }
86}
87
88/// Encode DNS messages and resource record types.
89pub struct BinEncoder<'a> {
90    offset: usize,
91    buffer: private::MaximalBuf<'a>,
92    /// start of label pointers with their labels in fully decompressed form for easy comparison, smallvec here?
93    name_pointers: Vec<(usize, Vec<u8>)>,
94    mode: EncodeMode,
95    canonical_names: bool,
96}
97
98impl<'a> BinEncoder<'a> {
99    /// Create a new encoder with the Vec to fill
100    pub fn new(buf: &'a mut Vec<u8>) -> Self {
101        Self::with_offset(buf, 0, EncodeMode::Normal)
102    }
103
104    /// Specify the mode for encoding
105    ///
106    /// # Arguments
107    ///
108    /// * `mode` - In Signing mode, canonical forms of all data are encoded, otherwise format matches the source form
109    pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
110        Self::with_offset(buf, 0, mode)
111    }
112
113    /// Begins the encoder at the given offset
114    ///
115    /// This is used for pointers. If this encoder is starting at some point further in
116    ///  the sequence of bytes, for the proper offset of the pointer, the offset accounts for that
117    ///  by using the offset to add to the pointer location being written.
118    ///
119    /// # Arguments
120    ///
121    /// * `offset` - index at which to start writing into the buffer
122    pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
123        if buf.capacity() < 512 {
124            let reserve = 512 - buf.capacity();
125            buf.reserve(reserve);
126        }
127
128        BinEncoder {
129            offset: offset as usize,
130            // TODO: add max_size to signature
131            buffer: private::MaximalBuf::new(u16::max_value(), buf),
132            name_pointers: Vec::new(),
133            mode,
134            canonical_names: false,
135        }
136    }
137
138    // TODO: move to constructor (kept for backward compatibility)
139    /// Sets the maximum size of the buffer
140    ///
141    /// DNS message lens must be smaller than u16::max_value due to hard limits in the protocol
142    ///
143    /// *this method will move to the constructor in a future release*
144    pub fn set_max_size(&mut self, max: u16) {
145        self.buffer.set_max_size(max);
146    }
147
148    /// Returns a reference to the internal buffer
149    pub fn into_bytes(self) -> &'a Vec<u8> {
150        self.buffer.into_bytes()
151    }
152
153    /// Returns the length of the buffer
154    pub fn len(&self) -> usize {
155        self.buffer.len()
156    }
157
158    /// Returns `true` if the buffer is empty
159    pub fn is_empty(&self) -> bool {
160        self.buffer.buffer().is_empty()
161    }
162
163    /// Returns the current offset into the buffer
164    pub fn offset(&self) -> usize {
165        self.offset
166    }
167
168    /// sets the current offset to the new offset
169    pub fn set_offset(&mut self, offset: usize) {
170        self.offset = offset;
171    }
172
173    /// Returns the current Encoding mode
174    pub fn mode(&self) -> EncodeMode {
175        self.mode
176    }
177
178    /// If set to true, then names will be written into the buffer in canonical form
179    pub fn set_canonical_names(&mut self, canonical_names: bool) {
180        self.canonical_names = canonical_names;
181    }
182
183    /// Returns true if then encoder is writing in canonical form
184    pub fn is_canonical_names(&self) -> bool {
185        self.canonical_names
186    }
187
188    /// Emit all names in canonical form, useful for <https://tools.ietf.org/html/rfc3597>
189    pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
190        &mut self,
191        f: F,
192    ) -> ProtoResult<()> {
193        let was_canonical = self.is_canonical_names();
194        self.set_canonical_names(true);
195
196        let res = f(self);
197        self.set_canonical_names(was_canonical);
198
199        res
200    }
201
202    // TODO: deprecate this...
203    /// Reserve specified additional length in the internal buffer.
204    pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
205        Ok(())
206    }
207
208    /// trims to the current offset
209    pub fn trim(&mut self) {
210        let offset = self.offset;
211        self.buffer.truncate(offset);
212        self.name_pointers.retain(|&(start, _)| start < offset);
213    }
214
215    // /// returns an error if the maximum buffer size would be exceeded with the addition number of elements
216    // ///
217    // /// and reserves the additional space in the buffer
218    // fn enforce_size(&mut self, additional: usize) -> ProtoResult<()> {
219    //     if (self.buffer.len() + additional) > self.max_size {
220    //         Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
221    //     } else {
222    //         self.reserve(additional);
223    //         Ok(())
224    //     }
225    // }
226
227    /// borrow a slice from the encoder
228    pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
229        assert!(start < self.offset);
230        assert!(end <= self.buffer.len());
231        &self.buffer.buffer()[start..end]
232    }
233
234    /// Stores a label pointer to an already written label
235    ///
236    /// The location is the current position in the buffer
237    ///  implicitly, it is expected that the name will be written to the stream after the current index.
238    pub fn store_label_pointer(&mut self, start: usize, end: usize) {
239        assert!(start <= (u16::max_value() as usize));
240        assert!(end <= (u16::max_value() as usize));
241        assert!(start <= end);
242        if self.offset < 0x3FFF_usize {
243            self.name_pointers
244                .push((start, self.slice_of(start, end).to_vec())); // the next char will be at the len() location
245        }
246    }
247
248    /// Looks up the index of an already written label
249    pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
250        let search = self.slice_of(start, end);
251
252        for (match_start, matcher) in &self.name_pointers {
253            if matcher.as_slice() == search {
254                assert!(match_start <= &(u16::max_value() as usize));
255                return Some(*match_start as u16);
256            }
257        }
258
259        None
260    }
261
262    /// Emit one byte into the buffer
263    pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
264        if self.offset < self.buffer.len() {
265            let offset = self.offset;
266            self.buffer.enforced_write(0, |buffer| {
267                *buffer
268                    .get_mut(offset)
269                    .expect("could not get index at offset") = b
270            })?;
271        } else {
272            self.buffer.enforced_write(1, |buffer| buffer.push(b))?;
273        }
274        self.offset += 1;
275        Ok(())
276    }
277
278    /// matches description from above.
279    ///
280    /// ```
281    /// use trust_dns_proto::serialize::binary::BinEncoder;
282    ///
283    /// let mut bytes: Vec<u8> = Vec::new();
284    /// {
285    ///   let mut encoder: BinEncoder = BinEncoder::new(&mut bytes);
286    ///   encoder.emit_character_data("abc");
287    /// }
288    /// assert_eq!(bytes, vec![3,b'a',b'b',b'c']);
289    /// ```
290    pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
291        let char_bytes = char_data.as_ref();
292        if char_bytes.len() > 255 {
293            return Err(ProtoErrorKind::CharacterDataTooLong {
294                max: 255,
295                len: char_bytes.len(),
296            }
297            .into());
298        }
299
300        // first the length is written
301        self.emit(char_bytes.len() as u8)?;
302        self.write_slice(char_bytes)
303    }
304
305    /// Emit one byte into the buffer
306    pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
307        self.emit(data)
308    }
309
310    /// Writes a u16 in network byte order to the buffer
311    pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
312        self.write_slice(&data.to_be_bytes())
313    }
314
315    /// Writes an i32 in network byte order to the buffer
316    pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
317        self.write_slice(&data.to_be_bytes())
318    }
319
320    /// Writes an u32 in network byte order to the buffer
321    pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
322        self.write_slice(&data.to_be_bytes())
323    }
324
325    fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
326        // replacement case, the necessary space should have been reserved already...
327        if self.offset < self.buffer.len() {
328            let offset = self.offset;
329
330            self.buffer.enforced_write(0, |buffer| {
331                let mut offset = offset;
332                for b in data {
333                    *buffer
334                        .get_mut(offset)
335                        .expect("could not get index at offset for slice") = *b;
336                    offset += 1;
337                }
338            })?;
339        } else {
340            self.buffer
341                .enforced_write(data.len(), |buffer| buffer.extend_from_slice(data))?;
342        }
343
344        self.offset += data.len();
345
346        Ok(())
347    }
348
349    /// Writes the byte slice to the stream
350    pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
351        self.write_slice(data)
352    }
353
354    /// Emits all the elements of an Iterator to the encoder
355    pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
356        &mut self,
357        mut iter: I,
358    ) -> ProtoResult<usize> {
359        self.emit_iter(&mut iter)
360    }
361
362    // TODO: dedup with above emit_all
363    /// Emits all the elements of an Iterator to the encoder
364    pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
365    where
366        'e: 'r,
367        I: Iterator<Item = &'r &'e E>,
368        E: 'r + 'e + BinEncodable,
369    {
370        let mut iter = iter.cloned();
371        self.emit_iter(&mut iter)
372    }
373
374    /// emits all items in the iterator, return the number emitted
375    #[allow(clippy::needless_return)]
376    pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
377        &mut self,
378        iter: &mut I,
379    ) -> ProtoResult<usize> {
380        let mut count = 0;
381        for i in iter {
382            let rollback = self.set_rollback();
383            i.emit(self).map_err(|e| {
384                if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
385                    rollback.rollback(self);
386                    return ProtoErrorKind::NotAllRecordsWritten { count }.into();
387                } else {
388                    return e;
389                }
390            })?;
391            count += 1;
392        }
393        Ok(count)
394    }
395
396    /// capture a location to write back to
397    pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
398        let index = self.offset;
399        let len = T::size_of();
400
401        // resize the buffer
402        self.buffer
403            .enforced_write(len, |buffer| buffer.resize(index + len, 0))?;
404
405        // update the offset
406        self.offset += len;
407
408        Ok(Place {
409            start_index: index,
410            phantom: PhantomData,
411        })
412    }
413
414    /// calculates the length of data written since the place was creating
415    pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
416        (self.offset - place.start_index) - place.size_of()
417    }
418
419    /// write back to a previously captured location
420    pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
421        // preserve current index
422        let current_index = self.offset;
423
424        // reset the current index back to place before writing
425        //   this is an assert because it's programming error for it to be wrong.
426        assert!(place.start_index < current_index);
427        self.offset = place.start_index;
428
429        // emit the data to be written at this place
430        let emit_result = data.emit(self);
431
432        // double check that the current number of bytes were written
433        //   this is an assert because it's programming error for it to be wrong.
434        assert!((self.offset - place.start_index) == place.size_of());
435
436        // reset to original location
437        self.offset = current_index;
438
439        emit_result
440    }
441
442    fn set_rollback(&self) -> Rollback {
443        Rollback {
444            rollback_index: self.offset(),
445        }
446    }
447}
448
449/// A trait to return the size of a type as it will be encoded in DNS
450///
451/// it does not necessarily equal `std::mem::size_of`, though it might, especially for primitives
452pub trait EncodedSize: BinEncodable {
453    /// Return the size in bytes of the
454    fn size_of() -> usize;
455}
456
457impl EncodedSize for u16 {
458    fn size_of() -> usize {
459        2
460    }
461}
462
463impl EncodedSize for Header {
464    fn size_of() -> usize {
465        Self::len()
466    }
467}
468
469#[derive(Debug)]
470#[must_use = "data must be written back to the place"]
471pub struct Place<T: EncodedSize> {
472    start_index: usize,
473    phantom: PhantomData<T>,
474}
475
476impl<T: EncodedSize> Place<T> {
477    pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
478        encoder.emit_at(self, data)
479    }
480
481    pub fn size_of(&self) -> usize {
482        T::size_of()
483    }
484}
485
486/// A type representing a rollback point in a stream
487pub(crate) struct Rollback {
488    rollback_index: usize,
489}
490
491impl Rollback {
492    pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
493        encoder.set_offset(self.rollback_index)
494    }
495}
496
497/// In the Verify mode there maybe some things which are encoded differently, e.g. SIG0 records
498///  should not be included in the additional count and not in the encoded data when in Verify
499#[derive(Copy, Clone, Eq, PartialEq)]
500pub enum EncodeMode {
501    /// In signing mode records are written in canonical form
502    Signing,
503    /// Write records in standard format
504    Normal,
505}
506
507#[cfg(test)]
508mod tests {
509    use std::str::FromStr;
510
511    use super::*;
512    use crate::{
513        op::{Message, Query},
514        rr::{rdata::SRV, RData, Record, RecordType},
515    };
516    use crate::{rr::Name, serialize::binary::BinDecoder};
517
518    #[test]
519    fn test_label_compression_regression() {
520        // https://github.com/bluejekyll/trust-dns/issues/339
521        /*
522        ;; QUESTION SECTION:
523        ;bluedot.is.autonavi.com.gds.alibabadns.com. IN AAAA
524
525        ;; AUTHORITY SECTION:
526        gds.alibabadns.com.     1799    IN      SOA     gdsns1.alibabadns.com. none. 2015080610 1800 600 3600 360
527        */
528        let data: Vec<u8> = vec![
529            154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
530            115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
531            97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
532            0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
533            110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
534            0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
535        ];
536
537        let msg = Message::from_vec(&data).unwrap();
538        msg.to_bytes().unwrap();
539    }
540
541    #[test]
542    fn test_size_of() {
543        assert_eq!(u16::size_of(), 2);
544    }
545
546    #[test]
547    fn test_place() {
548        let mut buf = vec![];
549        {
550            let mut encoder = BinEncoder::new(&mut buf);
551            let place = encoder.place::<u16>().unwrap();
552            assert_eq!(place.size_of(), 2);
553            assert_eq!(encoder.len_since_place(&place), 0);
554
555            encoder.emit(42_u8).expect("failed 0");
556            assert_eq!(encoder.len_since_place(&place), 1);
557
558            encoder.emit(48_u8).expect("failed 1");
559            assert_eq!(encoder.len_since_place(&place), 2);
560
561            place
562                .replace(&mut encoder, 4_u16)
563                .expect("failed to replace");
564            drop(encoder);
565        }
566
567        assert_eq!(buf.len(), 4);
568
569        let mut decoder = BinDecoder::new(&buf);
570        let written = decoder.read_u16().expect("cound not read u16").unverified();
571
572        assert_eq!(written, 4);
573    }
574
575    #[test]
576    fn test_max_size() {
577        let mut buf = vec![];
578        let mut encoder = BinEncoder::new(&mut buf);
579
580        encoder.set_max_size(5);
581        encoder.emit(0).expect("failed to write");
582        encoder.emit(1).expect("failed to write");
583        encoder.emit(2).expect("failed to write");
584        encoder.emit(3).expect("failed to write");
585        encoder.emit(4).expect("failed to write");
586        let error = encoder.emit(5).unwrap_err();
587
588        match *error.kind() {
589            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
590            _ => panic!(),
591        }
592    }
593
594    #[test]
595    fn test_max_size_0() {
596        let mut buf = vec![];
597        let mut encoder = BinEncoder::new(&mut buf);
598
599        encoder.set_max_size(0);
600        let error = encoder.emit(0).unwrap_err();
601
602        match *error.kind() {
603            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
604            _ => panic!(),
605        }
606    }
607
608    #[test]
609    fn test_max_size_place() {
610        let mut buf = vec![];
611        let mut encoder = BinEncoder::new(&mut buf);
612
613        encoder.set_max_size(2);
614        let place = encoder.place::<u16>().expect("place failed");
615        place.replace(&mut encoder, 16).expect("placeback failed");
616
617        let error = encoder.place::<u16>().unwrap_err();
618
619        match *error.kind() {
620            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
621            _ => panic!(),
622        }
623    }
624
625    #[test]
626    fn test_target_compression() {
627        let mut msg = Message::new();
628        msg.add_query(Query::query(
629            Name::from_str("www.google.com.").unwrap(),
630            RecordType::A,
631        ))
632        .add_answer(Record::from_rdata(
633            Name::from_str("www.google.com.").unwrap(),
634            0,
635            RData::SRV(SRV::new(
636                0,
637                0,
638                0,
639                Name::from_str("www.compressme.com").unwrap(),
640            )),
641        ))
642        .add_additional(Record::from_rdata(
643            Name::from_str("www.google.com.").unwrap(),
644            0,
645            RData::SRV(SRV::new(
646                0,
647                0,
648                0,
649                Name::from_str("www.compressme.com").unwrap(),
650            )),
651        ))
652        // name here should use compressed label from target in previous records
653        .add_answer(Record::from_rdata(
654            Name::from_str("www.compressme.com").unwrap(),
655            0,
656            RData::CNAME(Name::from_str("www.foo.com").unwrap()),
657        ));
658
659        let bytes = msg.to_vec().unwrap();
660        // label is compressed pointing to target, would be 145 otherwise
661        assert_eq!(bytes.len(), 130);
662        // check re-serializing
663        assert!(Message::from_vec(&bytes).is_ok());
664    }
665}