1use std::marker::PhantomData;
17
18use crate::error::{ProtoErrorKind, ProtoResult};
19
20use super::BinEncodable;
21use crate::op::Header;
22
23mod private {
25 use crate::error::{ProtoErrorKind, ProtoResult};
26
27 pub(crate) struct MaximalBuf<'a> {
29 max_size: usize,
30 buffer: &'a mut Vec<u8>,
31 }
32
33 impl<'a> MaximalBuf<'a> {
34 pub(crate) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
35 MaximalBuf {
36 max_size: max_size as usize,
37 buffer,
38 }
39 }
40
41 pub(crate) fn set_max_size(&mut self, max: u16) {
43 self.max_size = max as usize;
44 }
45
46 pub(crate) fn enforced_write<F>(&mut self, additional: usize, writer: F) -> ProtoResult<()>
50 where
51 F: FnOnce(&mut Vec<u8>),
52 {
53 let expected_len = self.buffer.len() + additional;
54
55 if expected_len > self.max_size {
56 Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
57 } else {
58 self.buffer.reserve(additional);
59 writer(self.buffer);
60
61 debug_assert_eq!(self.buffer.len(), expected_len);
62 Ok(())
63 }
64 }
65
66 pub(crate) fn truncate(&mut self, len: usize) {
68 self.buffer.truncate(len)
69 }
70
71 pub(crate) fn len(&self) -> usize {
73 self.buffer.len()
74 }
75
76 pub(crate) fn buffer(&'a self) -> &'a [u8] {
78 self.buffer as &'a [u8]
79 }
80
81 pub(crate) fn into_bytes(self) -> &'a Vec<u8> {
83 self.buffer
84 }
85 }
86}
87
88pub struct BinEncoder<'a> {
90 offset: usize,
91 buffer: private::MaximalBuf<'a>,
92 name_pointers: Vec<(usize, Vec<u8>)>,
94 mode: EncodeMode,
95 canonical_names: bool,
96}
97
98impl<'a> BinEncoder<'a> {
99 pub fn new(buf: &'a mut Vec<u8>) -> Self {
101 Self::with_offset(buf, 0, EncodeMode::Normal)
102 }
103
104 pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
110 Self::with_offset(buf, 0, mode)
111 }
112
113 pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
123 if buf.capacity() < 512 {
124 let reserve = 512 - buf.capacity();
125 buf.reserve(reserve);
126 }
127
128 BinEncoder {
129 offset: offset as usize,
130 buffer: private::MaximalBuf::new(u16::max_value(), buf),
132 name_pointers: Vec::new(),
133 mode,
134 canonical_names: false,
135 }
136 }
137
138 pub fn set_max_size(&mut self, max: u16) {
145 self.buffer.set_max_size(max);
146 }
147
148 pub fn into_bytes(self) -> &'a Vec<u8> {
150 self.buffer.into_bytes()
151 }
152
153 pub fn len(&self) -> usize {
155 self.buffer.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
160 self.buffer.buffer().is_empty()
161 }
162
163 pub fn offset(&self) -> usize {
165 self.offset
166 }
167
168 pub fn set_offset(&mut self, offset: usize) {
170 self.offset = offset;
171 }
172
173 pub fn mode(&self) -> EncodeMode {
175 self.mode
176 }
177
178 pub fn set_canonical_names(&mut self, canonical_names: bool) {
180 self.canonical_names = canonical_names;
181 }
182
183 pub fn is_canonical_names(&self) -> bool {
185 self.canonical_names
186 }
187
188 pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
190 &mut self,
191 f: F,
192 ) -> ProtoResult<()> {
193 let was_canonical = self.is_canonical_names();
194 self.set_canonical_names(true);
195
196 let res = f(self);
197 self.set_canonical_names(was_canonical);
198
199 res
200 }
201
202 pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
205 Ok(())
206 }
207
208 pub fn trim(&mut self) {
210 let offset = self.offset;
211 self.buffer.truncate(offset);
212 self.name_pointers.retain(|&(start, _)| start < offset);
213 }
214
215 pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
229 assert!(start < self.offset);
230 assert!(end <= self.buffer.len());
231 &self.buffer.buffer()[start..end]
232 }
233
234 pub fn store_label_pointer(&mut self, start: usize, end: usize) {
239 assert!(start <= (u16::max_value() as usize));
240 assert!(end <= (u16::max_value() as usize));
241 assert!(start <= end);
242 if self.offset < 0x3FFF_usize {
243 self.name_pointers
244 .push((start, self.slice_of(start, end).to_vec())); }
246 }
247
248 pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
250 let search = self.slice_of(start, end);
251
252 for (match_start, matcher) in &self.name_pointers {
253 if matcher.as_slice() == search {
254 assert!(match_start <= &(u16::max_value() as usize));
255 return Some(*match_start as u16);
256 }
257 }
258
259 None
260 }
261
262 pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
264 if self.offset < self.buffer.len() {
265 let offset = self.offset;
266 self.buffer.enforced_write(0, |buffer| {
267 *buffer
268 .get_mut(offset)
269 .expect("could not get index at offset") = b
270 })?;
271 } else {
272 self.buffer.enforced_write(1, |buffer| buffer.push(b))?;
273 }
274 self.offset += 1;
275 Ok(())
276 }
277
278 pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
291 let char_bytes = char_data.as_ref();
292 if char_bytes.len() > 255 {
293 return Err(ProtoErrorKind::CharacterDataTooLong {
294 max: 255,
295 len: char_bytes.len(),
296 }
297 .into());
298 }
299
300 self.emit(char_bytes.len() as u8)?;
302 self.write_slice(char_bytes)
303 }
304
305 pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
307 self.emit(data)
308 }
309
310 pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
312 self.write_slice(&data.to_be_bytes())
313 }
314
315 pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
317 self.write_slice(&data.to_be_bytes())
318 }
319
320 pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
322 self.write_slice(&data.to_be_bytes())
323 }
324
325 fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
326 if self.offset < self.buffer.len() {
328 let offset = self.offset;
329
330 self.buffer.enforced_write(0, |buffer| {
331 let mut offset = offset;
332 for b in data {
333 *buffer
334 .get_mut(offset)
335 .expect("could not get index at offset for slice") = *b;
336 offset += 1;
337 }
338 })?;
339 } else {
340 self.buffer
341 .enforced_write(data.len(), |buffer| buffer.extend_from_slice(data))?;
342 }
343
344 self.offset += data.len();
345
346 Ok(())
347 }
348
349 pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
351 self.write_slice(data)
352 }
353
354 pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
356 &mut self,
357 mut iter: I,
358 ) -> ProtoResult<usize> {
359 self.emit_iter(&mut iter)
360 }
361
362 pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
365 where
366 'e: 'r,
367 I: Iterator<Item = &'r &'e E>,
368 E: 'r + 'e + BinEncodable,
369 {
370 let mut iter = iter.cloned();
371 self.emit_iter(&mut iter)
372 }
373
374 #[allow(clippy::needless_return)]
376 pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
377 &mut self,
378 iter: &mut I,
379 ) -> ProtoResult<usize> {
380 let mut count = 0;
381 for i in iter {
382 let rollback = self.set_rollback();
383 i.emit(self).map_err(|e| {
384 if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
385 rollback.rollback(self);
386 return ProtoErrorKind::NotAllRecordsWritten { count }.into();
387 } else {
388 return e;
389 }
390 })?;
391 count += 1;
392 }
393 Ok(count)
394 }
395
396 pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
398 let index = self.offset;
399 let len = T::size_of();
400
401 self.buffer
403 .enforced_write(len, |buffer| buffer.resize(index + len, 0))?;
404
405 self.offset += len;
407
408 Ok(Place {
409 start_index: index,
410 phantom: PhantomData,
411 })
412 }
413
414 pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
416 (self.offset - place.start_index) - place.size_of()
417 }
418
419 pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
421 let current_index = self.offset;
423
424 assert!(place.start_index < current_index);
427 self.offset = place.start_index;
428
429 let emit_result = data.emit(self);
431
432 assert!((self.offset - place.start_index) == place.size_of());
435
436 self.offset = current_index;
438
439 emit_result
440 }
441
442 fn set_rollback(&self) -> Rollback {
443 Rollback {
444 rollback_index: self.offset(),
445 }
446 }
447}
448
449pub trait EncodedSize: BinEncodable {
453 fn size_of() -> usize;
455}
456
457impl EncodedSize for u16 {
458 fn size_of() -> usize {
459 2
460 }
461}
462
463impl EncodedSize for Header {
464 fn size_of() -> usize {
465 Self::len()
466 }
467}
468
469#[derive(Debug)]
470#[must_use = "data must be written back to the place"]
471pub struct Place<T: EncodedSize> {
472 start_index: usize,
473 phantom: PhantomData<T>,
474}
475
476impl<T: EncodedSize> Place<T> {
477 pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
478 encoder.emit_at(self, data)
479 }
480
481 pub fn size_of(&self) -> usize {
482 T::size_of()
483 }
484}
485
486pub(crate) struct Rollback {
488 rollback_index: usize,
489}
490
491impl Rollback {
492 pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
493 encoder.set_offset(self.rollback_index)
494 }
495}
496
497#[derive(Copy, Clone, Eq, PartialEq)]
500pub enum EncodeMode {
501 Signing,
503 Normal,
505}
506
507#[cfg(test)]
508mod tests {
509 use std::str::FromStr;
510
511 use super::*;
512 use crate::{
513 op::{Message, Query},
514 rr::{rdata::SRV, RData, Record, RecordType},
515 };
516 use crate::{rr::Name, serialize::binary::BinDecoder};
517
518 #[test]
519 fn test_label_compression_regression() {
520 let data: Vec<u8> = vec![
529 154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
530 115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
531 97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
532 0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
533 110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
534 0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
535 ];
536
537 let msg = Message::from_vec(&data).unwrap();
538 msg.to_bytes().unwrap();
539 }
540
541 #[test]
542 fn test_size_of() {
543 assert_eq!(u16::size_of(), 2);
544 }
545
546 #[test]
547 fn test_place() {
548 let mut buf = vec![];
549 {
550 let mut encoder = BinEncoder::new(&mut buf);
551 let place = encoder.place::<u16>().unwrap();
552 assert_eq!(place.size_of(), 2);
553 assert_eq!(encoder.len_since_place(&place), 0);
554
555 encoder.emit(42_u8).expect("failed 0");
556 assert_eq!(encoder.len_since_place(&place), 1);
557
558 encoder.emit(48_u8).expect("failed 1");
559 assert_eq!(encoder.len_since_place(&place), 2);
560
561 place
562 .replace(&mut encoder, 4_u16)
563 .expect("failed to replace");
564 drop(encoder);
565 }
566
567 assert_eq!(buf.len(), 4);
568
569 let mut decoder = BinDecoder::new(&buf);
570 let written = decoder.read_u16().expect("cound not read u16").unverified();
571
572 assert_eq!(written, 4);
573 }
574
575 #[test]
576 fn test_max_size() {
577 let mut buf = vec![];
578 let mut encoder = BinEncoder::new(&mut buf);
579
580 encoder.set_max_size(5);
581 encoder.emit(0).expect("failed to write");
582 encoder.emit(1).expect("failed to write");
583 encoder.emit(2).expect("failed to write");
584 encoder.emit(3).expect("failed to write");
585 encoder.emit(4).expect("failed to write");
586 let error = encoder.emit(5).unwrap_err();
587
588 match *error.kind() {
589 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
590 _ => panic!(),
591 }
592 }
593
594 #[test]
595 fn test_max_size_0() {
596 let mut buf = vec![];
597 let mut encoder = BinEncoder::new(&mut buf);
598
599 encoder.set_max_size(0);
600 let error = encoder.emit(0).unwrap_err();
601
602 match *error.kind() {
603 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
604 _ => panic!(),
605 }
606 }
607
608 #[test]
609 fn test_max_size_place() {
610 let mut buf = vec![];
611 let mut encoder = BinEncoder::new(&mut buf);
612
613 encoder.set_max_size(2);
614 let place = encoder.place::<u16>().expect("place failed");
615 place.replace(&mut encoder, 16).expect("placeback failed");
616
617 let error = encoder.place::<u16>().unwrap_err();
618
619 match *error.kind() {
620 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
621 _ => panic!(),
622 }
623 }
624
625 #[test]
626 fn test_target_compression() {
627 let mut msg = Message::new();
628 msg.add_query(Query::query(
629 Name::from_str("www.google.com.").unwrap(),
630 RecordType::A,
631 ))
632 .add_answer(Record::from_rdata(
633 Name::from_str("www.google.com.").unwrap(),
634 0,
635 RData::SRV(SRV::new(
636 0,
637 0,
638 0,
639 Name::from_str("www.compressme.com").unwrap(),
640 )),
641 ))
642 .add_additional(Record::from_rdata(
643 Name::from_str("www.google.com.").unwrap(),
644 0,
645 RData::SRV(SRV::new(
646 0,
647 0,
648 0,
649 Name::from_str("www.compressme.com").unwrap(),
650 )),
651 ))
652 .add_answer(Record::from_rdata(
654 Name::from_str("www.compressme.com").unwrap(),
655 0,
656 RData::CNAME(Name::from_str("www.foo.com").unwrap()),
657 ));
658
659 let bytes = msg.to_vec().unwrap();
660 assert_eq!(bytes.len(), 130);
662 assert!(Message::from_vec(&bytes).is_ok());
664 }
665}