1use core::mem::MaybeUninit;
8use core::ops::{Deref, DerefMut};
9use kalloc::{AllocError, Allocator, Box, DefaultAllocator};
10
11#[repr(C)]
21pub struct InlineArray<T, const N: usize, A: Allocator = DefaultAllocator> {
22 count: usize,
25
26 ptr: *mut T,
32
33 inline_storage: [MaybeUninit<T>; N],
37
38 allocator: A,
42}
43
44zr::static_assert!(core::mem::size_of::<InlineArray<u32, 4>>() == 32);
45zr::static_assert!(core::mem::align_of::<InlineArray<u32, 4>>() == 8);
46
47impl<T, const N: usize, A: Allocator> InlineArray<T, N, A> {
48 pub fn try_new_in_with<F>(count: usize, allocator: A, mut f: F) -> Result<Self, AllocError>
51 where
52 F: FnMut() -> T,
53 {
54 if count <= N {
55 let mut inline_storage = [const { MaybeUninit::uninit() }; N];
56 for i in 0..count {
57 inline_storage[i].write(f());
58 }
59 Ok(InlineArray { count, ptr: core::ptr::null_mut(), inline_storage, allocator })
60 } else {
61 let mut heap_data = Box::try_new_uninit_slice_in(count, allocator)?;
62 for i in 0..count {
63 heap_data[i].write(f());
64 }
65 let (fat_ptr, allocator) =
67 Box::into_raw_with_allocator(unsafe { heap_data.assume_init() });
68 let ptr = fat_ptr as *mut T;
69 Ok(InlineArray {
70 count,
71 ptr,
72 inline_storage: [const { MaybeUninit::uninit() }; N],
73 allocator,
74 })
75 }
76 }
77
78 pub fn try_new_in(count: usize, allocator: A) -> Result<Self, AllocError>
81 where
82 T: Default,
83 {
84 Self::try_new_in_with(count, allocator, T::default)
85 }
86
87 const fn is_inline(&self) -> bool {
89 self.count <= N
90 }
91
92 pub const fn len(&self) -> usize {
94 self.count
95 }
96
97 pub const fn is_empty(&self) -> bool {
99 self.count == 0
100 }
101
102 pub fn as_slice(&self) -> &[T] {
104 if self.is_inline() {
105 unsafe { self.inline_storage[..self.count].assume_init_ref() }
107 } else {
108 unsafe { core::slice::from_raw_parts(self.ptr, self.count) }
110 }
111 }
112
113 pub fn as_mut_slice(&mut self) -> &mut [T] {
115 if self.is_inline() {
116 unsafe { self.inline_storage[..self.count].assume_init_mut() }
118 } else {
119 unsafe { core::slice::from_raw_parts_mut(self.ptr, self.count) }
121 }
122 }
123}
124
125impl<T, const N: usize> InlineArray<T, N, DefaultAllocator> {
126 pub fn try_new(count: usize) -> Result<Self, AllocError>
129 where
130 T: Default,
131 {
132 Self::try_new_in(count, DefaultAllocator)
133 }
134
135 pub fn try_new_with<F>(count: usize, f: F) -> Result<Self, AllocError>
138 where
139 F: FnMut() -> T,
140 {
141 Self::try_new_in_with(count, DefaultAllocator, f)
142 }
143}
144
145impl<T, const N: usize, A: Allocator> Deref for InlineArray<T, N, A> {
146 type Target = [T];
147 fn deref(&self) -> &Self::Target {
148 self.as_slice()
149 }
150}
151
152impl<T, const N: usize, A: Allocator> DerefMut for InlineArray<T, N, A> {
153 fn deref_mut(&mut self) -> &mut Self::Target {
154 self.as_mut_slice()
155 }
156}
157
158impl<T, const N: usize, A: Allocator> Drop for InlineArray<T, N, A> {
159 fn drop(&mut self) {
160 if self.is_inline() {
161 unsafe {
162 self.inline_storage[..self.count].assume_init_drop();
163 }
164 } else {
165 unsafe {
169 let _ = Box::from_raw_in(
170 core::ptr::slice_from_raw_parts_mut(self.ptr, self.count),
171 self.allocator.clone(),
172 );
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use core::cell::Cell;
182 use core::ptr::NonNull;
183
184 #[derive(Debug, PartialEq, Eq)]
185 struct TestState {
186 live_obj_count: Cell<usize>,
187 ctor_count: Cell<usize>,
188 dtor_count: Cell<usize>,
189 alloc_count: Cell<usize>,
190 fail_threshold: Cell<usize>,
191 }
192
193 impl Default for TestState {
194 fn default() -> Self {
195 Self {
196 live_obj_count: Cell::new(0),
197 ctor_count: Cell::new(0),
198 dtor_count: Cell::new(0),
199 alloc_count: Cell::new(0),
200 fail_threshold: Cell::new(usize::MAX),
201 }
202 }
203 }
204
205 #[derive(Clone)]
206 struct TestAllocator<'a> {
207 state: &'a TestState,
208 }
209
210 impl<'a> Allocator for TestAllocator<'a> {
211 fn allocate(&self, layout: core::alloc::Layout) -> Result<NonNull<[u8]>, AllocError> {
212 let current = self.state.alloc_count.get();
213 self.state.alloc_count.set(current + 1);
214 if current >= self.state.fail_threshold.get() {
215 return Err(AllocError);
216 }
217 DefaultAllocator::default().allocate(layout)
218 }
219
220 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: core::alloc::Layout) {
221 unsafe { DefaultAllocator::default().deallocate(ptr, layout) }
222 }
223
224 unsafe fn grow(
225 &self,
226 ptr: NonNull<u8>,
227 old_layout: core::alloc::Layout,
228 new_layout: core::alloc::Layout,
229 ) -> Result<NonNull<[u8]>, AllocError> {
230 let current = self.state.alloc_count.get();
231 self.state.alloc_count.set(current + 1);
232 if current >= self.state.fail_threshold.get() {
233 return Err(AllocError);
234 }
235 unsafe { DefaultAllocator::default().grow(ptr, old_layout, new_layout) }
236 }
237
238 unsafe fn shrink(
239 &self,
240 ptr: NonNull<u8>,
241 old_layout: core::alloc::Layout,
242 new_layout: core::alloc::Layout,
243 ) -> Result<NonNull<[u8]>, AllocError> {
244 let current = self.state.alloc_count.get();
245 self.state.alloc_count.set(current + 1);
246 if current >= self.state.fail_threshold.get() {
247 return Err(AllocError);
248 }
249 unsafe { DefaultAllocator::default().shrink(ptr, old_layout, new_layout) }
250 }
251
252 fn allocate_zeroed(
253 &self,
254 layout: core::alloc::Layout,
255 ) -> Result<NonNull<[u8]>, AllocError> {
256 let current = self.state.alloc_count.get();
257 self.state.alloc_count.set(current + 1);
258 if current >= self.state.fail_threshold.get() {
259 return Err(AllocError);
260 }
261 DefaultAllocator::default().allocate_zeroed(layout)
262 }
263 }
264
265 #[derive(Debug)]
266 struct TestObject<'a> {
267 state: &'a TestState,
268 }
269
270 impl<'a> TestObject<'a> {
271 fn new(state: &'a TestState) -> Self {
272 state.live_obj_count.set(state.live_obj_count.get() + 1);
273 state.ctor_count.set(state.ctor_count.get() + 1);
274 TestObject { state }
275 }
276 }
277
278 impl<'a> Drop for TestObject<'a> {
279 fn drop(&mut self) {
280 self.state.live_obj_count.set(self.state.live_obj_count.get() - 1);
281 self.state.dtor_count.set(self.state.dtor_count.get() + 1);
282 }
283 }
284
285 #[test]
286 fn test_inline() {
287 let state = TestState::default();
288
289 for sz in 0..=3 {
290 state.ctor_count.set(0);
291 state.dtor_count.set(0);
292 {
293 let ia =
294 InlineArray::<TestObject<'_>, 3>::try_new_with(sz, || TestObject::new(&state))
295 .unwrap();
296 assert_eq!(ia.len(), sz);
297 }
298 assert_eq!(state.ctor_count.get(), sz);
299 assert_eq!(state.dtor_count.get(), sz);
300 }
301 }
302
303 #[test]
304 fn test_non_inline() {
305 let state = TestState::default();
306
307 let test_sizes = [4, 5, 6, 10, 100];
308
309 for &sz in &test_sizes {
310 state.ctor_count.set(0);
311 state.dtor_count.set(0);
312 {
313 let ia =
314 InlineArray::<TestObject<'_>, 3>::try_new_with(sz, || TestObject::new(&state))
315 .unwrap();
316 assert_eq!(ia.len(), sz);
317 }
318 assert_eq!(state.ctor_count.get(), sz);
319 assert_eq!(state.dtor_count.get(), sz);
320 }
321 }
322
323 #[test]
324 fn test_allocation_failure() {
325 let state = TestState::default();
326 state.fail_threshold.set(0); let ia = InlineArray::<u32, 3, TestAllocator<'_>>::try_new_in(
330 3,
331 TestAllocator { state: &state },
332 );
333 assert!(ia.is_ok());
334
335 let ia = InlineArray::<u32, 3, TestAllocator<'_>>::try_new_in(
337 4,
338 TestAllocator { state: &state },
339 );
340 assert!(ia.is_err());
341 }
342}