const_oid/
encoder.rs

1//! OID encoder with `const` support.
2
3use crate::{
4    arcs::{ARC_MAX_FIRST, ARC_MAX_SECOND},
5    Arc, Error, ObjectIdentifier, Result,
6};
7
8/// BER/DER encoder
9#[derive(Debug)]
10pub(crate) struct Encoder {
11    /// Current state
12    state: State,
13
14    /// Bytes of the OID being encoded in-progress
15    bytes: [u8; ObjectIdentifier::MAX_SIZE],
16
17    /// Current position within the byte buffer
18    cursor: usize,
19}
20
21/// Current state of the encoder
22#[derive(Debug)]
23enum State {
24    /// Initial state - no arcs yet encoded
25    Initial,
26
27    /// First arc parsed
28    FirstArc(Arc),
29
30    /// Encoding base 128 body of the OID
31    Body,
32}
33
34impl Encoder {
35    /// Create a new encoder initialized to an empty default state.
36    pub(crate) const fn new() -> Self {
37        Self {
38            state: State::Initial,
39            bytes: [0u8; ObjectIdentifier::MAX_SIZE],
40            cursor: 0,
41        }
42    }
43
44    /// Extend an existing OID.
45    pub(crate) const fn extend(oid: ObjectIdentifier) -> Self {
46        Self {
47            state: State::Body,
48            bytes: oid.bytes,
49            cursor: oid.length as usize,
50        }
51    }
52
53    /// Encode an [`Arc`] as base 128 into the internal buffer.
54    pub(crate) const fn arc(mut self, arc: Arc) -> Result<Self> {
55        match self.state {
56            State::Initial => {
57                if arc > ARC_MAX_FIRST {
58                    return Err(Error::ArcInvalid { arc });
59                }
60
61                self.state = State::FirstArc(arc);
62                Ok(self)
63            }
64            // Ensured not to overflow by `ARC_MAX_SECOND` check
65            #[allow(clippy::integer_arithmetic)]
66            State::FirstArc(first_arc) => {
67                if arc > ARC_MAX_SECOND {
68                    return Err(Error::ArcInvalid { arc });
69                }
70
71                self.state = State::Body;
72                self.bytes[0] = (first_arc * (ARC_MAX_SECOND + 1)) as u8 + arc as u8;
73                self.cursor = 1;
74                Ok(self)
75            }
76            // TODO(tarcieri): finer-grained overflow safety / checked arithmetic
77            #[allow(clippy::integer_arithmetic)]
78            State::Body => {
79                // Total number of bytes in encoded arc - 1
80                let nbytes = base128_len(arc);
81
82                // Shouldn't overflow on any 16-bit+ architectures
83                if self.cursor + nbytes + 1 >= ObjectIdentifier::MAX_SIZE {
84                    return Err(Error::Length);
85                }
86
87                let new_cursor = self.cursor + nbytes + 1;
88
89                // TODO(tarcieri): use `?` when stable in `const fn`
90                match self.encode_base128_byte(arc, nbytes, false) {
91                    Ok(mut encoder) => {
92                        encoder.cursor = new_cursor;
93                        Ok(encoder)
94                    }
95                    Err(err) => Err(err),
96                }
97            }
98        }
99    }
100
101    /// Finish encoding an OID.
102    pub(crate) const fn finish(self) -> Result<ObjectIdentifier> {
103        if self.cursor >= 2 {
104            Ok(ObjectIdentifier {
105                bytes: self.bytes,
106                length: self.cursor as u8,
107            })
108        } else {
109            Err(Error::NotEnoughArcs)
110        }
111    }
112
113    /// Encode a single byte of a Base 128 value.
114    const fn encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Result<Self> {
115        let mask = if continued { 0b10000000 } else { 0 };
116
117        // Underflow checked by branch
118        #[allow(clippy::integer_arithmetic)]
119        if n > 0x80 {
120            self.bytes[checked_add!(self.cursor, i)] = (n & 0b1111111) as u8 | mask;
121            n >>= 7;
122
123            if i > 0 {
124                self.encode_base128_byte(n, i.saturating_sub(1), true)
125            } else {
126                Err(Error::Base128)
127            }
128        } else {
129            self.bytes[self.cursor] = n as u8 | mask;
130            Ok(self)
131        }
132    }
133}
134
135/// Compute the length - 1 of an arc when encoded in base 128.
136const fn base128_len(arc: Arc) -> usize {
137    match arc {
138        0..=0x7f => 0,
139        0x80..=0x3fff => 1,
140        0x4000..=0x1fffff => 2,
141        0x200000..=0x1fffffff => 3,
142        _ => 4,
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::Encoder;
149    use hex_literal::hex;
150
151    /// OID `1.2.840.10045.2.1` encoded as ASN.1 BER/DER
152    const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201");
153
154    #[test]
155    fn encode() {
156        let encoder = Encoder::new();
157        let encoder = encoder.arc(1).unwrap();
158        let encoder = encoder.arc(2).unwrap();
159        let encoder = encoder.arc(840).unwrap();
160        let encoder = encoder.arc(10045).unwrap();
161        let encoder = encoder.arc(2).unwrap();
162        let encoder = encoder.arc(1).unwrap();
163        assert_eq!(&encoder.bytes[..encoder.cursor], EXAMPLE_OID_BER);
164    }
165}