Skip to main content

fidl_next_codec/wire/vec/
optional.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::mem::{MaybeUninit, needs_drop};
6use core::{fmt, slice};
7
8use munge::munge;
9
10use super::raw::RawVector;
11use crate::{
12    Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError, EncodeOption,
13    Encoder, EncoderExt as _, FromWire, FromWireOption, FromWireOptionRef, FromWireRef,
14    IntoNatural, Slot, ValidationError, Wire, wire,
15};
16
17/// An optional FIDL vector
18#[repr(transparent)]
19pub struct OptionalVector<'de, T> {
20    raw: RawVector<'de, T>,
21}
22
23unsafe impl<T: Wire> Wire for OptionalVector<'static, T> {
24    type Narrowed<'de> = OptionalVector<'de, T::Narrowed<'de>>;
25
26    #[inline]
27    fn zero_padding(out: &mut MaybeUninit<Self>) {
28        munge!(let Self { raw } = out);
29        RawVector::<T>::zero_padding(raw);
30    }
31}
32
33impl<T> Drop for OptionalVector<'_, T> {
34    fn drop(&mut self) {
35        if needs_drop::<T>() && self.is_some() {
36            unsafe {
37                self.raw.as_slice_ptr().drop_in_place();
38            }
39        }
40    }
41}
42
43impl<'de, T> OptionalVector<'de, T> {
44    /// Encodes that a vector is present in a slot.
45    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
46        munge!(let Self { raw } = out);
47        RawVector::encode_present(raw, len);
48    }
49
50    /// Encodes that a vector is absent in a slot.
51    pub fn encode_absent(out: &mut MaybeUninit<Self>) {
52        munge!(let Self { raw } = out);
53        RawVector::encode_absent(raw);
54    }
55
56    /// Returns whether the vector is present.
57    pub fn is_some(&self) -> bool {
58        !self.raw.as_ptr().is_null()
59    }
60
61    /// Returns whether the vector is absent.
62    pub fn is_none(&self) -> bool {
63        !self.is_some()
64    }
65
66    /// Gets a reference to the vector, if any.
67    pub fn as_ref(&self) -> Option<&wire::Vector<'_, T>> {
68        if self.is_some() { Some(unsafe { &*(self as *const Self).cast() }) } else { None }
69    }
70
71    /// Converts the optional wire vector to an `Option<WireVector>`.
72    pub fn to_option(self) -> Option<wire::Vector<'de, T>> {
73        if self.is_some() {
74            Some(unsafe { core::mem::transmute::<Self, wire::Vector<'de, T>>(self) })
75        } else {
76            None
77        }
78    }
79
80    /// Decodes a wire vector which contains raw data.
81    ///
82    /// # Safety
83    ///
84    /// The elements of the wire vector must not need to be individually decoded, and must always be
85    /// valid.
86    pub unsafe fn decode_raw<D>(
87        mut slot: Slot<'_, Self>,
88        decoder: &mut D,
89        max_len: u64,
90    ) -> Result<(), DecodeError>
91    where
92        D: Decoder<'de> + ?Sized,
93        T: Decode<D>,
94    {
95        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
96
97        if wire::Pointer::is_encoded_present(ptr.as_mut())? {
98            if **len > max_len {
99                return Err(DecodeError::Validation(ValidationError::VectorTooLong {
100                    count: **len,
101                    limit: max_len,
102                }));
103            }
104
105            let slice = decoder.take_slice_slot::<T>(**len as usize)?;
106            wire::Pointer::set_decoded_slice(ptr, slice);
107        } else if *len != 0 {
108            return Err(DecodeError::InvalidOptionalSize(**len));
109        }
110
111        Ok(())
112    }
113
114    /// Validate that this vector's length falls within the limit.
115    pub(crate) fn validate_max_len(
116        slot: Slot<'_, Self>,
117        limit: u64,
118    ) -> Result<(), ValidationError> {
119        munge!(let Self { raw: RawVector { len, ptr } } = slot);
120        let count = **len;
121        let is_present = ptr.as_bytes() != [0; 8];
122        if is_present && count > limit {
123            Err(ValidationError::VectorTooLong { count, limit })
124        } else {
125            Ok(())
126        }
127    }
128}
129
130type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
131
132impl<T: Constrained> Constrained for OptionalVector<'_, T> {
133    type Constraint = VectorConstraint<T>;
134
135    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
136        let (limit, _member_constraint) = constraint;
137
138        Self::validate_max_len(slot, limit)
139    }
140}
141
142impl<T: fmt::Debug> fmt::Debug for OptionalVector<'_, T> {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        self.as_ref().fmt(f)
145    }
146}
147
148impl<T, U> PartialEq<Option<U>> for OptionalVector<'_, T>
149where
150    for<'de> wire::Vector<'de, T>: PartialEq<U>,
151{
152    fn eq(&self, other: &Option<U>) -> bool {
153        match (self.as_ref(), other.as_ref()) {
154            (Some(lhs), Some(rhs)) => lhs == rhs,
155            (None, None) => true,
156            _ => false,
157        }
158    }
159}
160
161unsafe impl<'de, D, T> Decode<D> for OptionalVector<'de, T>
162where
163    D: Decoder<'de> + ?Sized,
164    T: Decode<D>,
165{
166    fn decode(
167        mut slot: Slot<'_, Self>,
168        decoder: &mut D,
169        constraint: Self::Constraint,
170    ) -> Result<(), DecodeError> {
171        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
172
173        let (length_constraint, member_constraint) = constraint;
174
175        if wire::Pointer::is_encoded_present(ptr.as_mut())? {
176            if **len > length_constraint {
177                return Err(DecodeError::Validation(ValidationError::VectorTooLong {
178                    count: **len,
179                    limit: length_constraint,
180                }));
181            }
182
183            let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
184            for i in 0..**len as usize {
185                T::decode(slice.index(i), decoder, member_constraint)?;
186            }
187            wire::Pointer::set_decoded_slice(ptr, slice);
188        } else if *len != 0 {
189            return Err(DecodeError::InvalidOptionalSize(**len));
190        }
191
192        Ok(())
193    }
194}
195
196#[inline]
197fn encode_to_optional_vector<V, W, E, T>(
198    value: Option<V>,
199    encoder: &mut E,
200    out: &mut MaybeUninit<OptionalVector<'static, W>>,
201    constraint: VectorConstraint<W>,
202) -> Result<(), EncodeError>
203where
204    V: AsRef<[T]> + IntoIterator,
205    V::IntoIter: ExactSizeIterator,
206    V::Item: Encode<W, E>,
207    W: Wire,
208    E: Encoder + ?Sized,
209    T: Encode<W, E>,
210{
211    let (length_constraint, member_constraint) = constraint;
212
213    if let Some(value) = value {
214        let len = value.as_ref().len();
215
216        if len as u64 > length_constraint {
217            return Err(EncodeError::Validation(ValidationError::VectorTooLong {
218                count: len as u64,
219                limit: length_constraint,
220            }));
221        }
222
223        if T::COPY_OPTIMIZATION.is_enabled() {
224            let slice = value.as_ref();
225            // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit
226            // bytes and can be copied directly to the output instead of calling `encode`. This
227            // means that we may cast `&[T]` to `&[u8]` and write those bytes.
228            let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
229            encoder.write(bytes);
230        } else {
231            encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
232        }
233        OptionalVector::encode_present(out, len as u64);
234    } else {
235        OptionalVector::encode_absent(out);
236    }
237    Ok(())
238}
239
240unsafe impl<W, E, T> EncodeOption<OptionalVector<'static, W>, E> for Vec<T>
241where
242    W: Wire,
243    E: Encoder + ?Sized,
244    T: Encode<W, E>,
245{
246    fn encode_option(
247        this: Option<Self>,
248        encoder: &mut E,
249        out: &mut MaybeUninit<OptionalVector<'static, W>>,
250        constraint: VectorConstraint<W>,
251    ) -> Result<(), EncodeError> {
252        encode_to_optional_vector(this, encoder, out, constraint)
253    }
254}
255
256unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a Vec<T>
257where
258    W: Wire,
259    E: Encoder + ?Sized,
260    T: Encode<W, E>,
261    &'a T: Encode<W, E>,
262{
263    fn encode_option(
264        this: Option<Self>,
265        encoder: &mut E,
266        out: &mut MaybeUninit<OptionalVector<'static, W>>,
267        constraint: VectorConstraint<W>,
268    ) -> Result<(), EncodeError> {
269        encode_to_optional_vector(this, encoder, out, constraint)
270    }
271}
272
273unsafe impl<W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for [T; N]
274where
275    W: Wire,
276    E: Encoder + ?Sized,
277    T: Encode<W, E>,
278{
279    fn encode_option(
280        this: Option<Self>,
281        encoder: &mut E,
282        out: &mut MaybeUninit<OptionalVector<'static, W>>,
283        constraint: VectorConstraint<W>,
284    ) -> Result<(), EncodeError> {
285        encode_to_optional_vector(this, encoder, out, constraint)
286    }
287}
288
289unsafe impl<'a, W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for &'a [T; N]
290where
291    W: Wire,
292    E: Encoder + ?Sized,
293    T: Encode<W, E>,
294    &'a T: Encode<W, E>,
295{
296    fn encode_option(
297        this: Option<Self>,
298        encoder: &mut E,
299        out: &mut MaybeUninit<OptionalVector<'static, W>>,
300        constraint: VectorConstraint<W>,
301    ) -> Result<(), EncodeError> {
302        encode_to_optional_vector(this, encoder, out, constraint)
303    }
304}
305
306unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a [T]
307where
308    W: Wire,
309    E: Encoder + ?Sized,
310    T: Encode<W, E>,
311    &'a T: Encode<W, E>,
312{
313    fn encode_option(
314        this: Option<Self>,
315        encoder: &mut E,
316        out: &mut MaybeUninit<OptionalVector<'static, W>>,
317        constraint: VectorConstraint<W>,
318    ) -> Result<(), EncodeError> {
319        encode_to_optional_vector(this, encoder, out, constraint)
320    }
321}
322
323impl<T: FromWire<W>, W> FromWireOption<OptionalVector<'_, W>> for Vec<T> {
324    fn from_wire_option(wire: OptionalVector<'_, W>) -> Option<Self> {
325        wire.to_option().map(Vec::from_wire)
326    }
327}
328
329impl<T: IntoNatural> IntoNatural for OptionalVector<'_, T> {
330    type Natural = Option<Vec<T::Natural>>;
331}
332
333impl<T: FromWireRef<W>, W> FromWireOptionRef<OptionalVector<'_, W>> for Vec<T> {
334    fn from_wire_option_ref(wire: &OptionalVector<'_, W>) -> Option<Self> {
335        wire.as_ref().map(Vec::from_wire_ref)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
342
343    #[test]
344    fn decode_optional_vec() {
345        assert_eq!(
346            chunks![
347                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348                0x00, 0x00,
349            ]
350            .as_mut_slice()
351            .decode_with_constraint::<wire::OptionalVector<'_, wire::Uint32>>((1000, ()))
352            .unwrap()
353            .as_ref(),
354            None,
355        );
356    }
357
358    #[test]
359    fn encode_optional_vec() {
360        assert_eq!(
361            Vec::encode_with_constraint(None::<Vec<u32>>, (1000, ())).unwrap(),
362            chunks![
363                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
364                0x00, 0x00,
365            ],
366        );
367    }
368}