1use core::marker::PhantomData;
6use core::mem::{MaybeUninit, forget, needs_drop};
7use core::ops::Deref;
8use core::ptr::{NonNull, copy_nonoverlapping};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawVector;
14use crate::{
15 Chunk, Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError,
16 Encoder, EncoderExt as _, FromWire, FromWireRef, IntoNatural, Slot, ValidationError, Wire,
17 wire,
18};
19
20#[repr(transparent)]
22pub struct Vector<'de, T> {
23 raw: RawVector<'de, T>,
24}
25
26unsafe impl<T: Wire> Wire for Vector<'static, T> {
27 type Narrowed<'de> = Vector<'de, T::Narrowed<'de>>;
28
29 #[inline]
30 fn zero_padding(out: &mut MaybeUninit<Self>) {
31 munge!(let Self { raw } = out);
32 RawVector::<T>::zero_padding(raw);
33 }
34}
35
36impl<T> Drop for Vector<'_, T> {
37 fn drop(&mut self) {
38 if needs_drop::<T>() {
39 unsafe {
40 self.raw.as_slice_ptr().drop_in_place();
41 }
42 }
43 }
44}
45
46impl<'de, T> Vector<'de, T> {
47 pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
49 munge!(let Self { raw } = out);
50 RawVector::encode_present(raw, len);
51 }
52
53 pub fn len(&self) -> usize {
55 self.raw.len() as usize
56 }
57
58 pub fn is_empty(&self) -> bool {
60 self.len() == 0
61 }
62
63 fn as_slice_ptr(&self) -> NonNull<[T]> {
65 unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
66 }
67
68 pub fn as_slice(&self) -> &[T] {
70 unsafe { self.as_slice_ptr().as_ref() }
71 }
72
73 pub unsafe fn decode_raw<D>(
80 mut slot: Slot<'_, Self>,
81 decoder: &mut D,
82 max_len: u64,
83 ) -> Result<(), DecodeError>
84 where
85 D: Decoder<'de> + ?Sized,
86 T: Decode<D>,
87 {
88 munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
89
90 if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
91 return Err(DecodeError::RequiredValueAbsent);
92 }
93
94 if **len > max_len {
95 return Err(DecodeError::Validation(ValidationError::VectorTooLong {
96 count: **len,
97 limit: max_len,
98 }));
99 }
100
101 let slice = decoder.take_slice_slot::<T>(**len as usize)?;
102 wire::Pointer::set_decoded_slice(ptr, slice);
103
104 Ok(())
105 }
106
107 pub(crate) fn validate_max_len(
109 slot: Slot<'_, Self>,
110 limit: u64,
111 ) -> Result<(), ValidationError> {
112 munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
113 let count: u64 = **len;
114 if count > limit { Err(ValidationError::VectorTooLong { count, limit }) } else { Ok(()) }
115 }
116}
117
118type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
119
120impl<T: Constrained> Constrained for Vector<'_, T> {
121 type Constraint = VectorConstraint<T>;
122
123 fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
124 let (limit, _) = constraint;
125
126 munge!(let Self { raw: RawVector { len, ptr:_ } } = slot);
127 let count = **len;
128 if count > limit {
129 return Err(ValidationError::VectorTooLong { count, limit });
130 }
131
132 Ok(())
133 }
134}
135
136pub struct IntoIter<'de, T> {
138 current: *mut T,
139 remaining: usize,
140 _phantom: PhantomData<&'de mut [Chunk]>,
141}
142
143impl<T> Drop for IntoIter<'_, T> {
144 fn drop(&mut self) {
145 for i in 0..self.remaining {
146 unsafe {
147 self.current.add(i).drop_in_place();
148 }
149 }
150 }
151}
152
153impl<T> Iterator for IntoIter<'_, T> {
154 type Item = T;
155
156 fn next(&mut self) -> Option<Self::Item> {
157 if self.remaining == 0 {
158 None
159 } else {
160 let result = unsafe { self.current.read() };
161 self.current = unsafe { self.current.add(1) };
162 self.remaining -= 1;
163 Some(result)
164 }
165 }
166}
167
168impl<'de, T> IntoIterator for Vector<'de, T> {
169 type IntoIter = IntoIter<'de, T>;
170 type Item = T;
171
172 fn into_iter(self) -> Self::IntoIter {
173 let current = self.raw.as_ptr();
174 let remaining = self.len();
175 forget(self);
176
177 IntoIter { current, remaining, _phantom: PhantomData }
178 }
179}
180
181impl<T> Deref for Vector<'_, T> {
182 type Target = [T];
183
184 fn deref(&self) -> &Self::Target {
185 self.as_slice()
186 }
187}
188
189impl<T: fmt::Debug> fmt::Debug for Vector<'_, T> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 self.as_slice().fmt(f)
192 }
193}
194
195impl<T, U: ?Sized> PartialEq<&U> for Vector<'_, T>
196where
197 for<'de> Vector<'de, T>: PartialEq<U>,
198{
199 fn eq(&self, other: &&U) -> bool {
200 self == *other
201 }
202}
203
204impl<T: PartialEq<U>, U, const N: usize> PartialEq<[U; N]> for Vector<'_, T> {
205 fn eq(&self, other: &[U; N]) -> bool {
206 self.as_slice() == other.as_slice()
207 }
208}
209
210impl<T: PartialEq<U>, U> PartialEq<[U]> for Vector<'_, T> {
211 fn eq(&self, other: &[U]) -> bool {
212 self.as_slice() == other
213 }
214}
215
216impl<T: PartialEq<U>, U> PartialEq<Vector<'_, U>> for Vector<'_, T> {
217 fn eq(&self, other: &Vector<'_, U>) -> bool {
218 self.as_slice() == other.as_slice()
219 }
220}
221
222unsafe impl<'de, D, T> Decode<D> for Vector<'de, T>
223where
224 D: Decoder<'de> + ?Sized,
225 T: Decode<D>,
226{
227 fn decode(
228 mut slot: Slot<'_, Self>,
229 decoder: &mut D,
230 constraint: Self::Constraint,
231 ) -> Result<(), DecodeError> {
232 munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
233
234 let (length_constraint, member_constraint) = constraint;
235
236 if **len > length_constraint {
237 return Err(DecodeError::Validation(ValidationError::VectorTooLong {
238 count: **len,
239 limit: length_constraint,
240 }));
241 }
242
243 if !wire::Pointer::is_encoded_present(ptr.as_mut())? {
244 return Err(DecodeError::RequiredValueAbsent);
245 }
246
247 let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
248 for i in 0..**len as usize {
249 T::decode(slice.index(i), decoder, member_constraint)?;
250 }
251 wire::Pointer::set_decoded_slice(ptr, slice);
252
253 Ok(())
254 }
255}
256
257#[inline]
258fn encode_to_vector<V, W, E, T>(
259 value: V,
260 encoder: &mut E,
261 out: &mut MaybeUninit<Vector<'static, W>>,
262 constraint: VectorConstraint<W>,
263) -> Result<(), EncodeError>
264where
265 V: AsRef<[T]> + IntoIterator,
266 V::IntoIter: ExactSizeIterator,
267 V::Item: Encode<W, E>,
268 W: Wire,
269 E: Encoder + ?Sized,
270 T: Encode<W, E>,
271{
272 let len = value.as_ref().len();
273 let (length_constraint, member_constraint) = constraint;
274
275 if len as u64 > length_constraint {
276 return Err(EncodeError::Validation(ValidationError::VectorTooLong {
277 count: len as u64,
278 limit: length_constraint,
279 }));
280 }
281
282 if T::COPY_OPTIMIZATION.is_enabled() {
283 let slice = value.as_ref();
284 let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
288 encoder.write(bytes);
289 } else {
291 encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
292 }
293 Vector::encode_present(out, len as u64);
294 Ok(())
295}
296
297unsafe impl<W, E, T> Encode<Vector<'static, W>, E> for Vec<T>
298where
299 W: Wire,
300 E: Encoder + ?Sized,
301 T: Encode<W, E>,
302{
303 fn encode(
304 self,
305 encoder: &mut E,
306 out: &mut MaybeUninit<Vector<'static, W>>,
307 constraint: VectorConstraint<W>,
308 ) -> Result<(), EncodeError> {
309 encode_to_vector(self, encoder, out, constraint)
310 }
311}
312
313unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a Vec<T>
314where
315 W: Wire,
316 E: Encoder + ?Sized,
317 T: Encode<W, E>,
318 &'a T: Encode<W, E>,
319{
320 fn encode(
321 self,
322 encoder: &mut E,
323 out: &mut MaybeUninit<Vector<'static, W>>,
324 constraint: VectorConstraint<W>,
325 ) -> Result<(), EncodeError> {
326 encode_to_vector(self, encoder, out, constraint)
327 }
328}
329
330unsafe impl<W, E, T, const N: usize> Encode<Vector<'static, W>, E> for [T; N]
331where
332 W: Wire,
333 E: Encoder + ?Sized,
334 T: Encode<W, E>,
335{
336 fn encode(
337 self,
338 encoder: &mut E,
339 out: &mut MaybeUninit<Vector<'static, W>>,
340 constraint: VectorConstraint<W>,
341 ) -> Result<(), EncodeError> {
342 encode_to_vector(self, encoder, out, constraint)
343 }
344}
345
346unsafe impl<'a, W, E, T, const N: usize> Encode<Vector<'static, W>, E> for &'a [T; N]
347where
348 W: Wire,
349 E: Encoder + ?Sized,
350 T: Encode<W, E>,
351 &'a T: Encode<W, E>,
352{
353 fn encode(
354 self,
355 encoder: &mut E,
356 out: &mut MaybeUninit<Vector<'static, W>>,
357 constraint: VectorConstraint<W>,
358 ) -> Result<(), EncodeError> {
359 encode_to_vector(self, encoder, out, constraint)
360 }
361}
362
363unsafe impl<'a, W, E, T> Encode<Vector<'static, W>, E> for &'a [T]
364where
365 W: Wire,
366 E: Encoder + ?Sized,
367 T: Encode<W, E>,
368 &'a T: Encode<W, E>,
369{
370 fn encode(
371 self,
372 encoder: &mut E,
373 out: &mut MaybeUninit<Vector<'static, W>>,
374 constraint: VectorConstraint<W>,
375 ) -> Result<(), EncodeError> {
376 encode_to_vector(self, encoder, out, constraint)
377 }
378}
379
380impl<T: FromWire<W>, W> FromWire<Vector<'_, W>> for Vec<T> {
381 fn from_wire(wire: Vector<'_, W>) -> Self {
382 let mut result = Vec::<T>::with_capacity(wire.len());
383 if T::COPY_OPTIMIZATION.is_enabled() {
384 unsafe {
385 copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
386 }
387 unsafe {
388 result.set_len(wire.len());
389 }
390 forget(wire);
391 } else {
392 for item in wire.into_iter() {
393 result.push(T::from_wire(item));
394 }
395 }
396 result
397 }
398}
399
400impl<T: IntoNatural> IntoNatural for Vector<'_, T> {
401 type Natural = Vec<T::Natural>;
402}
403
404impl<T: FromWireRef<W>, W> FromWireRef<Vector<'_, W>> for Vec<T> {
405 fn from_wire_ref(wire: &Vector<'_, W>) -> Self {
406 let mut result = Vec::<T>::with_capacity(wire.len());
407 if T::COPY_OPTIMIZATION.is_enabled() {
408 unsafe {
409 copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
410 }
411 unsafe {
412 result.set_len(wire.len());
413 }
414 } else {
415 for item in wire.iter() {
416 result.push(T::from_wire_ref(item));
417 }
418 }
419 result
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
426
427 #[test]
428 fn decode_vec() {
429 assert_eq!(
430 chunks![
431 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
432 0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
433 ]
434 .as_mut_slice()
435 .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
436 .unwrap()
437 .as_slice(),
438 &[wire::Uint32(0x12345678), wire::Uint32(0x9abcdef0)],
439 );
440 assert_eq!(
441 chunks![
442 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
443 0xff, 0xff,
444 ]
445 .as_mut_slice()
446 .decode_with_constraint::<wire::Vector<'_, wire::Uint32>>((1000, ()))
447 .unwrap()
448 .as_ref(),
449 <[wire::Uint32; _]>::as_slice(&[]),
450 );
451 }
452
453 #[test]
454 fn encode_vec() {
455 assert_eq!(
456 Vec::encode_with_constraint(Some(vec![0x12345678u32, 0x9abcdef0u32]), (1000, ()))
457 .unwrap(),
458 chunks![
459 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
460 0xff, 0xff, 0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a,
461 ],
462 );
463 assert_eq!(
464 Vec::encode_with_constraint(Some(Vec::<u32>::new()), (1000, ())).unwrap(),
465 chunks![
466 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
467 0xff, 0xff,
468 ],
469 );
470 }
471}