1use derivative::Derivative;
6use smallvec::SmallVec;
7use std::marker::PhantomData;
8use std::mem::MaybeUninit;
9use std::ops::{Range, RangeBounds};
10use zerocopy::{FromBytes, IntoBytes};
11
12#[cfg(target_arch = "aarch64")]
13mod arm64;
14
15#[cfg(target_arch = "aarch64")]
16use arm64 as arch;
17
18#[cfg(target_arch = "x86_64")]
19mod x64;
20
21#[cfg(target_arch = "x86_64")]
22use x64 as arch;
23
24#[cfg(target_arch = "riscv64")]
25mod riscv64;
26
27#[cfg(target_arch = "riscv64")]
28use riscv64 as arch;
29
30#[derive(Derivative)]
34#[derivative(Copy(bound = ""), Clone(bound = ""))]
35pub struct EbpfPtr<'a, T> {
36 ptr: *mut T,
37 phantom: PhantomData<&'a T>,
38}
39
40#[allow(clippy::undocumented_unsafe_blocks, reason = "Force documented unsafe blocks in Starnix")]
41unsafe impl<'a, T> Send for EbpfPtr<'a, T> {}
42#[allow(clippy::undocumented_unsafe_blocks, reason = "Force documented unsafe blocks in Starnix")]
43unsafe impl<'a, T> Sync for EbpfPtr<'a, T> {}
44
45impl<'a, T> EbpfPtr<'a, T>
46where
47 T: Sized,
48{
49 pub unsafe fn new(ptr: *mut T) -> Self {
55 Self { ptr, phantom: PhantomData }
56 }
57
58 pub unsafe fn deref(&self) -> &'a T {
62 #[allow(clippy::undocumented_unsafe_blocks, reason = "2024 edition migration")]
63 unsafe {
64 &*self.ptr
65 }
66 }
67
68 pub unsafe fn deref_mut(&self) -> &'a mut T {
71 #[allow(clippy::undocumented_unsafe_blocks, reason = "2024 edition migration")]
72 unsafe {
73 &mut *self.ptr
74 }
75 }
76
77 pub fn get_field<F, const OFFSET: usize>(&self) -> EbpfPtr<'a, F> {
78 assert!(OFFSET + std::mem::size_of::<F>() <= std::mem::size_of::<T>());
79 let field_ptr = unsafe { self.ptr.byte_offset(OFFSET as isize) } as *mut F;
82 EbpfPtr::<'a, F> { ptr: field_ptr, phantom: PhantomData }
83 }
84
85 pub fn ptr(&self) -> *mut T {
86 self.ptr
87 }
88}
89
90impl<'a, T> From<&'a mut T> for EbpfPtr<'a, T>
91where
92 T: IntoBytes + FromBytes + Sized,
93{
94 fn from(value: &'a mut T) -> Self {
95 let ptr = value.as_mut_bytes().as_mut_ptr() as *mut T;
96 unsafe { Self::new(ptr) }
100 }
101}
102
103impl EbpfPtr<'_, u64> {
104 pub fn load_relaxed(&self) -> u64 {
107 unsafe { arch::load_u64(self.ptr) }
109 }
110
111 pub fn store_relaxed(&self, value: u64) {
114 unsafe { arch::store_u64(self.ptr, value) }
116 }
117}
118
119impl EbpfPtr<'_, u32> {
120 pub fn load_relaxed(&self) -> u32 {
123 unsafe { arch::load_u32(self.ptr) }
125 }
126
127 pub fn store_relaxed(&self, value: u32) {
130 unsafe { arch::store_u32(self.ptr, value) }
132 }
133}
134
135impl EbpfPtr<'_, i32> {
136 pub fn load_relaxed(&self) -> i32 {
139 unsafe { arch::load_u32(self.ptr as *mut u32) as i32 }
141 }
142
143 pub fn store_relaxed(&self, value: i32) {
146 unsafe { arch::store_u32(self.ptr as *mut u32, value as u32) }
148 }
149}
150
151#[derive(Copy, Clone)]
159pub struct EbpfBufferPtr<'a> {
160 ptr: *mut u8,
161 size: usize,
162 phantom: PhantomData<&'a u8>,
163}
164
165impl<'a> EbpfBufferPtr<'a> {
166 pub const ALIGNMENT: usize = size_of::<u64>();
167
168 pub unsafe fn new(ptr: *mut u8, size: usize) -> Self {
174 Self { ptr, size, phantom: PhantomData }
175 }
176
177 pub fn len(&self) -> usize {
179 self.size
180 }
181
182 pub fn raw_ptr(&self) -> *mut u8 {
184 self.ptr
185 }
186
187 unsafe fn get_ptr_internal<T>(&self, offset: usize) -> EbpfPtr<'a, T> {
190 unsafe { EbpfPtr::new(self.ptr.byte_offset(offset as isize) as *mut T) }
193 }
194
195 pub fn get_ptr<T>(&self, offset: usize) -> Option<EbpfPtr<'a, T>> {
197 if offset + std::mem::size_of::<T>() <= self.size {
198 Some(unsafe { self.get_ptr_internal(offset) })
200 } else {
201 None
202 }
203 }
204
205 pub fn slice(&self, range: impl RangeBounds<usize>) -> Option<Self> {
207 let start = match range.start_bound() {
208 std::ops::Bound::Included(&start) => start,
209 std::ops::Bound::Excluded(&start) => start + 1,
210 std::ops::Bound::Unbounded => 0,
211 };
212 let end = match range.end_bound() {
213 std::ops::Bound::Included(&end) => end + 1,
214 std::ops::Bound::Excluded(&end) => end,
215 std::ops::Bound::Unbounded => self.size,
216 };
217
218 assert!(start <= end);
219 (end <= self.size).then(|| {
220 unsafe {
224 Self {
225 ptr: self.ptr.byte_offset(start as isize),
226 size: end - start,
227 phantom: PhantomData,
228 }
229 }
230 })
231 }
232
233 pub fn load_to_slice(&self, dst: &mut [MaybeUninit<u8>]) {
236 assert_eq!(dst.len(), self.size);
237
238 let mut src_ptr = self.ptr;
239 let src_end = unsafe { src_ptr.add(self.size) };
241
242 let Range { start: dst_ptr, end: dst_end } = dst.as_mut_ptr_range();
243 let mut dst_ptr = dst_ptr as *mut u8;
244 let dst_end = dst_end as *mut u8;
245
246 if src_ptr as usize % 8 > 0 {
247 if src_ptr < src_end && src_ptr as usize % 2 > 0 {
248 unsafe {
250 let value: u8 = arch::load_u8(src_ptr as *const u8);
251 std::ptr::write_unaligned(dst_ptr, value);
252 src_ptr = src_ptr.add(1);
253 dst_ptr = dst_ptr.add(1);
254 };
255 }
256
257 if src_ptr as usize + 2 <= src_end as usize && src_ptr as usize % 4 > 0 {
258 unsafe {
260 let value: u16 = arch::load_u16(src_ptr as *const u16);
261 std::ptr::write_unaligned(dst_ptr as *mut u16, value);
262 src_ptr = src_ptr.add(2);
263 dst_ptr = dst_ptr.add(2);
264 }
265 }
266
267 if src_ptr as usize + 4 <= src_end as usize && src_ptr as usize % 8 > 0 {
268 unsafe {
270 let value: u32 = arch::load_u32(src_ptr as *const u32);
271 std::ptr::write_unaligned(dst_ptr as *mut u32, value);
272 src_ptr = src_ptr.add(4);
273 dst_ptr = dst_ptr.add(4);
274 }
275 }
276 }
277
278 while src_ptr as usize + 8 <= src_end as usize {
279 unsafe {
281 let value: u64 = arch::load_u64(src_ptr as *const u64);
282 std::ptr::write_unaligned(dst_ptr as *mut u64, value);
283 src_ptr = src_ptr.add(8);
284 dst_ptr = dst_ptr.add(8);
285 }
286 }
287
288 if src_ptr < src_end {
289 if src_ptr as usize + 4 <= src_end as usize {
290 unsafe {
292 let value: u32 = arch::load_u32(src_ptr as *const u32);
293 std::ptr::write_unaligned(dst_ptr as *mut u32, value);
294 src_ptr = src_ptr.add(4);
295 dst_ptr = dst_ptr.add(4);
296 }
297 }
298
299 if src_ptr as usize + 2 <= src_end as usize {
300 unsafe {
302 let value: u16 = arch::load_u16(src_ptr as *const u16);
303 std::ptr::write_unaligned(dst_ptr as *mut u16, value);
304 src_ptr = src_ptr.add(2);
305 dst_ptr = dst_ptr.add(2);
306 }
307 }
308
309 if src_ptr < src_end {
310 unsafe {
312 let value: u8 = arch::load_u8(src_ptr as *const u8);
313 std::ptr::write_unaligned(dst_ptr, value);
314 src_ptr = src_ptr.add(1);
315 dst_ptr = dst_ptr.add(1);
316 }
317 }
318 }
319
320 debug_assert_eq!(src_ptr, src_end);
321 debug_assert_eq!(dst_ptr, dst_end);
322 }
323
324 pub fn load<const N: usize>(&self) -> SmallVec<[u8; N]> {
326 if self.size <= N {
327 let mut buf = MaybeUninit::<[u8; N]>::uninit();
328 self.load_to_slice(&mut AsMut::<[MaybeUninit<u8>]>::as_mut(&mut buf)[..self.size]);
329 unsafe { SmallVec::from_buf_and_len_unchecked(buf, self.size) }
331 } else {
332 let mut vec = Vec::<u8>::with_capacity(self.size);
333 self.load_to_slice(vec.spare_capacity_mut());
334 unsafe { vec.set_len(self.size) };
336 SmallVec::from_vec(vec)
337 }
338 }
339
340 pub fn store(&self, data: &[u8]) {
342 assert!(data.len() <= self.size);
343
344 let mut ptr = self.ptr;
345 let end = unsafe { ptr.add(data.len()) };
347 let mut data_offset = 0;
348
349 if ptr as usize % 8 > 0 {
351 if ptr < end && ptr as usize % 2 > 0 {
352 let value = data[data_offset];
353 data_offset += 1;
354 unsafe {
356 arch::store_u8(ptr, value);
357 ptr = ptr.add(1)
358 };
359 }
360
361 if (ptr as usize) + 2 <= end as usize && ptr as usize % 4 > 0 {
362 let value = u16::read_from_bytes(&data[data_offset..(data_offset + 2)]).unwrap();
363 data_offset += 2;
364 unsafe {
366 arch::store_u16(ptr as *mut u16, value);
367 ptr = ptr.add(2)
368 };
369 }
370
371 if (ptr as usize) + 4 <= end as usize && ptr as usize % 8 > 0 {
372 let value = u32::read_from_bytes(&data[data_offset..(data_offset + 4)]).unwrap();
373 data_offset += 4;
374 unsafe {
376 arch::store_u32(ptr as *mut u32, value);
377 ptr = ptr.add(4)
378 };
379 }
380 }
381
382 while (ptr as usize) + 8 <= end as usize {
384 let value = u64::read_from_bytes(&data[data_offset..(data_offset + 8)]).unwrap();
385 data_offset += 8;
386 unsafe {
388 arch::store_u64(ptr as *mut u64, value);
389 ptr = ptr.add(8)
390 };
391 }
392
393 if ptr < end {
395 if (ptr as usize) + 4 <= end as usize {
396 let value = u32::read_from_bytes(&data[data_offset..(data_offset + 4)]).unwrap();
397 data_offset += 4;
398 unsafe {
400 arch::store_u32(ptr as *mut u32, value);
401 ptr = ptr.add(4)
402 };
403 }
404
405 if (ptr as usize) + 2 <= end as usize {
406 let value = u16::read_from_bytes(&data[data_offset..(data_offset + 2)]).unwrap();
407 data_offset += 2;
408 unsafe {
410 arch::store_u16(ptr as *mut u16, value);
411 ptr = ptr.add(2)
412 };
413 }
414
415 if ptr < end {
416 let value = data[data_offset];
417 data_offset += 1;
418 unsafe {
420 arch::store_u8(ptr, value);
421 ptr = ptr.add(1)
422 };
423 }
424 }
425
426 debug_assert_eq!(ptr, end);
427 debug_assert_eq!(data_offset, data.len());
428 }
429
430 pub fn copy(&self, src: &EbpfBufferPtr<'_>) {
433 assert!(src.len() <= self.size);
434
435 let mut dst_ptr = self.ptr;
436 let dst_end = unsafe { dst_ptr.add(src.len()) };
438
439 let mut src_ptr = src.ptr;
440 let src_end = unsafe { src_ptr.add(src.len()) };
442
443 if (src_ptr as usize) % 8 == 0 && (dst_ptr as usize) % 8 == 0 {
445 while (src_ptr as usize) + 8 <= src_end as usize {
446 unsafe {
448 let value: u64 = arch::load_u64(src_ptr as *const u64);
449 arch::store_u64(dst_ptr as *mut u64, value);
450 src_ptr = src_ptr.add(8);
451 dst_ptr = dst_ptr.add(8);
452 }
453 }
454
455 if src_ptr < src_end {
456 if (src_ptr as usize) + 4 <= src_end as usize {
457 unsafe {
459 let value: u32 = arch::load_u32(src_ptr as *const u32);
460 arch::store_u32(dst_ptr as *mut u32, value);
461 src_ptr = src_ptr.add(4);
462 dst_ptr = dst_ptr.add(4);
463 }
464 }
465
466 if (src_ptr as usize) + 2 <= src_end as usize {
467 unsafe {
469 let value: u16 = arch::load_u16(src_ptr as *const u16);
470 arch::store_u16(dst_ptr as *mut u16, value);
471 src_ptr = src_ptr.add(2);
472 dst_ptr = dst_ptr.add(2);
473 }
474 }
475
476 if src_ptr < src_end {
477 unsafe {
479 let value: u8 = arch::load_u8(src_ptr as *const u8);
480 arch::store_u8(dst_ptr as *mut u8, value);
481 src_ptr = src_ptr.add(1);
482 dst_ptr = dst_ptr.add(1);
483 }
484 }
485 }
486
487 debug_assert_eq!(src_ptr, src_end);
488 debug_assert_eq!(dst_ptr, dst_end);
489 } else {
490 self.store(&src.load::<128>());
494 }
495 }
496}
497
498impl<'a> From<&'a mut [u8]> for EbpfBufferPtr<'a> {
499 fn from(value: &'a mut [u8]) -> Self {
500 let ptr = value.as_mut_ptr() as *mut u8;
501 unsafe { Self::new(ptr, value.len()) }
505 }
506}
507
508impl<'a> From<&'a mut Vec<u8>> for EbpfBufferPtr<'a> {
509 fn from(value: &'a mut Vec<u8>) -> Self {
510 let ptr = value.as_mut_ptr() as *mut u8;
511 unsafe { Self::new(ptr, value.len()) }
515 }
516}
517impl<'a, const N: usize> From<&'a mut [u8; N]> for EbpfBufferPtr<'a> {
518 fn from(value: &'a mut [u8; N]) -> Self {
519 let ptr = value.as_mut_ptr() as *mut u8;
520 unsafe { Self::new(ptr, N) }
524 }
525}
526
527#[cfg(test)]
528mod test {
529 use super::*;
530 use fuchsia_runtime::vmar_root_self;
531 use std::sync::Barrier;
532 use std::sync::atomic::{AtomicU32, Ordering};
533 use std::thread;
534
535 #[test]
536 fn test_u64_atomicity() {
537 let vmo_size = zx::system_get_page_size() as usize;
538 let vmo = zx::Vmo::create(vmo_size as u64).unwrap();
539 let addr = vmar_root_self()
540 .map(0, &vmo, 0, vmo_size, zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE)
541 .unwrap();
542 #[allow(
543 clippy::undocumented_unsafe_blocks,
544 reason = "Force documented unsafe blocks in Starnix"
545 )]
546 let shared_ptr = unsafe { EbpfPtr::new(addr as *mut u64) };
547
548 const NUM_THREADS: usize = 10;
549
550 let barrier = Barrier::new(NUM_THREADS * 2);
552
553 let finished_writers = AtomicU32::new(0);
554
555 thread::scope(|scope| {
556 let mut threads = Vec::new();
557
558 for _ in 0..10 {
559 threads.push(scope.spawn(|| {
560 barrier.wait();
561 for _ in 0..1000 {
562 for i in 0..255 {
563 let v = i << 8 | i;
565 let v = v << 16 | v;
566 let v = v << 32 | v;
567 shared_ptr.store_relaxed(v);
568 }
569 }
570 finished_writers.fetch_add(1, Ordering::Relaxed);
571 }));
572
573 threads.push(scope.spawn(|| {
574 barrier.wait();
575 loop {
576 for _ in 0..1000 {
577 let v = shared_ptr.load_relaxed();
578 assert!(v >> 32 == v & 0xffff_ffff);
580 assert!((v >> 16) & 0xffff == v & 0xffff);
581 assert!((v >> 8) & 0xff == v & 0xff);
582 }
583 if finished_writers.load(Ordering::Relaxed) == NUM_THREADS as u32 {
584 break;
585 }
586 }
587 }));
588 }
589
590 for t in threads.into_iter() {
591 t.join().expect("failed to join a test thread");
592 }
593 });
594
595 #[allow(
596 clippy::undocumented_unsafe_blocks,
597 reason = "Force documented unsafe blocks in Starnix"
598 )]
599 unsafe {
600 vmar_root_self().unmap(addr, vmo_size).unwrap()
601 };
602 }
603
604 #[test]
605 fn test_buffer_slice() {
606 const SIZE: usize = 32;
607
608 let mut buf = [0; SIZE];
609 #[allow(
610 clippy::undocumented_unsafe_blocks,
611 reason = "Force documented unsafe blocks in Starnix"
612 )]
613 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), SIZE) };
614
615 buf_ptr.slice(8..16).unwrap().store(&[1, 2, 3, 4, 5, 6, 7, 8]);
616 let value = buf_ptr.slice(0..24).unwrap().load::<16>();
617 assert_eq!(
618 &value[..],
619 &[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
620 );
621
622 assert!(buf_ptr.slice(8..40).is_none());
623 }
624
625 #[test]
626 fn test_buffer_load() {
627 const FULL_SIZE: usize = 40;
628 let mut buf = (0..(FULL_SIZE as u8)).map(|v| v as u8).collect::<Vec<_>>();
629 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), FULL_SIZE) };
631
632 for start in 0..FULL_SIZE {
633 for end in start..=FULL_SIZE {
634 let slice = buf_ptr.slice(start..end).unwrap();
635 let loaded = slice.load::<16>();
636
637 let expected = (start..end).map(|v| v as u8).collect::<Vec<_>>();
638 assert_eq!(&loaded[..], &expected[..], "failed for range {}..{}", start, end);
639 }
640 }
641 }
642
643 #[test]
644 fn test_buffer_store() {
645 const FULL_SIZE: usize = 40;
646 let mut buf = [0u8; FULL_SIZE];
647 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), FULL_SIZE) };
649
650 for start in 0..FULL_SIZE {
651 for end in start..=FULL_SIZE {
652 let slice = buf_ptr.slice(start..end).unwrap();
653 let data_to_store = (start..end).map(|v| v as u8).collect::<Vec<_>>();
654 slice.store(&data_to_store);
655
656 let loaded = slice.load::<16>();
657 assert_eq!(&loaded[..], &data_to_store[..], "failed for range {}..{}", start, end);
658 }
659 }
660 }
661
662 #[test]
663 fn test_buffer_copy() {
664 const BASE_SIZE: usize = 48;
665 let mut src_buf = (0..(BASE_SIZE as u8)).map(|v| v as u8).collect::<Vec<_>>();
666 let mut dst_buf = [0u8; BASE_SIZE];
667
668 let src_base = unsafe { EbpfBufferPtr::new(src_buf.as_mut_ptr(), BASE_SIZE) };
670
671 let dst_base = unsafe { EbpfBufferPtr::new(dst_buf.as_mut_ptr(), BASE_SIZE) };
673
674 for src_align in 0..8 {
675 for dst_align in 0..8 {
676 for len in 0..=32 {
677 dst_buf.fill(0);
678
679 let src_slice = src_base.slice(src_align..(src_align + len)).unwrap();
680 let dst_slice = dst_base.slice(dst_align..(dst_align + len)).unwrap();
681
682 dst_slice.copy(&src_slice);
683
684 let loaded = dst_slice.load::<16>();
685 let expected =
686 (src_align..(src_align + len)).map(|v| v as u8).collect::<Vec<_>>();
687 assert_eq!(
688 &loaded[..],
689 &expected[..],
690 "copy failed for length {} with src align {} and dst align {}",
691 len,
692 src_align,
693 dst_align
694 );
695
696 for i in 0..BASE_SIZE {
697 if i < dst_align || i >= dst_align + len {
698 assert_eq!(
699 dst_buf[i], 0,
700 "out-of-bounds memory modified at index {} (dst_align={}, len={})",
701 i, dst_align, len
702 );
703 }
704 }
705 }
706 }
707 }
708 }
709}