1use core::mem::{MaybeUninit, needs_drop};
6use core::{fmt, slice};
7
8use munge::munge;
9
10use super::raw::RawVector;
11use crate::{
12 Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError, EncodeOption,
13 Encoder, EncoderExt as _, FromWire, FromWireOption, FromWireOptionRef, FromWireRef,
14 IntoNatural, Slot, ValidationError, Wire, wire,
15};
16
17#[repr(transparent)]
19pub struct OptionalVector<'de, T> {
20 raw: RawVector<'de, T>,
21}
22
23unsafe impl<T: Wire> Wire for OptionalVector<'static, T> {
24 type Narrowed<'de> = OptionalVector<'de, T::Narrowed<'de>>;
25
26 #[inline]
27 fn zero_padding(out: &mut MaybeUninit<Self>) {
28 munge!(let Self { raw } = out);
29 RawVector::<T>::zero_padding(raw);
30 }
31}
32
33impl<T> Drop for OptionalVector<'_, T> {
34 fn drop(&mut self) {
35 if needs_drop::<T>() && self.is_some() {
36 unsafe {
37 self.raw.as_slice_ptr().drop_in_place();
38 }
39 }
40 }
41}
42
43impl<'de, T> OptionalVector<'de, T> {
44 pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
46 munge!(let Self { raw } = out);
47 RawVector::encode_present(raw, len);
48 }
49
50 pub fn encode_absent(out: &mut MaybeUninit<Self>) {
52 munge!(let Self { raw } = out);
53 RawVector::encode_absent(raw);
54 }
55
56 pub fn is_some(&self) -> bool {
58 !self.raw.as_ptr().is_null()
59 }
60
61 pub fn is_none(&self) -> bool {
63 !self.is_some()
64 }
65
66 pub fn as_ref(&self) -> Option<&wire::Vector<'_, T>> {
68 if self.is_some() { Some(unsafe { &*(self as *const Self).cast() }) } else { None }
69 }
70
71 pub fn to_option(self) -> Option<wire::Vector<'de, T>> {
73 if self.is_some() {
74 Some(unsafe { core::mem::transmute::<Self, wire::Vector<'de, T>>(self) })
75 } else {
76 None
77 }
78 }
79
80 pub unsafe fn decode_raw<D>(
87 mut slot: Slot<'_, Self>,
88 decoder: &mut D,
89 max_len: u64,
90 ) -> Result<(), DecodeError>
91 where
92 D: Decoder<'de> + ?Sized,
93 T: Decode<D>,
94 {
95 munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
96
97 if wire::Pointer::is_encoded_present(ptr.as_mut())? {
98 if **len > max_len {
99 return Err(DecodeError::Validation(ValidationError::VectorTooLong {
100 count: **len,
101 limit: max_len,
102 }));
103 }
104
105 let slice = decoder.take_slice_slot::<T>(**len as usize)?;
106 wire::Pointer::set_decoded_slice(ptr, slice);
107 } else if *len != 0 {
108 return Err(DecodeError::InvalidOptionalSize(**len));
109 }
110
111 Ok(())
112 }
113
114 pub(crate) fn validate_max_len(
116 slot: Slot<'_, Self>,
117 limit: u64,
118 ) -> Result<(), ValidationError> {
119 munge!(let Self { raw: RawVector { len, ptr } } = slot);
120 let count = **len;
121 let is_present = ptr.as_bytes() != [0; 8];
122 if is_present && count > limit {
123 Err(ValidationError::VectorTooLong { count, limit })
124 } else {
125 Ok(())
126 }
127 }
128}
129
130type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
131
132impl<T: Constrained> Constrained for OptionalVector<'_, T> {
133 type Constraint = VectorConstraint<T>;
134
135 fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
136 let (limit, _member_constraint) = constraint;
137
138 Self::validate_max_len(slot, limit)
139 }
140}
141
142impl<T: fmt::Debug> fmt::Debug for OptionalVector<'_, T> {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 self.as_ref().fmt(f)
145 }
146}
147
148impl<T, U> PartialEq<Option<U>> for OptionalVector<'_, T>
149where
150 for<'de> wire::Vector<'de, T>: PartialEq<U>,
151{
152 fn eq(&self, other: &Option<U>) -> bool {
153 match (self.as_ref(), other.as_ref()) {
154 (Some(lhs), Some(rhs)) => lhs == rhs,
155 (None, None) => true,
156 _ => false,
157 }
158 }
159}
160
161unsafe impl<'de, D, T> Decode<D> for OptionalVector<'de, T>
162where
163 D: Decoder<'de> + ?Sized,
164 T: Decode<D>,
165{
166 fn decode(
167 mut slot: Slot<'_, Self>,
168 decoder: &mut D,
169 constraint: Self::Constraint,
170 ) -> Result<(), DecodeError> {
171 munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
172
173 let (length_constraint, member_constraint) = constraint;
174
175 if wire::Pointer::is_encoded_present(ptr.as_mut())? {
176 if **len > length_constraint {
177 return Err(DecodeError::Validation(ValidationError::VectorTooLong {
178 count: **len,
179 limit: length_constraint,
180 }));
181 }
182
183 let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
184 for i in 0..**len as usize {
185 T::decode(slice.index(i), decoder, member_constraint)?;
186 }
187 wire::Pointer::set_decoded_slice(ptr, slice);
188 } else if *len != 0 {
189 return Err(DecodeError::InvalidOptionalSize(**len));
190 }
191
192 Ok(())
193 }
194}
195
196#[inline]
197fn encode_to_optional_vector<V, W, E, T>(
198 value: Option<V>,
199 encoder: &mut E,
200 out: &mut MaybeUninit<OptionalVector<'static, W>>,
201 constraint: VectorConstraint<W>,
202) -> Result<(), EncodeError>
203where
204 V: AsRef<[T]> + IntoIterator,
205 V::IntoIter: ExactSizeIterator,
206 V::Item: Encode<W, E>,
207 W: Wire,
208 E: Encoder + ?Sized,
209 T: Encode<W, E>,
210{
211 let (length_constraint, member_constraint) = constraint;
212
213 if let Some(value) = value {
214 let len = value.as_ref().len();
215
216 if len as u64 > length_constraint {
217 return Err(EncodeError::Validation(ValidationError::VectorTooLong {
218 count: len as u64,
219 limit: length_constraint,
220 }));
221 }
222
223 if T::COPY_OPTIMIZATION.is_enabled() {
224 let slice = value.as_ref();
225 let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
229 encoder.write(bytes);
230 } else {
231 encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
232 }
233 OptionalVector::encode_present(out, len as u64);
234 } else {
235 OptionalVector::encode_absent(out);
236 }
237 Ok(())
238}
239
240unsafe impl<W, E, T> EncodeOption<OptionalVector<'static, W>, E> for Vec<T>
241where
242 W: Wire,
243 E: Encoder + ?Sized,
244 T: Encode<W, E>,
245{
246 fn encode_option(
247 this: Option<Self>,
248 encoder: &mut E,
249 out: &mut MaybeUninit<OptionalVector<'static, W>>,
250 constraint: VectorConstraint<W>,
251 ) -> Result<(), EncodeError> {
252 encode_to_optional_vector(this, encoder, out, constraint)
253 }
254}
255
256unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a Vec<T>
257where
258 W: Wire,
259 E: Encoder + ?Sized,
260 T: Encode<W, E>,
261 &'a T: Encode<W, E>,
262{
263 fn encode_option(
264 this: Option<Self>,
265 encoder: &mut E,
266 out: &mut MaybeUninit<OptionalVector<'static, W>>,
267 constraint: VectorConstraint<W>,
268 ) -> Result<(), EncodeError> {
269 encode_to_optional_vector(this, encoder, out, constraint)
270 }
271}
272
273unsafe impl<W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for [T; N]
274where
275 W: Wire,
276 E: Encoder + ?Sized,
277 T: Encode<W, E>,
278{
279 fn encode_option(
280 this: Option<Self>,
281 encoder: &mut E,
282 out: &mut MaybeUninit<OptionalVector<'static, W>>,
283 constraint: VectorConstraint<W>,
284 ) -> Result<(), EncodeError> {
285 encode_to_optional_vector(this, encoder, out, constraint)
286 }
287}
288
289unsafe impl<'a, W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for &'a [T; N]
290where
291 W: Wire,
292 E: Encoder + ?Sized,
293 T: Encode<W, E>,
294 &'a T: Encode<W, E>,
295{
296 fn encode_option(
297 this: Option<Self>,
298 encoder: &mut E,
299 out: &mut MaybeUninit<OptionalVector<'static, W>>,
300 constraint: VectorConstraint<W>,
301 ) -> Result<(), EncodeError> {
302 encode_to_optional_vector(this, encoder, out, constraint)
303 }
304}
305
306unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a [T]
307where
308 W: Wire,
309 E: Encoder + ?Sized,
310 T: Encode<W, E>,
311 &'a T: Encode<W, E>,
312{
313 fn encode_option(
314 this: Option<Self>,
315 encoder: &mut E,
316 out: &mut MaybeUninit<OptionalVector<'static, W>>,
317 constraint: VectorConstraint<W>,
318 ) -> Result<(), EncodeError> {
319 encode_to_optional_vector(this, encoder, out, constraint)
320 }
321}
322
323impl<T: FromWire<W>, W> FromWireOption<OptionalVector<'_, W>> for Vec<T> {
324 fn from_wire_option(wire: OptionalVector<'_, W>) -> Option<Self> {
325 wire.to_option().map(Vec::from_wire)
326 }
327}
328
329impl<T: IntoNatural> IntoNatural for OptionalVector<'_, T> {
330 type Natural = Option<Vec<T::Natural>>;
331}
332
333impl<T: FromWireRef<W>, W> FromWireOptionRef<OptionalVector<'_, W>> for Vec<T> {
334 fn from_wire_option_ref(wire: &OptionalVector<'_, W>) -> Option<Self> {
335 wire.as_ref().map(Vec::from_wire_ref)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
342
343 #[test]
344 fn decode_optional_vec() {
345 assert_eq!(
346 chunks![
347 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348 0x00, 0x00,
349 ]
350 .as_mut_slice()
351 .decode_with_constraint::<wire::OptionalVector<'_, wire::Uint32>>((1000, ()))
352 .unwrap()
353 .as_ref(),
354 None,
355 );
356 }
357
358 #[test]
359 fn encode_optional_vec() {
360 assert_eq!(
361 Vec::encode_with_constraint(None::<Vec<u32>>, (1000, ())).unwrap(),
362 chunks![
363 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
364 0x00, 0x00,
365 ],
366 );
367 }
368}