1use derivative::Derivative;
6use smallvec::SmallVec;
7use std::marker::PhantomData;
8use std::mem::MaybeUninit;
9use std::ops::{Range, RangeBounds};
10use zerocopy::{FromBytes, IntoBytes};
11
12#[cfg(target_arch = "aarch64")]
13mod arm64;
14
15#[cfg(target_arch = "aarch64")]
16use arm64 as arch;
17
18#[cfg(target_arch = "x86_64")]
19mod x64;
20
21#[cfg(target_arch = "x86_64")]
22use x64 as arch;
23
24#[cfg(target_arch = "riscv64")]
25mod riscv64;
26
27#[cfg(target_arch = "riscv64")]
28use riscv64 as arch;
29
30#[derive(Derivative)]
34#[derivative(Copy(bound = ""), Clone(bound = ""))]
35pub struct EbpfPtr<'a, T> {
36 ptr: *mut T,
37 phantom: PhantomData<&'a T>,
38}
39
40#[allow(clippy::undocumented_unsafe_blocks, reason = "Force documented unsafe blocks in Starnix")]
41unsafe impl<'a, T> Send for EbpfPtr<'a, T> {}
42#[allow(clippy::undocumented_unsafe_blocks, reason = "Force documented unsafe blocks in Starnix")]
43unsafe impl<'a, T> Sync for EbpfPtr<'a, T> {}
44
45impl<'a, T> EbpfPtr<'a, T>
46where
47 T: Sized,
48{
49 pub unsafe fn new(ptr: *mut T) -> Self {
55 Self { ptr, phantom: PhantomData }
56 }
57
58 pub unsafe fn deref(&self) -> &'a T {
62 #[allow(clippy::undocumented_unsafe_blocks, reason = "2024 edition migration")]
63 unsafe {
64 &*self.ptr
65 }
66 }
67
68 pub unsafe fn deref_mut(&self) -> &'a mut T {
71 #[allow(clippy::undocumented_unsafe_blocks, reason = "2024 edition migration")]
72 unsafe {
73 &mut *self.ptr
74 }
75 }
76
77 pub fn get_field<F, const OFFSET: usize>(&self) -> EbpfPtr<'a, F> {
78 assert!(OFFSET + std::mem::size_of::<F>() <= std::mem::size_of::<T>());
79 let field_ptr = unsafe { self.ptr.byte_offset(OFFSET as isize) } as *mut F;
82 EbpfPtr::<'a, F> { ptr: field_ptr, phantom: PhantomData }
83 }
84
85 pub fn ptr(&self) -> *mut T {
86 self.ptr
87 }
88}
89
90impl<'a, T> From<&'a mut T> for EbpfPtr<'a, T>
91where
92 T: IntoBytes + FromBytes + Sized,
93{
94 fn from(value: &'a mut T) -> Self {
95 let ptr = value.as_mut_bytes().as_mut_ptr() as *mut T;
96 unsafe { Self::new(ptr) }
100 }
101}
102
103impl EbpfPtr<'_, u64> {
104 pub fn load_relaxed(&self) -> u64 {
107 unsafe { arch::load_u64(self.ptr) }
109 }
110
111 pub fn store_relaxed(&self, value: u64) {
114 unsafe { arch::store_u64(self.ptr, value) }
116 }
117}
118
119impl EbpfPtr<'_, u32> {
120 pub fn load_relaxed(&self) -> u32 {
123 unsafe { arch::load_u32(self.ptr) }
125 }
126
127 pub fn store_relaxed(&self, value: u32) {
130 unsafe { arch::store_u32(self.ptr, value) }
132 }
133}
134
135impl EbpfPtr<'_, i32> {
136 pub fn load_relaxed(&self) -> i32 {
139 unsafe { arch::load_u32(self.ptr as *mut u32) as i32 }
141 }
142
143 pub fn store_relaxed(&self, value: i32) {
146 unsafe { arch::store_u32(self.ptr as *mut u32, value as u32) }
148 }
149}
150
151impl EbpfPtr<'_, u16> {
152 pub fn load_relaxed(&self) -> u16 {
155 unsafe { arch::load_u16(self.ptr) }
157 }
158
159 pub fn store_relaxed(&self, value: u16) {
162 unsafe { arch::store_u16(self.ptr, value) }
164 }
165}
166
167impl EbpfPtr<'_, u8> {
168 pub fn load_relaxed(&self) -> u8 {
170 unsafe { arch::load_u8(self.ptr) }
172 }
173
174 pub fn store_relaxed(&self, value: u8) {
176 unsafe { arch::store_u8(self.ptr, value) }
178 }
179}
180
181#[derive(Copy, Clone)]
189pub struct EbpfBufferPtr<'a> {
190 ptr: *mut u8,
191 size: usize,
192 phantom: PhantomData<&'a u8>,
193}
194
195impl<'a> EbpfBufferPtr<'a> {
196 pub const ALIGNMENT: usize = size_of::<u64>();
197
198 pub unsafe fn new(ptr: *mut u8, size: usize) -> Self {
204 Self { ptr, size, phantom: PhantomData }
205 }
206
207 pub fn len(&self) -> usize {
209 self.size
210 }
211
212 pub fn raw_ptr(&self) -> *mut u8 {
214 self.ptr
215 }
216
217 unsafe fn get_ptr_internal<T>(&self, offset: usize) -> EbpfPtr<'a, T> {
220 unsafe { EbpfPtr::new(self.ptr.byte_offset(offset as isize) as *mut T) }
223 }
224
225 pub fn get_ptr<T>(&self, offset: usize) -> Option<EbpfPtr<'a, T>> {
227 if offset + std::mem::size_of::<T>() <= self.size {
228 Some(unsafe { self.get_ptr_internal(offset) })
230 } else {
231 None
232 }
233 }
234
235 pub fn slice(&self, range: impl RangeBounds<usize>) -> Option<Self> {
237 let start = match range.start_bound() {
238 std::ops::Bound::Included(&start) => start,
239 std::ops::Bound::Excluded(&start) => start + 1,
240 std::ops::Bound::Unbounded => 0,
241 };
242 let end = match range.end_bound() {
243 std::ops::Bound::Included(&end) => end + 1,
244 std::ops::Bound::Excluded(&end) => end,
245 std::ops::Bound::Unbounded => self.size,
246 };
247
248 assert!(start <= end);
249 (end <= self.size).then(|| {
250 unsafe {
254 Self {
255 ptr: self.ptr.byte_offset(start as isize),
256 size: end - start,
257 phantom: PhantomData,
258 }
259 }
260 })
261 }
262
263 pub fn load_to_slice(&self, dst: &mut [MaybeUninit<u8>]) {
266 assert_eq!(dst.len(), self.size);
267
268 let mut src_ptr = self.ptr;
269 let src_end = unsafe { src_ptr.add(self.size) };
271
272 let Range { start: dst_ptr, end: dst_end } = dst.as_mut_ptr_range();
273 let mut dst_ptr = dst_ptr as *mut u8;
274 let dst_end = dst_end as *mut u8;
275
276 if src_ptr as usize % 8 > 0 {
277 if src_ptr < src_end && src_ptr as usize % 2 > 0 {
278 unsafe {
280 let value: u8 = arch::load_u8(src_ptr as *const u8);
281 std::ptr::write_unaligned(dst_ptr, value);
282 src_ptr = src_ptr.add(1);
283 dst_ptr = dst_ptr.add(1);
284 };
285 }
286
287 if src_ptr as usize + 2 <= src_end as usize && src_ptr as usize % 4 > 0 {
288 unsafe {
290 let value: u16 = arch::load_u16(src_ptr as *const u16);
291 std::ptr::write_unaligned(dst_ptr as *mut u16, value);
292 src_ptr = src_ptr.add(2);
293 dst_ptr = dst_ptr.add(2);
294 }
295 }
296
297 if src_ptr as usize + 4 <= src_end as usize && src_ptr as usize % 8 > 0 {
298 unsafe {
300 let value: u32 = arch::load_u32(src_ptr as *const u32);
301 std::ptr::write_unaligned(dst_ptr as *mut u32, value);
302 src_ptr = src_ptr.add(4);
303 dst_ptr = dst_ptr.add(4);
304 }
305 }
306 }
307
308 while src_ptr as usize + 8 <= src_end as usize {
309 unsafe {
311 let value: u64 = arch::load_u64(src_ptr as *const u64);
312 std::ptr::write_unaligned(dst_ptr as *mut u64, value);
313 src_ptr = src_ptr.add(8);
314 dst_ptr = dst_ptr.add(8);
315 }
316 }
317
318 if src_ptr < src_end {
319 if src_ptr as usize + 4 <= src_end as usize {
320 unsafe {
322 let value: u32 = arch::load_u32(src_ptr as *const u32);
323 std::ptr::write_unaligned(dst_ptr as *mut u32, value);
324 src_ptr = src_ptr.add(4);
325 dst_ptr = dst_ptr.add(4);
326 }
327 }
328
329 if src_ptr as usize + 2 <= src_end as usize {
330 unsafe {
332 let value: u16 = arch::load_u16(src_ptr as *const u16);
333 std::ptr::write_unaligned(dst_ptr as *mut u16, value);
334 src_ptr = src_ptr.add(2);
335 dst_ptr = dst_ptr.add(2);
336 }
337 }
338
339 if src_ptr < src_end {
340 unsafe {
342 let value: u8 = arch::load_u8(src_ptr as *const u8);
343 std::ptr::write_unaligned(dst_ptr, value);
344 src_ptr = src_ptr.add(1);
345 dst_ptr = dst_ptr.add(1);
346 }
347 }
348 }
349
350 debug_assert_eq!(src_ptr, src_end);
351 debug_assert_eq!(dst_ptr, dst_end);
352 }
353
354 pub fn load<const N: usize>(&self) -> SmallVec<[u8; N]> {
356 if self.size <= N {
357 let mut buf = MaybeUninit::<[u8; N]>::uninit();
358 self.load_to_slice(&mut AsMut::<[MaybeUninit<u8>]>::as_mut(&mut buf)[..self.size]);
359 unsafe { SmallVec::from_buf_and_len_unchecked(buf, self.size) }
361 } else {
362 let mut vec = Vec::<u8>::with_capacity(self.size);
363 self.load_to_slice(vec.spare_capacity_mut());
364 unsafe { vec.set_len(self.size) };
366 SmallVec::from_vec(vec)
367 }
368 }
369
370 pub fn store(&self, data: &[u8]) {
372 assert!(data.len() <= self.size);
373
374 let mut ptr = self.ptr;
375 let end = unsafe { ptr.add(data.len()) };
377 let mut data_offset = 0;
378
379 if ptr as usize % 8 > 0 {
381 if ptr < end && ptr as usize % 2 > 0 {
382 let value = data[data_offset];
383 data_offset += 1;
384 unsafe {
386 arch::store_u8(ptr, value);
387 ptr = ptr.add(1)
388 };
389 }
390
391 if (ptr as usize) + 2 <= end as usize && ptr as usize % 4 > 0 {
392 let value = u16::read_from_bytes(&data[data_offset..(data_offset + 2)]).unwrap();
393 data_offset += 2;
394 unsafe {
396 arch::store_u16(ptr as *mut u16, value);
397 ptr = ptr.add(2)
398 };
399 }
400
401 if (ptr as usize) + 4 <= end as usize && ptr as usize % 8 > 0 {
402 let value = u32::read_from_bytes(&data[data_offset..(data_offset + 4)]).unwrap();
403 data_offset += 4;
404 unsafe {
406 arch::store_u32(ptr as *mut u32, value);
407 ptr = ptr.add(4)
408 };
409 }
410 }
411
412 while (ptr as usize) + 8 <= end as usize {
414 let value = u64::read_from_bytes(&data[data_offset..(data_offset + 8)]).unwrap();
415 data_offset += 8;
416 unsafe {
418 arch::store_u64(ptr as *mut u64, value);
419 ptr = ptr.add(8)
420 };
421 }
422
423 if ptr < end {
425 if (ptr as usize) + 4 <= end as usize {
426 let value = u32::read_from_bytes(&data[data_offset..(data_offset + 4)]).unwrap();
427 data_offset += 4;
428 unsafe {
430 arch::store_u32(ptr as *mut u32, value);
431 ptr = ptr.add(4)
432 };
433 }
434
435 if (ptr as usize) + 2 <= end as usize {
436 let value = u16::read_from_bytes(&data[data_offset..(data_offset + 2)]).unwrap();
437 data_offset += 2;
438 unsafe {
440 arch::store_u16(ptr as *mut u16, value);
441 ptr = ptr.add(2)
442 };
443 }
444
445 if ptr < end {
446 let value = data[data_offset];
447 data_offset += 1;
448 unsafe {
450 arch::store_u8(ptr, value);
451 ptr = ptr.add(1)
452 };
453 }
454 }
455
456 debug_assert_eq!(ptr, end);
457 debug_assert_eq!(data_offset, data.len());
458 }
459
460 pub fn copy(&self, src: &EbpfBufferPtr<'_>) {
463 assert!(src.len() <= self.size);
464
465 let mut dst_ptr = self.ptr;
466 let dst_end = unsafe { dst_ptr.add(src.len()) };
468
469 let mut src_ptr = src.ptr;
470 let src_end = unsafe { src_ptr.add(src.len()) };
472
473 if (src_ptr as usize) % 8 == 0 && (dst_ptr as usize) % 8 == 0 {
475 while (src_ptr as usize) + 8 <= src_end as usize {
476 unsafe {
478 let value: u64 = arch::load_u64(src_ptr as *const u64);
479 arch::store_u64(dst_ptr as *mut u64, value);
480 src_ptr = src_ptr.add(8);
481 dst_ptr = dst_ptr.add(8);
482 }
483 }
484
485 if src_ptr < src_end {
486 if (src_ptr as usize) + 4 <= src_end as usize {
487 unsafe {
489 let value: u32 = arch::load_u32(src_ptr as *const u32);
490 arch::store_u32(dst_ptr as *mut u32, value);
491 src_ptr = src_ptr.add(4);
492 dst_ptr = dst_ptr.add(4);
493 }
494 }
495
496 if (src_ptr as usize) + 2 <= src_end as usize {
497 unsafe {
499 let value: u16 = arch::load_u16(src_ptr as *const u16);
500 arch::store_u16(dst_ptr as *mut u16, value);
501 src_ptr = src_ptr.add(2);
502 dst_ptr = dst_ptr.add(2);
503 }
504 }
505
506 if src_ptr < src_end {
507 unsafe {
509 let value: u8 = arch::load_u8(src_ptr as *const u8);
510 arch::store_u8(dst_ptr as *mut u8, value);
511 src_ptr = src_ptr.add(1);
512 dst_ptr = dst_ptr.add(1);
513 }
514 }
515 }
516
517 debug_assert_eq!(src_ptr, src_end);
518 debug_assert_eq!(dst_ptr, dst_end);
519 } else {
520 self.store(&src.load::<128>());
524 }
525 }
526}
527
528impl<'a> From<&'a mut [u8]> for EbpfBufferPtr<'a> {
529 fn from(value: &'a mut [u8]) -> Self {
530 let ptr = value.as_mut_ptr() as *mut u8;
531 unsafe { Self::new(ptr, value.len()) }
535 }
536}
537
538impl<'a> From<&'a mut Vec<u8>> for EbpfBufferPtr<'a> {
539 fn from(value: &'a mut Vec<u8>) -> Self {
540 let ptr = value.as_mut_ptr() as *mut u8;
541 unsafe { Self::new(ptr, value.len()) }
545 }
546}
547impl<'a, const N: usize> From<&'a mut [u8; N]> for EbpfBufferPtr<'a> {
548 fn from(value: &'a mut [u8; N]) -> Self {
549 let ptr = value.as_mut_ptr() as *mut u8;
550 unsafe { Self::new(ptr, N) }
554 }
555}
556
557#[cfg(test)]
558mod test {
559 use super::*;
560 use fuchsia_runtime::vmar_root_self;
561 use std::sync::Barrier;
562 use std::sync::atomic::{AtomicU32, Ordering};
563 use std::thread;
564
565 #[test]
566 fn test_u64_atomicity() {
567 let vmo_size = zx::system_get_page_size() as usize;
568 let vmo = zx::Vmo::create(vmo_size as u64).unwrap();
569 let addr = vmar_root_self()
570 .map(0, &vmo, 0, vmo_size, zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE)
571 .unwrap();
572 #[allow(
573 clippy::undocumented_unsafe_blocks,
574 reason = "Force documented unsafe blocks in Starnix"
575 )]
576 let shared_ptr = unsafe { EbpfPtr::new(addr as *mut u64) };
577
578 const NUM_THREADS: usize = 10;
579
580 let barrier = Barrier::new(NUM_THREADS * 2);
582
583 let finished_writers = AtomicU32::new(0);
584
585 thread::scope(|scope| {
586 let mut threads = Vec::new();
587
588 for _ in 0..10 {
589 threads.push(scope.spawn(|| {
590 barrier.wait();
591 for _ in 0..1000 {
592 for i in 0..255 {
593 let v = i << 8 | i;
595 let v = v << 16 | v;
596 let v = v << 32 | v;
597 shared_ptr.store_relaxed(v);
598 }
599 }
600 finished_writers.fetch_add(1, Ordering::Relaxed);
601 }));
602
603 threads.push(scope.spawn(|| {
604 barrier.wait();
605 loop {
606 for _ in 0..1000 {
607 let v = shared_ptr.load_relaxed();
608 assert!(v >> 32 == v & 0xffff_ffff);
610 assert!((v >> 16) & 0xffff == v & 0xffff);
611 assert!((v >> 8) & 0xff == v & 0xff);
612 }
613 if finished_writers.load(Ordering::Relaxed) == NUM_THREADS as u32 {
614 break;
615 }
616 }
617 }));
618 }
619
620 for t in threads.into_iter() {
621 t.join().expect("failed to join a test thread");
622 }
623 });
624
625 #[allow(
626 clippy::undocumented_unsafe_blocks,
627 reason = "Force documented unsafe blocks in Starnix"
628 )]
629 unsafe {
630 vmar_root_self().unmap(addr, vmo_size).unwrap()
631 };
632 }
633
634 #[test]
635 fn test_buffer_slice() {
636 const SIZE: usize = 32;
637
638 let mut buf = [0; SIZE];
639 #[allow(
640 clippy::undocumented_unsafe_blocks,
641 reason = "Force documented unsafe blocks in Starnix"
642 )]
643 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), SIZE) };
644
645 buf_ptr.slice(8..16).unwrap().store(&[1, 2, 3, 4, 5, 6, 7, 8]);
646 let value = buf_ptr.slice(0..24).unwrap().load::<16>();
647 assert_eq!(
648 &value[..],
649 &[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
650 );
651
652 assert!(buf_ptr.slice(8..40).is_none());
653 }
654
655 #[test]
656 fn test_buffer_load() {
657 const FULL_SIZE: usize = 40;
658 let mut buf = (0..(FULL_SIZE as u8)).map(|v| v as u8).collect::<Vec<_>>();
659 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), FULL_SIZE) };
661
662 for start in 0..FULL_SIZE {
663 for end in start..=FULL_SIZE {
664 let slice = buf_ptr.slice(start..end).unwrap();
665 let loaded = slice.load::<16>();
666
667 let expected = (start..end).map(|v| v as u8).collect::<Vec<_>>();
668 assert_eq!(&loaded[..], &expected[..], "failed for range {}..{}", start, end);
669 }
670 }
671 }
672
673 #[test]
674 fn test_buffer_store() {
675 const FULL_SIZE: usize = 40;
676 let mut buf = [0u8; FULL_SIZE];
677 let buf_ptr = unsafe { EbpfBufferPtr::new(buf.as_mut_ptr(), FULL_SIZE) };
679
680 for start in 0..FULL_SIZE {
681 for end in start..=FULL_SIZE {
682 let slice = buf_ptr.slice(start..end).unwrap();
683 let data_to_store = (start..end).map(|v| v as u8).collect::<Vec<_>>();
684 slice.store(&data_to_store);
685
686 let loaded = slice.load::<16>();
687 assert_eq!(&loaded[..], &data_to_store[..], "failed for range {}..{}", start, end);
688 }
689 }
690 }
691
692 #[test]
693 fn test_buffer_copy() {
694 const BASE_SIZE: usize = 48;
695 let mut src_buf = (0..(BASE_SIZE as u8)).map(|v| v as u8).collect::<Vec<_>>();
696 let mut dst_buf = [0u8; BASE_SIZE];
697
698 let src_base = unsafe { EbpfBufferPtr::new(src_buf.as_mut_ptr(), BASE_SIZE) };
700
701 let dst_base = unsafe { EbpfBufferPtr::new(dst_buf.as_mut_ptr(), BASE_SIZE) };
703
704 for src_align in 0..8 {
705 for dst_align in 0..8 {
706 for len in 0..=32 {
707 dst_buf.fill(0);
708
709 let src_slice = src_base.slice(src_align..(src_align + len)).unwrap();
710 let dst_slice = dst_base.slice(dst_align..(dst_align + len)).unwrap();
711
712 dst_slice.copy(&src_slice);
713
714 let loaded = dst_slice.load::<16>();
715 let expected =
716 (src_align..(src_align + len)).map(|v| v as u8).collect::<Vec<_>>();
717 assert_eq!(
718 &loaded[..],
719 &expected[..],
720 "copy failed for length {} with src align {} and dst align {}",
721 len,
722 src_align,
723 dst_align
724 );
725
726 for i in 0..BASE_SIZE {
727 if i < dst_align || i >= dst_align + len {
728 assert_eq!(
729 dst_buf[i], 0,
730 "out-of-bounds memory modified at index {} (dst_align={}, len={})",
731 i, dst_align, len
732 );
733 }
734 }
735 }
736 }
737 }
738 }
739}