Skip to main content

fidl_next_codec/wire/vec/
required.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::mem::{MaybeUninit, forget, needs_drop};
7use core::ops::Deref;
8use core::ptr::{NonNull, copy_nonoverlapping};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawVector;
14use crate::{
15    Chunk, Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError,
16    Encoder, EncoderExt as _, FromWire, FromWireRef, IntoNatural, Slot, ValidationError, Wire,
17    wire,
18};
19
20/// A FIDL vector
21#[repr(transparent)]
22pub struct Vector<'de, T> {
23    raw: RawVector<'de, T>,
24}
25
26// SAFETY: `Vector` is `repr(transparent)` over `RawVector`, which implements `Wire`.
27// Lifetime erasure is safe since `Vector` is covariant over its lifetime.
28unsafe impl<T: Wire> Wire for Vector<'static, T> {
29    type Narrowed<'de> = Vector<'de, T::Narrowed<'de>>;
30
31    #[inline]
32    fn zero_padding(out: &mut MaybeUninit<Self>) {
33        munge!(let Self { raw } = out);
34        RawVector::<T>::zero_padding(raw);
35    }
36}
37
38impl<T> Drop for Vector<'_, T> {
39    fn drop(&mut self) {
40        if needs_drop::<T>() {
41            // SAFETY: If `T` needs to be dropped, the pointer has been decoded and points to a
42            // valid slice of initialized `T` elements.
43            unsafe {
44                self.raw.as_slice_ptr().drop_in_place();
45            }
46        }
47    }
48}
49
50impl<'de, T> Vector<'de, T> {
51    /// Encodes that a vector is present in a slot.
52    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
53        munge!(let Self { raw } = out);
54        RawVector::encode_present(raw, len);
55    }
56
57    /// Returns the length of the vector in elements.
58    pub fn len(&self) -> usize {
59        self.raw.len() as usize
60    }
61
62    /// Returns whether the vector is empty.
63    pub fn is_empty(&self) -> bool {
64        self.len() == 0
65    }
66
67    /// Returns a pointer to the elements of the vector.
68    fn as_slice_ptr(&self) -> NonNull<[T]> {
69        // SAFETY: The underlying pointer is guaranteed to be non-null for a valid `Vector`.
70        unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
71    }
72
73    /// Returns a slice of the elements of the vector.
74    pub fn as_slice(&self) -> &[T] {
75        // SAFETY: The pointer is aligned, initialized, and the lifetime of the reference
76        // is bound to the lifetime of `self`.
77        unsafe { self.as_slice_ptr().as_ref() }
78    }
79
80    /// Decodes a wire vector which contains raw data.
81    ///
82    /// # Safety
83    ///
84    /// The elements of the wire vector must not need to be individually decoded, and must always be
85    /// valid.
86    pub unsafe fn decode_raw<D>(
87        mut slot: Slot<'_, Self>,
88        decoder: &mut D,
89        max_len: u64,
90    ) -> Result<(), DecodeError>
91    where
92        D: Decoder<'de> + ?Sized,
93        T: Decode<D>,
94    {
95        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
96
97        if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
98            return Err(DecodeError::RequiredValueAbsent);
99        }
100
101        if **len > max_len {
102            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
103                count: **len,
104                limit: max_len,
105            }));
106        }
107
108        let slice = decoder.take_slice_slot::<T>(**len as usize)?;
109        wire::Pointer::set_decoded_slice(ptr, slice);
110
111        Ok(())
112    }
113
114    /// Validate that this vector's length falls within the limit.
115    pub(crate) fn validate_max_len(
116        slot: Slot<'_, Self>,
117        limit: u64,
118    ) -> Result<(), ValidationError> {
119        munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
120        let count: u64 = **len;
121        if count > limit { Err(ValidationError::VectorTooLong { count, limit }) } else { Ok(()) }
122    }
123}
124
125type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
126
127impl<T: Constrained> Constrained for Vector<'_, T> {
128    type Constraint = VectorConstraint<T>;
129
130    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
131        let (limit, _) = constraint;
132
133        munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
134        let count = **len;
135        if count > limit {
136            return Err(ValidationError::VectorTooLong { count, limit });
137        }
138
139        Ok(())
140    }
141}
142
143/// An iterator over the items of a `WireVector`.
144pub struct IntoIter<'de, T> {
145    current: *mut T,
146    remaining: usize,
147    _phantom: PhantomData<&'de mut [Chunk]>,
148}
149
150impl<T> Drop for IntoIter<'_, T> {
151    fn drop(&mut self) {
152        for i in 0..self.remaining {
153            // SAFETY: `self.current.add(i)` points to an initialized element of `T` within the
154            // original vector's allocation that has not yet been yielded by the iterator.
155            unsafe {
156                self.current.add(i).drop_in_place();
157            }
158        }
159    }
160}
161
162impl<T> Iterator for IntoIter<'_, T> {
163    type Item = T;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        if self.remaining == 0 {
167            None
168        } else {
169            // SAFETY: `self.current` points to a valid, initialized element of `T` that has not
170            // yet been read. We ownership-transfer it, and decrement `self.remaining` to ensure
171            // it is not dropped again.
172            let result = unsafe { self.current.read() };
173            // SAFETY: `self.current` is within the bounds of the allocated slice, so advancing
174            // it by 1 is safe (it may point to one-past-the-end if it was the last element).
175            self.current = unsafe { self.current.add(1) };
176            self.remaining -= 1;
177            Some(result)
178        }
179    }
180}
181
182impl<'de, T> IntoIterator for Vector<'de, T> {
183    type IntoIter = IntoIter<'de, T>;
184    type Item = T;
185
186    fn into_iter(self) -> Self::IntoIter {
187        let current = self.raw.as_ptr();
188        let remaining = self.len();
189        forget(self);
190
191        IntoIter { current, remaining, _phantom: PhantomData }
192    }
193}
194
195impl<T> Deref for Vector<'_, T> {
196    type Target = [T];
197
198    fn deref(&self) -> &Self::Target {
199        self.as_slice()
200    }
201}
202
203impl<T: fmt::Debug> fmt::Debug for Vector<'_, T> {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        self.as_slice().fmt(f)
206    }
207}
208
209impl<T, U: ?Sized> PartialEq<&U> for Vector<'_, T>
210where
211    for<'de> Vector<'de, T>: PartialEq<U>,
212{
213    fn eq(&self, other: &&U) -> bool {
214        self == *other
215    }
216}
217
218impl<T: PartialEq<U>, U, const N: usize> PartialEq<[U; N]> for Vector<'_, T> {
219    fn eq(&self, other: &[U; N]) -> bool {
220        self.as_slice() == other.as_slice()
221    }
222}
223
224impl<T: PartialEq<U>, U> PartialEq<[U]> for Vector<'_, T> {
225    fn eq(&self, other: &[U]) -> bool {
226        self.as_slice() == other
227    }
228}
229
230impl<T: PartialEq<U>, U> PartialEq<Vector<'_, U>> for Vector<'_, T> {
231    fn eq(&self, other: &Vector<'_, U>) -> bool {
232        self.as_slice() == other.as_slice()
233    }
234}
235
236// SAFETY: If `decode` returns `Ok`, the `Vector` has been successfully decoded,
237// and the underlying pointer is updated to point to a successfully decoded slice
238// of `T` allocated by the decoder.
239unsafe impl<'de, D, T> Decode<D> for Vector<'de, T>
240where
241    D: Decoder<'de> + ?Sized,
242    T: Decode<D>,
243{
244    fn decode(
245        mut slot: Slot<'_, Self>,
246        decoder: &mut D,
247        constraint: Self::Constraint,
248    ) -> Result<(), DecodeError> {
249        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
250
251        let (length_constraint, member_constraint) = constraint;
252
253        if **len > length_constraint {
254            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
255                count: **len,
256                limit: length_constraint,
257            }));
258        }
259
260        if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
261            return Err(DecodeError::RequiredValueAbsent);
262        }
263
264        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
265        for i in 0..**len as usize {
266            T::decode(slice.index(i), decoder, member_constraint)?;
267        }
268        wire::Pointer::set_decoded_slice(ptr, slice);
269
270        Ok(())
271    }
272}
273
274#[inline]
275fn encode_to_vector<V, W, E, T>(
276    value: V,
277    encoder: &mut E,
278    out: &mut MaybeUninit<Vector<'static, W>>,
279    constraint: VectorConstraint<W>,
280) -> Result<(), EncodeError>
281where
282    V: AsRef<[T]> + IntoIterator,
283    V::IntoIter: ExactSizeIterator,
284    V::Item: Encode<W, E>,
285    W: Wire,
286    E: Encoder + ?Sized,
287    T: Encode<W, E>,
288{
289    let len = value.as_ref().len();
290    let (length_constraint, member_constraint) = constraint;
291
292    if len as u64 > length_constraint {
293        return Err(EncodeError::Validation(ValidationError::VectorTooLong {
294            count: len as u64,
295            limit: length_constraint,
296        }));
297    }
298
299    if T::COPY_OPTIMIZATION.is_enabled() {
300        let slice = value.as_ref();
301        // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit bytes
302        // and can be copied directly to the output instead of calling `encode`. This means that we
303        // may cast `&[T]` to `&[u8]` and write those bytes.
304        let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
305        encoder.write(bytes);
306        // TODO: copy-optimized encodings don't currently check constraints
307    } else {
308        encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
309    }
310    Vector::encode_present(out, len as u64);
311    Ok(())
312}
313
314// SAFETY: `encode` delegates to `encode_to_vector`, which initializes the output.
315unsafe impl<W, E, T> Encode<Vector<'static, W>, E> for Vec<T>
316where
317    W: Wire,
318    E: Encoder + ?Sized,
319    T: Encode<W, E>,
320{
321    fn encode(
322        self,
323        encoder: &mut E,
324        out: &mut MaybeUninit<Vector<'static, W>>,
325        constraint: VectorConstraint<W>,
326    ) -> Result<(), EncodeError> {
327        encode_to_vector(self, encoder, out, constraint)
328    }
329}
330
331// SAFETY: `encode` delegates to `encode_to_vector`, which initializes the output.
332unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a Vec<T>
333where
334    W: Wire,
335    E: Encoder + ?Sized,
336    T: Encode<W, E>,
337    &'a T: Encode<W, E>,
338{
339    fn encode(
340        self,
341        encoder: &mut E,
342        out: &mut MaybeUninit<Vector<'static, W>>,
343        constraint: VectorConstraint<W>,
344    ) -> Result<(), EncodeError> {
345        encode_to_vector(self, encoder, out, constraint)
346    }
347}
348
349// SAFETY: `encode` delegates to `encode_to_vector`, which initializes the output.
350unsafe impl<W, E, T, const N: usize> Encode<Vector<'static, W>, E> for [T; N]
351where
352    W: Wire,
353    E: Encoder + ?Sized,
354    T: Encode<W, E>,
355{
356    fn encode(
357        self,
358        encoder: &mut E,
359        out: &mut MaybeUninit<Vector<'static, W>>,
360        constraint: VectorConstraint<W>,
361    ) -> Result<(), EncodeError> {
362        encode_to_vector(self, encoder, out, constraint)
363    }
364}
365
366// SAFETY: `encode` delegates to `encode_to_vector`, which initializes the output.
367unsafe impl<'a, W, E, T, const N: usize> Encode<Vector<'static, W>, E> for &'a [T; N]
368where
369    W: Wire,
370    E: Encoder + ?Sized,
371    T: Encode<W, E>,
372    &'a T: Encode<W, E>,
373{
374    fn encode(
375        self,
376        encoder: &mut E,
377        out: &mut MaybeUninit<Vector<'static, W>>,
378        constraint: VectorConstraint<W>,
379    ) -> Result<(), EncodeError> {
380        encode_to_vector(self, encoder, out, constraint)
381    }
382}
383
384// SAFETY: `encode` delegates to `encode_to_vector`, which initializes the output.
385unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a [T]
386where
387    W: Wire,
388    E: Encoder + ?Sized,
389    T: Encode<W, E>,
390    &'a T: Encode<W, E>,
391{
392    fn encode(
393        self,
394        encoder: &mut E,
395        out: &mut MaybeUninit<Vector<'static, W>>,
396        constraint: VectorConstraint<W>,
397    ) -> Result<(), EncodeError> {
398        encode_to_vector(self, encoder, out, constraint)
399    }
400}
401
402impl<T: FromWire<W>, W> FromWire<Vector<'_, W>> for Vec<T> {
403    fn from_wire(wire: Vector<'_, W>) -> Self {
404        let mut result = Vec::<T>::with_capacity(wire.len());
405        if T::COPY_OPTIMIZATION.is_enabled() {
406            // SAFETY: `T` has copy optimization enabled, meaning it is layout-compatible with `W`
407            // and can be safely copied. The destination buffer has been allocated with sufficient
408            // capacity, and the source and destination do not overlap.
409            unsafe {
410                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
411            }
412            // SAFETY: We have just initialized the first `wire.len()` elements of `result`
413            // via `copy_nonoverlapping`.
414            unsafe {
415                result.set_len(wire.len());
416            }
417            forget(wire);
418        } else {
419            for item in wire.into_iter() {
420                result.push(T::from_wire(item));
421            }
422        }
423        result
424    }
425}
426
427impl<T: IntoNatural> IntoNatural for Vector<'_, T> {
428    type Natural = Vec<T::Natural>;
429}
430
431impl<T: FromWireRef<W>, W> FromWireRef<Vector<'_, W>> for Vec<T> {
432    fn from_wire_ref(wire: &Vector<'_, W>) -> Self {
433        let mut result = Vec::<T>::with_capacity(wire.len());
434        if T::COPY_OPTIMIZATION.is_enabled() {
435            // SAFETY: `T` has copy optimization enabled, meaning it is layout-compatible with `W`
436            // and can be safely copied. The destination buffer has been allocated with sufficient
437            // capacity, and the source and destination do not overlap.
438            unsafe {
439                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
440            }
441            // SAFETY: We have just initialized the first `wire.len()` elements of `result`
442            // via `copy_nonoverlapping`.
443            unsafe {
444                result.set_len(wire.len());
445            }
446        } else {
447            for item in wire.iter() {
448                result.push(T::from_wire_ref(item));
449            }
450        }
451        result
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
458
459    #[test]
460    fn decode_vec() {
461        assert_eq!(
462            chunks![
463                0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
464                0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
465            ]
466            .as_mut_slice()
467            .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
468            .unwrap()
469            .as_slice(),
470            &[wire::Uint32(0x12345678), wire::Uint32(0x9abcdef0)],
471        );
472        assert_eq!(
473            chunks![
474                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
475                0xff, 0xff,
476            ]
477            .as_mut_slice()
478            .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
479            .unwrap()
480            .as_ref(),
481            <[wire::Uint32; _]>::as_slice(&[]),
482        );
483    }
484
485    #[test]
486    fn encode_vec() {
487        assert_eq!(
488            Vec::encode_with_constraint(Some(vec![0x12345678u32, 0x9abcdef0u32]), (1000, ()))
489                .unwrap(),
490            chunks![
491                0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
492                0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
493            ],
494        );
495        assert_eq!(
496            Vec::encode_with_constraint(Some(Vec::<u32>::new()), (1000, ())).unwrap(),
497            chunks![
498                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
499                0xff, 0xff,
500            ],
501        );
502    }
503}