1use core::marker::PhantomData;
6use core::mem::{MaybeUninit, forget, needs_drop};
7use core::ops::Deref;
8use core::ptr::{NonNull, copy_nonoverlapping};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawWireVector;
14use crate::{
15    Chunk, Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError,
16    Encoder, EncoderExt as _, FromWire, FromWireRef, IntoNatural, Slot, ValidationError, Wire,
17    WirePointer,
18};
19
20#[repr(transparent)]
22pub struct WireVector<'de, T> {
23    raw: RawWireVector<'de, T>,
24}
25
26unsafe impl<T: Wire> Wire for WireVector<'static, T> {
27    type Owned<'de> = WireVector<'de, T::Owned<'de>>;
28
29    #[inline]
30    fn zero_padding(out: &mut MaybeUninit<Self>) {
31        munge!(let Self { raw } = out);
32        RawWireVector::<T>::zero_padding(raw);
33    }
34}
35
36impl<T> Drop for WireVector<'_, T> {
37    fn drop(&mut self) {
38        if needs_drop::<T>() {
39            unsafe {
40                self.raw.as_slice_ptr().drop_in_place();
41            }
42        }
43    }
44}
45
46impl<T> WireVector<'_, T> {
47    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
49        munge!(let Self { raw } = out);
50        RawWireVector::encode_present(raw, len);
51    }
52
53    pub fn len(&self) -> usize {
55        self.raw.len() as usize
56    }
57
58    pub fn is_empty(&self) -> bool {
60        self.len() == 0
61    }
62
63    fn as_slice_ptr(&self) -> NonNull<[T]> {
65        unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
66    }
67
68    pub fn as_slice(&self) -> &[T] {
70        unsafe { self.as_slice_ptr().as_ref() }
71    }
72
73    pub unsafe fn decode_raw<D>(
80        mut slot: Slot<'_, Self>,
81        mut decoder: &mut D,
82        max_len: u64,
83    ) -> Result<(), DecodeError>
84    where
85        D: Decoder + ?Sized,
86        T: Decode<D>,
87    {
88        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
89
90        if !WirePointer::is_encoded_present(ptr.as_mut())? {
91            return Err(DecodeError::RequiredValueAbsent);
92        }
93
94        if **len > max_len {
95            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
96                count: **len,
97                limit: max_len,
98            }));
99        }
100
101        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
102        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
103
104        Ok(())
105    }
106
107    pub(crate) fn validate_max_len(
109        slot: Slot<'_, Self>,
110        limit: u64,
111    ) -> Result<(), crate::ValidationError> {
112        munge!(let Self { raw: RawWireVector { len, ptr:_ } } = slot);
113        let count: u64 = **len;
114        if count > limit { Err(ValidationError::VectorTooLong { count, limit }) } else { Ok(()) }
115    }
116}
117
118type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
119
120impl<T: Constrained> Constrained for WireVector<'_, T> {
121    type Constraint = VectorConstraint<T>;
122
123    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
124        let (limit, _) = constraint;
125
126        munge!(let Self { raw: RawWireVector { len, ptr:_ } } = slot);
127        let count = **len;
128        if count > limit {
129            return Err(ValidationError::VectorTooLong { count, limit });
130        }
131
132        Ok(())
133    }
134}
135
136pub struct IntoIter<'de, T> {
138    current: *mut T,
139    remaining: usize,
140    _phantom: PhantomData<&'de mut [Chunk]>,
141}
142
143impl<T> Drop for IntoIter<'_, T> {
144    fn drop(&mut self) {
145        for i in 0..self.remaining {
146            unsafe {
147                self.current.add(i).drop_in_place();
148            }
149        }
150    }
151}
152
153impl<T> Iterator for IntoIter<'_, T> {
154    type Item = T;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.remaining == 0 {
158            None
159        } else {
160            let result = unsafe { self.current.read() };
161            self.current = unsafe { self.current.add(1) };
162            self.remaining -= 1;
163            Some(result)
164        }
165    }
166}
167
168impl<'de, T> IntoIterator for WireVector<'de, T> {
169    type IntoIter = IntoIter<'de, T>;
170    type Item = T;
171
172    fn into_iter(self) -> Self::IntoIter {
173        let current = self.raw.as_ptr();
174        let remaining = self.len();
175        forget(self);
176
177        IntoIter { current, remaining, _phantom: PhantomData }
178    }
179}
180
181impl<T> Deref for WireVector<'_, T> {
182    type Target = [T];
183
184    fn deref(&self) -> &Self::Target {
185        self.as_slice()
186    }
187}
188
189impl<T: fmt::Debug> fmt::Debug for WireVector<'_, T> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        self.as_slice().fmt(f)
192    }
193}
194
195unsafe impl<D: Decoder + ?Sized, T: Decode<D>> Decode<D> for WireVector<'static, T> {
196    fn decode(
197        mut slot: Slot<'_, Self>,
198        mut decoder: &mut D,
199        constraint: <Self as Constrained>::Constraint,
200    ) -> Result<(), DecodeError> {
201        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
202
203        let (length_constraint, member_constraint) = constraint;
204
205        if **len > length_constraint {
206            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
207                count: **len,
208                limit: length_constraint,
209            }));
210        }
211
212        if !WirePointer::is_encoded_present(ptr.as_mut())? {
213            return Err(DecodeError::RequiredValueAbsent);
214        }
215
216        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
217        for i in 0..**len as usize {
218            T::decode(slice.index(i), decoder, member_constraint)?;
219        }
220        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
221
222        Ok(())
223    }
224}
225
226#[inline]
227fn encode_to_vector<V, W, E, T>(
228    value: V,
229    encoder: &mut E,
230    out: &mut MaybeUninit<WireVector<'_, W>>,
231    constraint: VectorConstraint<W>,
232) -> Result<(), EncodeError>
233where
234    V: AsRef<[T]> + IntoIterator,
235    V::IntoIter: ExactSizeIterator,
236    V::Item: Encode<W, E>,
237    W: Constrained + Wire,
238    E: Encoder + ?Sized,
239    T: Encode<W, E>,
240{
241    let len = value.as_ref().len();
242    let (length_constraint, member_constraint) = constraint;
243
244    if len as u64 > length_constraint {
245        return Err(EncodeError::Validation(ValidationError::VectorTooLong {
246            count: len as u64,
247            limit: length_constraint,
248        }));
249    }
250
251    if T::COPY_OPTIMIZATION.is_enabled() {
252        let slice = value.as_ref();
253        let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
257        encoder.write(bytes);
258        } else {
260        encoder.encode_next_iter(value.into_iter(), member_constraint)?;
261    }
262    WireVector::encode_present(out, len as u64);
263    Ok(())
264}
265
266unsafe impl<W, E, T> Encode<WireVector<'static, W>, E> for Vec<T>
267where
268    W: Constrained + Wire,
269    E: Encoder + ?Sized,
270    T: Encode<W, E>,
271{
272    fn encode(
273        self,
274        encoder: &mut E,
275        out: &mut MaybeUninit<WireVector<'static, W>>,
276        constraint: VectorConstraint<W>,
277    ) -> Result<(), EncodeError> {
278        encode_to_vector(self, encoder, out, constraint)
279    }
280}
281
282unsafe impl<'a, W, E, T> Encode<WireVector<'static, W>, E> for &'a Vec<T>
283where
284    W: Constrained + Wire,
285    E: Encoder + ?Sized,
286    T: Encode<W, E>,
287    &'a T: Encode<W, E>,
288{
289    fn encode(
290        self,
291        encoder: &mut E,
292        out: &mut MaybeUninit<WireVector<'static, W>>,
293        constraint: VectorConstraint<W>,
294    ) -> Result<(), EncodeError> {
295        encode_to_vector(self, encoder, out, constraint)
296    }
297}
298
299unsafe impl<W, E, T, const N: usize> Encode<WireVector<'static, W>, E> for [T; N]
300where
301    W: Constrained + Wire,
302    E: Encoder + ?Sized,
303    T: Encode<W, E>,
304{
305    fn encode(
306        self,
307        encoder: &mut E,
308        out: &mut MaybeUninit<WireVector<'static, W>>,
309        constraint: VectorConstraint<W>,
310    ) -> Result<(), EncodeError> {
311        encode_to_vector(self, encoder, out, constraint)
312    }
313}
314
315unsafe impl<'a, W, E, T, const N: usize> Encode<WireVector<'static, W>, E> for &'a [T; N]
316where
317    W: Constrained + Wire,
318    E: Encoder + ?Sized,
319    T: Encode<W, E>,
320    &'a T: Encode<W, E>,
321{
322    fn encode(
323        self,
324        encoder: &mut E,
325        out: &mut MaybeUninit<WireVector<'static, W>>,
326        constraint: VectorConstraint<W>,
327    ) -> Result<(), EncodeError> {
328        encode_to_vector(self, encoder, out, constraint)
329    }
330}
331
332unsafe impl<'a, W, E, T> Encode<WireVector<'static, W>, E> for &'a [T]
333where
334    W: Constrained + Wire,
335    E: Encoder + ?Sized,
336    T: Encode<W, E>,
337    &'a T: Encode<W, E>,
338{
339    fn encode(
340        self,
341        encoder: &mut E,
342        out: &mut MaybeUninit<WireVector<'static, W>>,
343        constraint: VectorConstraint<W>,
344    ) -> Result<(), EncodeError> {
345        encode_to_vector(self, encoder, out, constraint)
346    }
347}
348
349impl<T: FromWire<W>, W> FromWire<WireVector<'_, W>> for Vec<T> {
350    fn from_wire(wire: WireVector<'_, W>) -> Self {
351        let mut result = Vec::<T>::with_capacity(wire.len());
352        if T::COPY_OPTIMIZATION.is_enabled() {
353            unsafe {
354                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
355            }
356            unsafe {
357                result.set_len(wire.len());
358            }
359            forget(wire);
360        } else {
361            for item in wire.into_iter() {
362                result.push(T::from_wire(item));
363            }
364        }
365        result
366    }
367}
368
369impl<T: IntoNatural> IntoNatural for WireVector<'_, T> {
370    type Natural = Vec<T::Natural>;
371}
372
373impl<T: FromWireRef<W>, W> FromWireRef<WireVector<'_, W>> for Vec<T> {
374    fn from_wire_ref(wire: &WireVector<'_, W>) -> Self {
375        let mut result = Vec::<T>::with_capacity(wire.len());
376        if T::COPY_OPTIMIZATION.is_enabled() {
377            unsafe {
378                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
379            }
380            unsafe {
381                result.set_len(wire.len());
382            }
383        } else {
384            for item in wire.iter() {
385                result.push(T::from_wire_ref(item));
386            }
387        }
388        result
389    }
390}