der/writer/
slice.rs

1//! Slice writer.
2
3use crate::{
4    asn1::*, Encode, EncodeValue, ErrorKind, Header, Length, Result, Tag, TagMode, TagNumber,
5    Tagged, Writer,
6};
7
8/// [`Writer`] which encodes DER into a mutable output byte slice.
9#[derive(Debug)]
10pub struct SliceWriter<'a> {
11    /// Buffer into which DER-encoded message is written
12    bytes: &'a mut [u8],
13
14    /// Has the encoding operation failed?
15    failed: bool,
16
17    /// Total number of bytes written to buffer so far
18    position: Length,
19}
20
21impl<'a> SliceWriter<'a> {
22    /// Create a new encoder with the given byte slice as a backing buffer.
23    pub fn new(bytes: &'a mut [u8]) -> Self {
24        Self {
25            bytes,
26            failed: false,
27            position: Length::ZERO,
28        }
29    }
30
31    /// Encode a value which impls the [`Encode`] trait.
32    pub fn encode<T: Encode>(&mut self, encodable: &T) -> Result<()> {
33        if self.is_failed() {
34            self.error(ErrorKind::Failed)?
35        }
36
37        encodable.encode(self).map_err(|e| {
38            self.failed = true;
39            e.nested(self.position)
40        })
41    }
42
43    /// Return an error with the given [`ErrorKind`], annotating it with
44    /// context about where the error occurred.
45    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
46        self.failed = true;
47        Err(kind.at(self.position))
48    }
49
50    /// Did the decoding operation fail due to an error?
51    pub fn is_failed(&self) -> bool {
52        self.failed
53    }
54
55    /// Finish encoding to the buffer, returning a slice containing the data
56    /// written to the buffer.
57    pub fn finish(self) -> Result<&'a [u8]> {
58        let position = self.position;
59
60        if self.is_failed() {
61            return Err(ErrorKind::Failed.at(position));
62        }
63
64        self.bytes
65            .get(..usize::try_from(position)?)
66            .ok_or_else(|| ErrorKind::Overlength.at(position))
67    }
68
69    /// Encode a `CONTEXT-SPECIFIC` field with the provided tag number and mode.
70    pub fn context_specific<T>(
71        &mut self,
72        tag_number: TagNumber,
73        tag_mode: TagMode,
74        value: &T,
75    ) -> Result<()>
76    where
77        T: EncodeValue + Tagged,
78    {
79        ContextSpecificRef {
80            tag_number,
81            tag_mode,
82            value,
83        }
84        .encode(self)
85    }
86
87    /// Encode an ASN.1 `SEQUENCE` of the given length.
88    ///
89    /// Spawns a nested slice writer which is expected to be exactly the
90    /// specified length upon completion.
91    pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
92    where
93        F: FnOnce(&mut SliceWriter<'_>) -> Result<()>,
94    {
95        Header::new(Tag::Sequence, length).and_then(|header| header.encode(self))?;
96
97        let mut nested_encoder = SliceWriter::new(self.reserve(length)?);
98        f(&mut nested_encoder)?;
99
100        if nested_encoder.finish()?.len() == usize::try_from(length)? {
101            Ok(())
102        } else {
103            self.error(ErrorKind::Length { tag: Tag::Sequence })
104        }
105    }
106
107    /// Reserve a portion of the internal buffer, updating the internal cursor
108    /// position and returning a mutable slice.
109    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
110        if self.is_failed() {
111            return Err(ErrorKind::Failed.at(self.position));
112        }
113
114        let len = len
115            .try_into()
116            .or_else(|_| self.error(ErrorKind::Overflow))?;
117
118        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
119        let slice = self
120            .bytes
121            .get_mut(self.position.try_into()?..end.try_into()?)
122            .ok_or_else(|| ErrorKind::Overlength.at(end))?;
123
124        self.position = end;
125        Ok(slice)
126    }
127}
128
129impl<'a> Writer for SliceWriter<'a> {
130    fn write(&mut self, slice: &[u8]) -> Result<()> {
131        self.reserve(slice.len())?.copy_from_slice(slice);
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::SliceWriter;
139    use crate::{Encode, ErrorKind, Length};
140
141    #[test]
142    fn overlength_message() {
143        let mut buffer = [];
144        let mut writer = SliceWriter::new(&mut buffer);
145        let err = false.encode(&mut writer).err().unwrap();
146        assert_eq!(err.kind(), ErrorKind::Overlength);
147        assert_eq!(err.position(), Some(Length::ONE));
148    }
149}