fidl_next_codec/
encoder.rs1use core::marker::PhantomData;
8use core::mem::MaybeUninit;
9use core::slice::from_raw_parts;
10
11use crate::{CHUNK_SIZE, Chunk, Constrained, Encode, EncodeError, Slot, Wire, WireU64};
12
13pub trait InternalHandleEncoder {
15 #[doc(hidden)]
20 fn __internal_handle_count(&self) -> usize;
21}
22
23pub trait Encoder: InternalHandleEncoder {
25 fn bytes_written(&self) -> usize;
27
28 fn write_zeroes(&mut self, len: usize);
32
33 fn write(&mut self, bytes: &[u8]);
37
38 fn rewrite(&mut self, pos: usize, bytes: &[u8]);
40}
41
42impl InternalHandleEncoder for Vec<Chunk> {
43 #[inline]
44 fn __internal_handle_count(&self) -> usize {
45 0
46 }
47}
48
49impl Encoder for Vec<Chunk> {
50 #[inline]
51 fn bytes_written(&self) -> usize {
52 self.len() * CHUNK_SIZE
53 }
54
55 #[inline]
56 fn write_zeroes(&mut self, len: usize) {
57 let count = len.div_ceil(CHUNK_SIZE);
58 self.reserve(count);
59 let ptr = unsafe { self.as_mut_ptr().add(self.len()) };
60 unsafe {
61 ptr.write_bytes(0, count);
62 }
63 unsafe {
64 self.set_len(self.len() + count);
65 }
66 }
67
68 #[inline]
69 fn write(&mut self, bytes: &[u8]) {
70 if bytes.is_empty() {
71 return;
72 }
73
74 let count = bytes.len().div_ceil(CHUNK_SIZE);
75 self.reserve(count);
76
77 unsafe {
79 self.as_mut_ptr().add(self.len() + count - 1).write(WireU64(0));
80 }
81 let ptr = unsafe { self.as_mut_ptr().add(self.len()).cast::<u8>() };
82
83 unsafe {
85 ptr.copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
86 }
87
88 unsafe {
90 self.set_len(self.len() + count);
91 }
92 }
93
94 #[inline]
95 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
96 assert!(pos + bytes.len() <= self.bytes_written());
97
98 let ptr = unsafe { self.as_mut_ptr().cast::<u8>().add(pos) };
99 unsafe {
100 ptr.copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
101 }
102 }
103}
104
105pub trait EncoderExt {
107 fn preallocate<T>(&mut self, len: usize) -> Preallocated<'_, Self, T>;
109
110 fn encode_next_iter<W: Constrained + Wire, T: Encode<W, Self>>(
114 &mut self,
115 values: impl ExactSizeIterator<Item = T>,
116 constraint: W::Constraint,
117 ) -> Result<(), EncodeError>;
118
119 fn encode_next<W: Constrained + Wire, T: Encode<W, Self>>(
123 &mut self,
124 value: T,
125 constraint: W::Constraint,
126 ) -> Result<(), EncodeError>;
127}
128
129impl<E: Encoder + ?Sized> EncoderExt for E {
130 fn preallocate<T>(&mut self, len: usize) -> Preallocated<'_, Self, T> {
131 let pos = self.bytes_written();
132
133 self.write_zeroes(len * size_of::<T>());
135
136 Preallocated {
137 encoder: self,
138 pos,
139 #[cfg(debug_assertions)]
140 remaining: len,
141 _phantom: PhantomData,
142 }
143 }
144
145 fn encode_next_iter<W: Constrained + Wire, T: Encode<W, Self>>(
146 &mut self,
147 values: impl ExactSizeIterator<Item = T>,
148 constraint: W::Constraint,
149 ) -> Result<(), EncodeError> {
150 let mut outputs = self.preallocate::<W>(values.len());
151
152 let mut out = MaybeUninit::<W>::uninit();
153 <W as Wire>::zero_padding(&mut out);
154 for value in values {
155 value.encode(outputs.encoder, &mut out, constraint)?;
156 <W as Constrained>::validate(
157 unsafe { Slot::new_unchecked_from_maybe_uninit(&mut out) },
158 constraint,
159 )
160 .map_err(EncodeError::Validation)?;
161 unsafe {
162 outputs.write_next(out.assume_init_ref());
163 }
164 }
165
166 Ok(())
167 }
168
169 fn encode_next<W: Constrained + Wire, T: Encode<W, Self>>(
170 &mut self,
171 value: T,
172 constraint: W::Constraint,
173 ) -> Result<(), EncodeError> {
174 self.encode_next_iter(core::iter::once(value), constraint)
175 }
176}
177
178pub struct Preallocated<'a, E: ?Sized, T> {
180 pub encoder: &'a mut E,
182 pos: usize,
183 #[cfg(debug_assertions)]
184 remaining: usize,
185 _phantom: PhantomData<T>,
186}
187
188impl<E: Encoder + ?Sized, T> Preallocated<'_, E, T> {
189 pub unsafe fn write_next(&mut self, value: &T) {
195 #[cfg(debug_assertions)]
196 {
197 assert!(self.remaining > 0, "attemped to write more slots than preallocated");
198 self.remaining -= 1;
199 }
200
201 let bytes_ptr = (value as *const T).cast::<u8>();
202 let bytes = unsafe { from_raw_parts(bytes_ptr, size_of::<T>()) };
203 self.encoder.rewrite(self.pos, bytes);
204 self.pos += size_of::<T>();
205 }
206}