Skip to main content

fidl_next_codec/
encoder.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The core [`Encoder`] trait.
6
7use core::marker::PhantomData;
8use core::mem::MaybeUninit;
9use core::slice::from_raw_parts;
10
11use crate::wire::Uint64;
12use crate::{CHUNK_SIZE, Chunk, Encode, EncodeError, Slot, Wire};
13
14/// An encoder for FIDL handles (internal).
15pub trait InternalHandleEncoder {
16    /// Returns the number of handles written to the encoder.
17    ///
18    /// This method exposes details about Fuchsia resources that plain old FIDL shouldn't need to
19    /// know about. Do not use this method outside of this crate.
20    #[doc(hidden)]
21    fn __internal_handle_count(&self) -> usize;
22}
23
24/// An encoder for FIDL messages.
25pub trait Encoder: InternalHandleEncoder {
26    /// Returns the number of bytes written to the encoder.
27    fn bytes_written(&self) -> usize;
28
29    /// Writes zeroed bytes to the end of the encoder.
30    ///
31    /// Additional bytes are written to pad the written data to a multiple of [`CHUNK_SIZE`].
32    fn write_zeroes(&mut self, len: usize);
33
34    /// Copies bytes to the end of the encoder.
35    ///
36    /// Additional bytes are written to pad the written data to a multiple of [`CHUNK_SIZE`].
37    fn write(&mut self, bytes: &[u8]);
38
39    /// Rewrites bytes at a position in the encoder.
40    fn rewrite(&mut self, pos: usize, bytes: &[u8]);
41}
42
43impl InternalHandleEncoder for Vec<Chunk> {
44    #[inline]
45    fn __internal_handle_count(&self) -> usize {
46        0
47    }
48}
49
50impl Encoder for Vec<Chunk> {
51    #[inline]
52    fn bytes_written(&self) -> usize {
53        self.len() * CHUNK_SIZE
54    }
55
56    #[inline]
57    fn write_zeroes(&mut self, len: usize) {
58        let count = len.div_ceil(CHUNK_SIZE);
59        self.reserve(count);
60        // SAFETY: `reserve` ensures the vector has enough capacity for `count` additional
61        // elements.
62        let ptr = unsafe { self.as_mut_ptr().add(self.len()) };
63        // SAFETY: `ptr` is valid for writing `count` elements because of the previous `reserve`
64        // call.
65        unsafe {
66            ptr.write_bytes(0, count);
67        }
68        // SAFETY: The memory up to the new length has been initialized to zero.
69        unsafe {
70            self.set_len(self.len() + count);
71        }
72    }
73
74    #[inline]
75    fn write(&mut self, bytes: &[u8]) {
76        if bytes.is_empty() {
77            return;
78        }
79
80        let count = bytes.len().div_ceil(CHUNK_SIZE);
81        self.reserve(count);
82
83        // Zero out the last chunk
84        // SAFETY: `reserve` ensures the pointer is within the allocated capacity.
85        unsafe {
86            self.as_mut_ptr().add(self.len() + count - 1).write(Uint64(0));
87        }
88        // SAFETY: `reserve` ensures the pointer is within the allocated capacity.
89        let ptr = unsafe { self.as_mut_ptr().add(self.len()).cast::<u8>() };
90
91        // Copy all the bytes
92        // SAFETY: `ptr` has sufficient capacity for `bytes.len()`,
93        // which is less than or equal to `count * CHUNK_SIZE`.
94        unsafe {
95            ptr.copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
96        }
97
98        // Set the new length
99        // SAFETY: All `count` chunks have been initialized.
100        unsafe {
101            self.set_len(self.len() + count);
102        }
103    }
104
105    #[inline]
106    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
107        assert!(pos + bytes.len() <= self.bytes_written());
108
109        // SAFETY: `pos` is within the initialized bounds of the vector.
110        let ptr = unsafe { self.as_mut_ptr().cast::<u8>().add(pos) };
111        // SAFETY: The destination pointer is valid for writes of `bytes.len()` and
112        // does not overlap with `bytes`.
113        unsafe {
114            ptr.copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
115        }
116    }
117}
118
119/// Extension methods for [`Encoder`].
120pub trait EncoderExt {
121    /// Pre-allocates space for a slice of elements.
122    fn preallocate<T>(&mut self, len: usize) -> Preallocated<'_, Self, T>;
123
124    /// Encodes an iterator of elements.
125    ///
126    /// Returns `Err` if encoding failed.
127    fn encode_next_iter<W, T>(
128        &mut self,
129        values: impl ExactSizeIterator<Item = T>,
130    ) -> Result<(), EncodeError>
131    where
132        W: Wire<Constraint = ()>,
133        T: Encode<W, Self>;
134
135    /// Encodes an iterator of elements.
136    ///
137    /// Returns `Err` if encoding failed.
138    fn encode_next_iter_with_constraint<W, T>(
139        &mut self,
140        values: impl ExactSizeIterator<Item = T>,
141        constraint: W::Constraint,
142    ) -> Result<(), EncodeError>
143    where
144        W: Wire,
145        T: Encode<W, Self>;
146
147    /// Encodes a value.
148    ///
149    /// Returns `Err` if encoding failed.
150    fn encode_next<W, T>(&mut self, value: T) -> Result<(), EncodeError>
151    where
152        W: Wire<Constraint = ()>,
153        T: Encode<W, Self>;
154
155    /// Encodes a value with a constraint.
156    ///
157    /// Returns `Err` if encoding failed.
158    fn encode_next_with_constraint<W: Wire, T: Encode<W, Self>>(
159        &mut self,
160        value: T,
161        constraint: W::Constraint,
162    ) -> Result<(), EncodeError>;
163
164    /// Encodes a value into a new instance of the encoder.
165    ///
166    /// Returns `Err` if encoding failed.
167    fn encode<W, T>(value: T) -> Result<Self, EncodeError>
168    where
169        Self: Default,
170        W: Wire<Constraint = ()>,
171        T: Encode<W, Self>;
172
173    /// Encodes a value with a constraint into a new instance of the encoder.
174    ///
175    /// Returns `Err` if encoding failed.
176    fn encode_with_constraint<W, T>(
177        value: T,
178        constraint: W::Constraint,
179    ) -> Result<Self, EncodeError>
180    where
181        Self: Default,
182        W: Wire,
183        T: Encode<W, Self>;
184}
185
186impl<E: Encoder + ?Sized> EncoderExt for E {
187    fn preallocate<T>(&mut self, len: usize) -> Preallocated<'_, Self, T> {
188        let pos = self.bytes_written();
189
190        // Zero out the next `count` bytes
191        self.write_zeroes(len * size_of::<T>());
192
193        Preallocated {
194            encoder: self,
195            pos,
196            #[cfg(debug_assertions)]
197            remaining: len,
198            _phantom: PhantomData,
199        }
200    }
201
202    fn encode_next_iter<W, T>(
203        &mut self,
204        values: impl ExactSizeIterator<Item = T>,
205    ) -> Result<(), EncodeError>
206    where
207        W: Wire<Constraint = ()>,
208        T: Encode<W, Self>,
209    {
210        self.encode_next_iter_with_constraint(values, ())
211    }
212
213    fn encode_next_iter_with_constraint<W, T>(
214        &mut self,
215        values: impl ExactSizeIterator<Item = T>,
216        constraint: W::Constraint,
217    ) -> Result<(), EncodeError>
218    where
219        W: Wire,
220        T: Encode<W, Self>,
221    {
222        let mut outputs = self.preallocate::<W>(values.len());
223
224        let mut out = MaybeUninit::<W>::uninit();
225        <W as Wire>::zero_padding(&mut out);
226        for value in values {
227            value.encode(outputs.encoder, &mut out, constraint)?;
228            // SAFETY: `out` has been fully initialized by `W::zero_padding` and `value.encode`.
229            W::validate(unsafe { Slot::new_unchecked_from_maybe_uninit(&mut out) }, constraint)
230                .map_err(EncodeError::Validation)?;
231            // SAFETY: `out` has been fully initialized.
232            unsafe {
233                outputs.write_next(out.assume_init_ref());
234            }
235        }
236
237        Ok(())
238    }
239
240    fn encode_next<W, T>(&mut self, value: T) -> Result<(), EncodeError>
241    where
242        W: Wire<Constraint = ()>,
243        T: Encode<W, Self>,
244    {
245        self.encode_next_with_constraint(value, ())
246    }
247
248    fn encode_next_with_constraint<W, T>(
249        &mut self,
250        value: T,
251        constraint: W::Constraint,
252    ) -> Result<(), EncodeError>
253    where
254        W: Wire,
255        T: Encode<W, Self>,
256    {
257        self.encode_next_iter_with_constraint(core::iter::once(value), constraint)
258    }
259
260    fn encode<W, T>(value: T) -> Result<Self, EncodeError>
261    where
262        Self: Default,
263        W: Wire<Constraint = ()>,
264        T: Encode<W, Self>,
265    {
266        Self::encode_with_constraint(value, ())
267    }
268
269    fn encode_with_constraint<W, T>(
270        value: T,
271        constraint: W::Constraint,
272    ) -> Result<Self, EncodeError>
273    where
274        Self: Default,
275        W: Wire,
276        T: Encode<W, Self>,
277    {
278        let mut result = Self::default();
279        result.encode_next_with_constraint(value, constraint)?;
280        Ok(result)
281    }
282}
283
284/// A pre-allocated slice of elements
285pub struct Preallocated<'a, E: ?Sized, T> {
286    /// The encoder.
287    pub encoder: &'a mut E,
288    pos: usize,
289    #[cfg(debug_assertions)]
290    remaining: usize,
291    _phantom: PhantomData<T>,
292}
293
294impl<E: Encoder + ?Sized, T> Preallocated<'_, E, T> {
295    /// Writes into the next pre-allocated slot in the encoder.
296    ///
297    /// # Safety
298    ///
299    /// All of the bytes of `value` must be initialized, including padding.
300    pub unsafe fn write_next(&mut self, value: &T) {
301        #[cfg(debug_assertions)]
302        {
303            assert!(self.remaining > 0, "attemped to write more slots than preallocated");
304            self.remaining -= 1;
305        }
306
307        let bytes_ptr = (value as *const T).cast::<u8>();
308        // SAFETY: `value` is valid for reads of `size_of::<T>()` bytes, and the
309        // caller guarantees it is fully initialized.
310        let bytes = unsafe { from_raw_parts(bytes_ptr, size_of::<T>()) };
311        self.encoder.rewrite(self.pos, bytes);
312        self.pos += size_of::<T>();
313    }
314}