Skip to main content

fidl_next_codec/wire/vec/
required.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::mem::{MaybeUninit, forget, needs_drop};
7use core::ops::Deref;
8use core::ptr::{NonNull, copy_nonoverlapping};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawVector;
14use crate::{
15    Chunk, Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError,
16    Encoder, EncoderExt as _, FromWire, FromWireRef, IntoNatural, Slot, ValidationError, Wire,
17    wire,
18};
19
20/// A FIDL vector
21#[repr(transparent)]
22pub struct Vector<'de, T> {
23    raw: RawVector<'de, T>,
24}
25
26unsafe impl<T: Wire> Wire for Vector<'static, T> {
27    type Narrowed<'de> = Vector<'de, T::Narrowed<'de>>;
28
29    #[inline]
30    fn zero_padding(out: &mut MaybeUninit<Self>) {
31        munge!(let Self { raw } = out);
32        RawVector::<T>::zero_padding(raw);
33    }
34}
35
36impl<T> Drop for Vector<'_, T> {
37    fn drop(&mut self) {
38        if needs_drop::<T>() {
39            unsafe {
40                self.raw.as_slice_ptr().drop_in_place();
41            }
42        }
43    }
44}
45
46impl<'de, T> Vector<'de, T> {
47    /// Encodes that a vector is present in a slot.
48    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
49        munge!(let Self { raw } = out);
50        RawVector::encode_present(raw, len);
51    }
52
53    /// Returns the length of the vector in elements.
54    pub fn len(&self) -> usize {
55        self.raw.len() as usize
56    }
57
58    /// Returns whether the vector is empty.
59    pub fn is_empty(&self) -> bool {
60        self.len() == 0
61    }
62
63    /// Returns a pointer to the elements of the vector.
64    fn as_slice_ptr(&self) -> NonNull<[T]> {
65        unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
66    }
67
68    /// Returns a slice of the elements of the vector.
69    pub fn as_slice(&self) -> &[T] {
70        unsafe { self.as_slice_ptr().as_ref() }
71    }
72
73    /// Decodes a wire vector which contains raw data.
74    ///
75    /// # Safety
76    ///
77    /// The elements of the wire vector must not need to be individually decoded, and must always be
78    /// valid.
79    pub unsafe fn decode_raw<D>(
80        mut slot: Slot<'_, Self>,
81        decoder: &mut D,
82        max_len: u64,
83    ) -> Result<(), DecodeError>
84    where
85        D: Decoder<'de> + ?Sized,
86        T: Decode<D>,
87    {
88        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
89
90        if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
91            return Err(DecodeError::RequiredValueAbsent);
92        }
93
94        if **len > max_len {
95            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
96                count: **len,
97                limit: max_len,
98            }));
99        }
100
101        let slice = decoder.take_slice_slot::<T>(**len as usize)?;
102        wire::Pointer::set_decoded_slice(ptr, slice);
103
104        Ok(())
105    }
106
107    /// Validate that this vector's length falls within the limit.
108    pub(crate) fn validate_max_len(
109        slot: Slot<'_, Self>,
110        limit: u64,
111    ) -> Result<(), ValidationError> {
112        munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
113        let count: u64 = **len;
114        if count > limit { Err(ValidationError::VectorTooLong { count, limit }) } else { Ok(()) }
115    }
116}
117
118type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
119
120impl<T: Constrained> Constrained for Vector<'_, T> {
121    type Constraint = VectorConstraint<T>;
122
123    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
124        let (limit, _) = constraint;
125
126        munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
127        let count = **len;
128        if count > limit {
129            return Err(ValidationError::VectorTooLong { count, limit });
130        }
131
132        Ok(())
133    }
134}
135
136/// An iterator over the items of a `WireVector`.
137pub struct IntoIter<'de, T> {
138    current: *mut T,
139    remaining: usize,
140    _phantom: PhantomData<&'de mut [Chunk]>,
141}
142
143impl<T> Drop for IntoIter<'_, T> {
144    fn drop(&mut self) {
145        for i in 0..self.remaining {
146            unsafe {
147                self.current.add(i).drop_in_place();
148            }
149        }
150    }
151}
152
153impl<T> Iterator for IntoIter<'_, T> {
154    type Item = T;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.remaining == 0 {
158            None
159        } else {
160            let result = unsafe { self.current.read() };
161            self.current = unsafe { self.current.add(1) };
162            self.remaining -= 1;
163            Some(result)
164        }
165    }
166}
167
168impl<'de, T> IntoIterator for Vector<'de, T> {
169    type IntoIter = IntoIter<'de, T>;
170    type Item = T;
171
172    fn into_iter(self) -> Self::IntoIter {
173        let current = self.raw.as_ptr();
174        let remaining = self.len();
175        forget(self);
176
177        IntoIter { current, remaining, _phantom: PhantomData }
178    }
179}
180
181impl<T> Deref for Vector<'_, T> {
182    type Target = [T];
183
184    fn deref(&self) -> &Self::Target {
185        self.as_slice()
186    }
187}
188
189impl<T: fmt::Debug> fmt::Debug for Vector<'_, T> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        self.as_slice().fmt(f)
192    }
193}
194
195impl<T, U: ?Sized> PartialEq<&U> for Vector<'_, T>
196where
197    for<'de> Vector<'de, T>: PartialEq<U>,
198{
199    fn eq(&self, other: &&U) -> bool {
200        self == *other
201    }
202}
203
204impl<T: PartialEq<U>, U, const N: usize> PartialEq<[U; N]> for Vector<'_, T> {
205    fn eq(&self, other: &[U; N]) -> bool {
206        self.as_slice() == other.as_slice()
207    }
208}
209
210impl<T: PartialEq<U>, U> PartialEq<[U]> for Vector<'_, T> {
211    fn eq(&self, other: &[U]) -> bool {
212        self.as_slice() == other
213    }
214}
215
216impl<T: PartialEq<U>, U> PartialEq<Vector<'_, U>> for Vector<'_, T> {
217    fn eq(&self, other: &Vector<'_, U>) -> bool {
218        self.as_slice() == other.as_slice()
219    }
220}
221
222unsafe impl<'de, D, T> Decode<D> for Vector<'de, T>
223where
224    D: Decoder<'de> + ?Sized,
225    T: Decode<D>,
226{
227    fn decode(
228        mut slot: Slot<'_, Self>,
229        decoder: &mut D,
230        constraint: Self::Constraint,
231    ) -> Result<(), DecodeError> {
232        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
233
234        let (length_constraint, member_constraint) = constraint;
235
236        if **len > length_constraint {
237            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
238                count: **len,
239                limit: length_constraint,
240            }));
241        }
242
243        if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
244            return Err(DecodeError::RequiredValueAbsent);
245        }
246
247        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
248        for i in 0..**len as usize {
249            T::decode(slice.index(i), decoder, member_constraint)?;
250        }
251        wire::Pointer::set_decoded_slice(ptr, slice);
252
253        Ok(())
254    }
255}
256
257#[inline]
258fn encode_to_vector<V, W, E, T>(
259    value: V,
260    encoder: &mut E,
261    out: &mut MaybeUninit<Vector<'static, W>>,
262    constraint: VectorConstraint<W>,
263) -> Result<(), EncodeError>
264where
265    V: AsRef<[T]> + IntoIterator,
266    V::IntoIter: ExactSizeIterator,
267    V::Item: Encode<W, E>,
268    W: Wire,
269    E: Encoder + ?Sized,
270    T: Encode<W, E>,
271{
272    let len = value.as_ref().len();
273    let (length_constraint, member_constraint) = constraint;
274
275    if len as u64 > length_constraint {
276        return Err(EncodeError::Validation(ValidationError::VectorTooLong {
277            count: len as u64,
278            limit: length_constraint,
279        }));
280    }
281
282    if T::COPY_OPTIMIZATION.is_enabled() {
283        let slice = value.as_ref();
284        // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit bytes
285        // and can be copied directly to the output instead of calling `encode`. This means that we
286        // may cast `&[T]` to `&[u8]` and write those bytes.
287        let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
288        encoder.write(bytes);
289        // TODO: copy-optimized encodings don't currently check constraints
290    } else {
291        encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
292    }
293    Vector::encode_present(out, len as u64);
294    Ok(())
295}
296
297unsafe impl<W, E, T> Encode<Vector<'static, W>, E> for Vec<T>
298where
299    W: Wire,
300    E: Encoder + ?Sized,
301    T: Encode<W, E>,
302{
303    fn encode(
304        self,
305        encoder: &mut E,
306        out: &mut MaybeUninit<Vector<'static, W>>,
307        constraint: VectorConstraint<W>,
308    ) -> Result<(), EncodeError> {
309        encode_to_vector(self, encoder, out, constraint)
310    }
311}
312
313unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a Vec<T>
314where
315    W: Wire,
316    E: Encoder + ?Sized,
317    T: Encode<W, E>,
318    &'a T: Encode<W, E>,
319{
320    fn encode(
321        self,
322        encoder: &mut E,
323        out: &mut MaybeUninit<Vector<'static, W>>,
324        constraint: VectorConstraint<W>,
325    ) -> Result<(), EncodeError> {
326        encode_to_vector(self, encoder, out, constraint)
327    }
328}
329
330unsafe impl<W, E, T, const N: usize> Encode<Vector<'static, W>, E> for [T; N]
331where
332    W: Wire,
333    E: Encoder + ?Sized,
334    T: Encode<W, E>,
335{
336    fn encode(
337        self,
338        encoder: &mut E,
339        out: &mut MaybeUninit<Vector<'static, W>>,
340        constraint: VectorConstraint<W>,
341    ) -> Result<(), EncodeError> {
342        encode_to_vector(self, encoder, out, constraint)
343    }
344}
345
346unsafe impl<'a, W, E, T, const N: usize> Encode<Vector<'static, W>, E> for &'a [T; N]
347where
348    W: Wire,
349    E: Encoder + ?Sized,
350    T: Encode<W, E>,
351    &'a T: Encode<W, E>,
352{
353    fn encode(
354        self,
355        encoder: &mut E,
356        out: &mut MaybeUninit<Vector<'static, W>>,
357        constraint: VectorConstraint<W>,
358    ) -> Result<(), EncodeError> {
359        encode_to_vector(self, encoder, out, constraint)
360    }
361}
362
363unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a [T]
364where
365    W: Wire,
366    E: Encoder + ?Sized,
367    T: Encode<W, E>,
368    &'a T: Encode<W, E>,
369{
370    fn encode(
371        self,
372        encoder: &mut E,
373        out: &mut MaybeUninit<Vector<'static, W>>,
374        constraint: VectorConstraint<W>,
375    ) -> Result<(), EncodeError> {
376        encode_to_vector(self, encoder, out, constraint)
377    }
378}
379
380impl<T: FromWire<W>, W> FromWire<Vector<'_, W>> for Vec<T> {
381    fn from_wire(wire: Vector<'_, W>) -> Self {
382        let mut result = Vec::<T>::with_capacity(wire.len());
383        if T::COPY_OPTIMIZATION.is_enabled() {
384            unsafe {
385                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
386            }
387            unsafe {
388                result.set_len(wire.len());
389            }
390            forget(wire);
391        } else {
392            for item in wire.into_iter() {
393                result.push(T::from_wire(item));
394            }
395        }
396        result
397    }
398}
399
400impl<T: IntoNatural> IntoNatural for Vector<'_, T> {
401    type Natural = Vec<T::Natural>;
402}
403
404impl<T: FromWireRef<W>, W> FromWireRef<Vector<'_, W>> for Vec<T> {
405    fn from_wire_ref(wire: &Vector<'_, W>) -> Self {
406        let mut result = Vec::<T>::with_capacity(wire.len());
407        if T::COPY_OPTIMIZATION.is_enabled() {
408            unsafe {
409                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
410            }
411            unsafe {
412                result.set_len(wire.len());
413            }
414        } else {
415            for item in wire.iter() {
416                result.push(T::from_wire_ref(item));
417            }
418        }
419        result
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
426
427    #[test]
428    fn decode_vec() {
429        assert_eq!(
430            chunks![
431                0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
432                0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
433            ]
434            .as_mut_slice()
435            .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
436            .unwrap()
437            .as_slice(),
438            &[wire::Uint32(0x12345678), wire::Uint32(0x9abcdef0)],
439        );
440        assert_eq!(
441            chunks![
442                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
443                0xff, 0xff,
444            ]
445            .as_mut_slice()
446            .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
447            .unwrap()
448            .as_ref(),
449            <[wire::Uint32; _]>::as_slice(&[]),
450        );
451    }
452
453    #[test]
454    fn encode_vec() {
455        assert_eq!(
456            Vec::encode_with_constraint(Some(vec![0x12345678u32, 0x9abcdef0u32]), (1000, ()))
457                .unwrap(),
458            chunks![
459                0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
460                0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
461            ],
462        );
463        assert_eq!(
464            Vec::encode_with_constraint(Some(Vec::<u32>::new()), (1000, ())).unwrap(),
465            chunks![
466                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
467                0xff, 0xff,
468            ],
469        );
470    }
471}