Skip to main content

ebpf_api/
helpers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::MapKey;
6use crate::maps::{Map, MapValueRef, RingBuffer, RingBufferWakeupPolicy};
7use ebpf::{BpfValue, EbpfBufferPtr, EbpfHelperImpl, EbpfProgramContext, FromBpfValue, HelperSet};
8use inspect_stubs::track_stub;
9use linux_uapi::{
10    BPF_SK_STORAGE_GET_F_CREATE, bpf_func_id_BPF_FUNC_get_current_pid_tgid,
11    bpf_func_id_BPF_FUNC_get_current_uid_gid, bpf_func_id_BPF_FUNC_get_netns_cookie,
12    bpf_func_id_BPF_FUNC_get_retval, bpf_func_id_BPF_FUNC_get_smp_processor_id,
13    bpf_func_id_BPF_FUNC_get_socket_cookie, bpf_func_id_BPF_FUNC_get_socket_uid,
14    bpf_func_id_BPF_FUNC_ktime_get_boot_ns, bpf_func_id_BPF_FUNC_ktime_get_coarse_ns,
15    bpf_func_id_BPF_FUNC_ktime_get_ns, bpf_func_id_BPF_FUNC_map_delete_elem,
16    bpf_func_id_BPF_FUNC_map_lookup_elem, bpf_func_id_BPF_FUNC_map_update_elem,
17    bpf_func_id_BPF_FUNC_probe_read_str, bpf_func_id_BPF_FUNC_probe_read_user,
18    bpf_func_id_BPF_FUNC_probe_read_user_str, bpf_func_id_BPF_FUNC_ringbuf_discard,
19    bpf_func_id_BPF_FUNC_ringbuf_reserve, bpf_func_id_BPF_FUNC_ringbuf_submit,
20    bpf_func_id_BPF_FUNC_set_retval, bpf_func_id_BPF_FUNC_sk_fullsock,
21    bpf_func_id_BPF_FUNC_sk_lookup_tcp, bpf_func_id_BPF_FUNC_sk_lookup_udp,
22    bpf_func_id_BPF_FUNC_sk_release, bpf_func_id_BPF_FUNC_sk_storage_get,
23    bpf_func_id_BPF_FUNC_skb_load_bytes, bpf_func_id_BPF_FUNC_skb_load_bytes_relative,
24    bpf_func_id_BPF_FUNC_trace_printk, bpf_map_type_BPF_MAP_TYPE_RINGBUF,
25    bpf_map_type_BPF_MAP_TYPE_SK_STORAGE, gid_t, pid_t, uid_t,
26};
27use smallvec::SmallVec;
28use std::slice;
29use zerocopy::IntoBytes as _;
30
31pub trait MapsContext<'a> {
32    fn on_map_access(&mut self, map: &Map);
33    fn add_value_ref(&mut self, map_ref: MapValueRef<'a>);
34}
35
36pub trait MapsProgramContext: EbpfProgramContext {
37    fn on_map_access(context: &mut Self::RunContext<'_>, map: &Map);
38    fn add_value_ref<'a>(context: &mut Self::RunContext<'a>, map_ref: MapValueRef<'a>);
39}
40
41impl<C: EbpfProgramContext> MapsProgramContext for C
42where
43    for<'a> C::RunContext<'a>: MapsContext<'a>,
44{
45    fn on_map_access(context: &mut Self::RunContext<'_>, map: &Map) {
46        context.on_map_access(map);
47    }
48
49    fn add_value_ref<'a>(context: &mut Self::RunContext<'a>, map_ref: MapValueRef<'a>) {
50        context.add_value_ref(map_ref);
51    }
52}
53
54fn bpf_map_lookup_elem<'a, C: MapsProgramContext>(
55    context: &mut C::RunContext<'a>,
56    map: BpfValue,
57    key: BpfValue,
58    _: BpfValue,
59    _: BpfValue,
60    _: BpfValue,
61) -> BpfValue {
62    // SAFETY: The `map` must be a reference to a `Map` object kept alive by the program itself.
63    let map: &Map = unsafe { &*map.as_ptr::<Map>() };
64
65    // SAFETY: safety is ensured by the verifier.
66    let key = unsafe { EbpfBufferPtr::new(key.as_ptr::<u8>(), map.schema.key_size as usize) };
67    let key: MapKey = key.load();
68
69    C::on_map_access(context, map);
70
71    let Some(value_ref) = map.lookup(&key) else {
72        return BpfValue::default();
73    };
74
75    let result: BpfValue = value_ref.ptr().raw_ptr().into();
76
77    // If this is a map with ref-counted elements then save the reference for
78    // the lifetime of the program.
79    if value_ref.is_ref_counted() {
80        C::add_value_ref(context, value_ref);
81    }
82
83    result
84}
85
86fn bpf_map_update_elem<C: MapsProgramContext>(
87    context: &mut C::RunContext<'_>,
88    map: BpfValue,
89    key: BpfValue,
90    value: BpfValue,
91    flags: BpfValue,
92    _: BpfValue,
93) -> BpfValue {
94    // SAFETY: The `map` must be a reference to a `Map` object kept alive by the program itself.
95    let map: &Map = unsafe { &*map.as_ptr::<Map>() };
96
97    // TODO(https://fxbug.dev/496639039): This should be checked by the verifier.
98    if map.schema.map_type == bpf_map_type_BPF_MAP_TYPE_SK_STORAGE {
99        return BpfValue::default();
100    }
101
102    // SAFETY: safety is ensured by the verifier.
103    let key = unsafe { EbpfBufferPtr::new(key.as_ptr::<u8>(), map.schema.key_size as usize) };
104    let key: MapKey = key.load();
105
106    // SAFETY: safety is ensured by the verifier.
107    let value = unsafe { EbpfBufferPtr::new(value.as_ptr::<u8>(), map.schema.value_size as usize) };
108    let flags = flags.as_u64();
109
110    C::on_map_access(context, map);
111
112    map.update(&key, value, flags).map(|_| 0).unwrap_or(u64::MAX).into()
113}
114
115fn bpf_map_delete_elem<C: MapsProgramContext>(
116    context: &mut C::RunContext<'_>,
117    map: BpfValue,
118    key: BpfValue,
119    _: BpfValue,
120    _: BpfValue,
121    _: BpfValue,
122) -> BpfValue {
123    // SAFETY: The `map` must be a reference to a `Map` object kept alive by the program itself.
124    let map: &Map = unsafe { &*map.as_ptr::<Map>() };
125
126    // TODO(https://fxbug.dev/496639039): This should be checked by the verifier.
127    if map.schema.map_type == bpf_map_type_BPF_MAP_TYPE_SK_STORAGE {
128        return BpfValue::default();
129    }
130
131    // SAFETY: safety is ensured by the verifier.
132    let key = unsafe { EbpfBufferPtr::new(key.as_ptr::<u8>(), map.schema.key_size as usize) };
133    let key: MapKey = key.load();
134
135    C::on_map_access(context, map);
136
137    map.delete(&key).map(|_| 0).unwrap_or(u64::MAX).into()
138}
139
140fn bpf_trace_printk<C: EbpfProgramContext>(
141    _context: &mut C::RunContext<'_>,
142    _fmt: BpfValue,
143    _fmt_size: BpfValue,
144    _: BpfValue,
145    _: BpfValue,
146    _: BpfValue,
147) -> BpfValue {
148    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_trace_printk");
149    0.into()
150}
151
152fn bpf_ktime_get_ns<C: EbpfProgramContext>(
153    _context: &mut C::RunContext<'_>,
154    _: BpfValue,
155    _: BpfValue,
156    _: BpfValue,
157    _: BpfValue,
158    _: BpfValue,
159) -> BpfValue {
160    zx::MonotonicInstant::get().into_nanos().into()
161}
162
163fn bpf_ringbuf_reserve<C: EbpfProgramContext>(
164    _context: &mut C::RunContext<'_>,
165    map: BpfValue,
166    size: BpfValue,
167    flags: BpfValue,
168    _: BpfValue,
169    _: BpfValue,
170) -> BpfValue {
171    // SAFETY: The safety of the operation is ensured by the bpf verifier. The `map` must be a
172    // reference to a `Map` object kept alive by the program itself.
173    let map: &Map = unsafe { &*map.as_ptr::<Map>() };
174
175    // Map type is checked by the verifier.
176    assert!(map.schema.map_type == bpf_map_type_BPF_MAP_TYPE_RINGBUF);
177
178    let Ok(size) = u32::try_from(size) else {
179        return BpfValue::default();
180    };
181    let flags = u64::from(flags);
182    map.ringbuf_reserve(size, flags).map(BpfValue::from).unwrap_or_else(|_| BpfValue::default())
183}
184
185fn bpf_ringbuf_submit<C: EbpfProgramContext>(
186    _context: &mut C::RunContext<'_>,
187    data: BpfValue,
188    flags: BpfValue,
189    _: BpfValue,
190    _: BpfValue,
191    _: BpfValue,
192) -> BpfValue {
193    let flags = RingBufferWakeupPolicy::from(flags);
194
195    // SAFETY: The safety of the operation is ensured by the bpf verifier. The data has to come from
196    // the result of a reserve call.
197    unsafe {
198        RingBuffer::submit(u64::from(data), flags);
199    }
200    0.into()
201}
202
203fn bpf_ringbuf_discard<C: EbpfProgramContext>(
204    _context: &mut C::RunContext<'_>,
205    data: BpfValue,
206    flags: BpfValue,
207    _: BpfValue,
208    _: BpfValue,
209    _: BpfValue,
210) -> BpfValue {
211    let flags = RingBufferWakeupPolicy::from(flags);
212
213    // SAFETY: The safety of the operation is ensured by the bpf verifier. The data has to come from
214    // the result of a reserve call.
215    unsafe {
216        RingBuffer::discard(u64::from(data), flags);
217    }
218    0.into()
219}
220
221fn bpf_ktime_get_boot_ns<C: EbpfProgramContext>(
222    _context: &mut C::RunContext<'_>,
223    _: BpfValue,
224    _: BpfValue,
225    _: BpfValue,
226    _: BpfValue,
227    _: BpfValue,
228) -> BpfValue {
229    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_ktime_get_boot_ns");
230    0.into()
231}
232
233fn bpf_probe_read_user<C: EbpfProgramContext>(
234    _context: &mut C::RunContext<'_>,
235    _: BpfValue,
236    _: BpfValue,
237    _: BpfValue,
238    _: BpfValue,
239    _: BpfValue,
240) -> BpfValue {
241    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_probe_read_user");
242    0.into()
243}
244
245fn bpf_probe_read_user_str<C: EbpfProgramContext>(
246    _context: &mut C::RunContext<'_>,
247    _: BpfValue,
248    _: BpfValue,
249    _: BpfValue,
250    _: BpfValue,
251    _: BpfValue,
252) -> BpfValue {
253    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_probe_read_user_str");
254    0.into()
255}
256
257fn bpf_ktime_get_coarse_ns<C: EbpfProgramContext>(
258    _context: &mut C::RunContext<'_>,
259    _: BpfValue,
260    _: BpfValue,
261    _: BpfValue,
262    _: BpfValue,
263    _: BpfValue,
264) -> BpfValue {
265    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_ktime_get_coarse_ns");
266    0.into()
267}
268
269fn bpf_probe_read_str<C: EbpfProgramContext>(
270    _context: &mut C::RunContext<'_>,
271    _: BpfValue,
272    _: BpfValue,
273    _: BpfValue,
274    _: BpfValue,
275    _: BpfValue,
276) -> BpfValue {
277    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_probe_read_str");
278    0.into()
279}
280
281fn bpf_get_smp_processor_id<C: EbpfProgramContext>(
282    _context: &mut C::RunContext<'_>,
283    _: BpfValue,
284    _: BpfValue,
285    _: BpfValue,
286    _: BpfValue,
287    _: BpfValue,
288) -> BpfValue {
289    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_get_smp_processor_id");
290    0.into()
291}
292
293pub trait CurrentTaskContext {
294    fn get_uid_gid(&self) -> (uid_t, gid_t);
295    fn get_tid_tgid(&self) -> (pid_t, pid_t);
296}
297
298pub trait CurrentTaskProgramContext: EbpfProgramContext {
299    fn get_uid_gid<'a>(context: &mut Self::RunContext<'a>) -> (uid_t, gid_t);
300    fn get_tid_tgid<'a>(context: &mut Self::RunContext<'a>) -> (pid_t, pid_t);
301}
302
303impl<C: EbpfProgramContext> CurrentTaskProgramContext for C
304where
305    for<'a> C::RunContext<'a>: CurrentTaskContext,
306{
307    fn get_uid_gid<'a>(context: &mut Self::RunContext<'a>) -> (uid_t, gid_t) {
308        context.get_uid_gid()
309    }
310    fn get_tid_tgid<'a>(context: &mut Self::RunContext<'a>) -> (pid_t, pid_t) {
311        context.get_tid_tgid()
312    }
313}
314
315fn bpf_get_current_uid_gid<C: CurrentTaskProgramContext>(
316    context: &mut C::RunContext<'_>,
317    _: BpfValue,
318    _: BpfValue,
319    _: BpfValue,
320    _: BpfValue,
321    _: BpfValue,
322) -> BpfValue {
323    let (uid, gid) = C::get_uid_gid(context);
324    (uid as u64 | (gid as u64) << 32).into()
325}
326
327fn bpf_get_current_pid_tgid<C: CurrentTaskProgramContext>(
328    context: &mut C::RunContext<'_>,
329    _: BpfValue,
330    _: BpfValue,
331    _: BpfValue,
332    _: BpfValue,
333    _: BpfValue,
334) -> BpfValue {
335    let (pid, tgid) = C::get_tid_tgid(context);
336    (pid as u64 | (tgid as u64) << 32).into()
337}
338
339// Trait for `EbpfProgramContext` where the first argument is a `SocketRef`,
340// i.e. it references a socket.
341pub trait Arg1IsSocketProgramContext: EbpfProgramContext {
342    type Arg1AsSocket<'a>: FromBpfValue<Self::RunContext<'a>> + SocketRef;
343}
344
345impl<C> Arg1IsSocketProgramContext for C
346where
347    C: EbpfProgramContext,
348    for<'a> Self::Arg1<'a>: FromBpfValue<Self::RunContext<'a>> + SocketRef,
349{
350    type Arg1AsSocket<'a> = Self::Arg1<'a>;
351}
352
353// Marker trait for `EbpfProgramContext` that supports `bpf_get_socket_uid`.
354pub trait SocketCookieProgramContext: Arg1IsSocketProgramContext {}
355impl<C> SocketCookieProgramContext for C where C: Arg1IsSocketProgramContext {}
356
357fn bpf_get_socket_cookie<'a, C: SocketCookieProgramContext>(
358    context: &mut C::RunContext<'a>,
359    arg1: BpfValue,
360    _: BpfValue,
361    _: BpfValue,
362    _: BpfValue,
363    _: BpfValue,
364) -> BpfValue {
365    // SAFETY: Verifier checks that the argument points at the value that
366    // that's passed as the first argument.
367    let arg1_as_socket = unsafe { C::Arg1AsSocket::from_bpf_value(context, arg1) };
368    arg1_as_socket.get_socket_cookie().unwrap_or(0).into()
369}
370
371pub trait SocketRef {
372    fn get_socket_cookie(&self) -> Option<u64>;
373    fn get_socket_uid(&self) -> Option<uid_t>;
374}
375
376// A trait for eBPF run context with `bpf_sock` pointers.
377pub trait BpfSockContext: Sized {
378    type BpfSockRef: SocketRef + FromBpfValue<Self>;
379}
380
381pub trait SkStorageProgramContext: EbpfProgramContext {
382    type BpfSockRef<'a>: SocketRef + FromBpfValue<Self::RunContext<'a>>;
383}
384
385impl<C> SkStorageProgramContext for C
386where
387    C: EbpfProgramContext,
388    for<'a> C::RunContext<'a>: BpfSockContext,
389{
390    type BpfSockRef<'a> = <C::RunContext<'a> as BpfSockContext>::BpfSockRef;
391}
392
393#[derive(Copy, Clone, Debug, PartialEq, Eq)]
394pub enum LoadBytesBase {
395    MacHeader,
396    NetworkHeader,
397}
398
399// Marker trait for `EbpfProgramContext` that supports `bpf_get_socket_uid`.
400pub trait SocketUidProgramContext: Arg1IsSocketProgramContext {}
401impl<C> SocketUidProgramContext for C where C: Arg1IsSocketProgramContext {}
402
403fn bpf_get_socket_uid<'a, C: SocketUidProgramContext>(
404    context: &mut C::RunContext<'a>,
405    sk_buf: BpfValue,
406    _: BpfValue,
407    _: BpfValue,
408    _: BpfValue,
409    _: BpfValue,
410) -> BpfValue {
411    const OVERFLOW_UID: uid_t = 65534;
412    // SAFETY: Verifier checks that the first argument points at a `__sk_buff`.
413    let sk_buf = unsafe { C::Arg1AsSocket::from_bpf_value(context, sk_buf) };
414    sk_buf.get_socket_uid().unwrap_or(OVERFLOW_UID).into()
415}
416
417// Trait for packets that support `bpf_load_bytes_relative`.
418pub trait PacketWithLoadBytes {
419    fn load_bytes_relative(&self, base: LoadBytesBase, offset: usize, buf: &mut [u8]) -> i64;
420}
421
422// Trait for `EbpfProgramContext` that supports `bpf_load_bytes_relative`.
423pub trait SkbLoadBytesProgramContext: EbpfProgramContext {
424    fn skb_load_bytes_relative<'a>(
425        context: &mut Self::RunContext<'a>,
426        sk_buf: BpfValue,
427        base: LoadBytesBase,
428        offset: usize,
429        buf: &mut [u8],
430    ) -> i64;
431}
432
433impl<C: EbpfProgramContext> SkbLoadBytesProgramContext for C
434where
435    for<'b> C::Arg1<'b>: FromBpfValue<C::RunContext<'b>>,
436    for<'b> C::Arg1<'b>: PacketWithLoadBytes,
437{
438    fn skb_load_bytes_relative<'a>(
439        context: &mut Self::RunContext<'a>,
440        sk_buf: BpfValue,
441        base: LoadBytesBase,
442        offset: usize,
443        buf: &mut [u8],
444    ) -> i64 {
445        // SAFETY: Verifier checks that the argument points at the same value
446        // that was passed to the program as the first argument.
447        let sk_buf = unsafe { C::Arg1::from_bpf_value(context, sk_buf) };
448        sk_buf.load_bytes_relative(base, offset, buf)
449    }
450}
451
452fn bpf_skb_load_bytes<'a, C: SkbLoadBytesProgramContext>(
453    context: &mut C::RunContext<'a>,
454    sk_buf: BpfValue,
455    offset: BpfValue,
456    to: BpfValue,
457    len: BpfValue,
458    _: BpfValue,
459) -> BpfValue {
460    let base = LoadBytesBase::NetworkHeader;
461
462    let Ok(offset) = offset.as_u64().try_into() else {
463        return u64::MAX.into();
464    };
465
466    // SAFETY: The verifier ensures that `to` points to a valid buffer of at
467    // least `len` bytes that the eBPF program has permission to access.
468    let buf = unsafe { slice::from_raw_parts_mut(to.as_ptr::<u8>(), len.as_u64() as usize) };
469
470    C::skb_load_bytes_relative(context, sk_buf, base, offset, buf).into()
471}
472
473fn bpf_skb_load_bytes_relative<'a, C: SkbLoadBytesProgramContext>(
474    context: &mut C::RunContext<'a>,
475    sk_buf: BpfValue,
476    offset: BpfValue,
477    to: BpfValue,
478    len: BpfValue,
479    start_header: BpfValue,
480) -> BpfValue {
481    let base = match start_header.as_u64() {
482        0 => LoadBytesBase::MacHeader,
483        1 => LoadBytesBase::NetworkHeader,
484        _ => return u64::MAX.into(),
485    };
486
487    let Ok(offset) = offset.as_u64().try_into() else {
488        return u64::MAX.into();
489    };
490
491    // SAFETY: The verifier ensures that `to` points to a valid buffer of at
492    // least `len` bytes that the eBPF program has permission to access.
493    let buf = unsafe { slice::from_raw_parts_mut(to.as_ptr::<u8>(), len.as_u64() as usize) };
494
495    C::skb_load_bytes_relative(context, sk_buf, base, offset, buf).into()
496}
497
498fn bpf_sk_storage_get<'a, C: SkStorageProgramContext + MapsProgramContext>(
499    context: &mut C::RunContext<'a>,
500    map: BpfValue,
501    sk: BpfValue,
502    value: BpfValue,
503    flags: BpfValue,
504    _: BpfValue,
505) -> BpfValue {
506    if sk.is_zero() {
507        return BpfValue::default();
508    }
509
510    // SAFETY: Verifier ensures that `sk` is either null or a pointer to
511    // `bpf_sock`. The null case is checked above.
512    let bpf_sock = unsafe { C::BpfSockRef::from_bpf_value(context, sk) };
513
514    // Use socket cookie to identify the socket in the map.
515    let Some(socket_id) = bpf_sock.get_socket_cookie() else {
516        return BpfValue::default();
517    };
518
519    let key = socket_id.as_bytes();
520
521    // SAFETY: The `map` must be a reference to a `Map` object kept alive by the program itself.
522    let map: &Map = unsafe { &*map.as_ptr::<Map>() };
523
524    // Checked by the verifier.
525    assert!(map.schema.map_type == bpf_map_type_BPF_MAP_TYPE_SK_STORAGE);
526
527    C::on_map_access(context, map);
528
529    if let Some(value_ref) = map.lookup(key) {
530        let result: BpfValue = value_ref.ptr().raw_ptr().into();
531        C::add_value_ref(context, value_ref);
532        return result;
533    }
534
535    if flags.as_u32() & BPF_SK_STORAGE_GET_F_CREATE != 0 {
536        let mut vec;
537        let init_val = if value.as_u64() == 0 {
538            vec = SmallVec::<[u8; 128]>::new();
539            vec.resize(map.schema.value_size as usize, 0);
540            (&mut vec[..]).into()
541        } else {
542            // SAFETY: The verifier ensures that `value` points to a valid buffer.
543            unsafe { EbpfBufferPtr::new(value.as_ptr::<u8>(), map.schema.value_size as usize) }
544        };
545
546        let r = map.update(key, init_val, 0);
547        if r.is_ok() {
548            if let Some(value_ref) = map.lookup(key) {
549                let result: BpfValue = value_ref.ptr().raw_ptr().into();
550                C::add_value_ref(context, value_ref);
551                return result;
552            }
553        }
554    }
555
556    BpfValue::default()
557}
558
559fn bpf_sk_fullsock<C: EbpfProgramContext>(
560    _context: &mut C::RunContext<'_>,
561    _: BpfValue,
562    _: BpfValue,
563    _: BpfValue,
564    _: BpfValue,
565    _: BpfValue,
566) -> BpfValue {
567    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_sk_fullsock");
568    0.into()
569}
570
571pub trait ReturnValueContext {
572    fn set_retval(&mut self, value: i32) -> i32;
573    fn get_retval(&self) -> i32;
574}
575
576pub trait ReturnValueProgramContext: EbpfProgramContext {
577    fn set_retval<'a>(context: &mut Self::RunContext<'a>, value: i32) -> i32;
578    fn get_retval<'a>(context: &mut Self::RunContext<'a>) -> i32;
579}
580
581impl<C: EbpfProgramContext> ReturnValueProgramContext for C
582where
583    for<'a> C::RunContext<'a>: ReturnValueContext,
584{
585    fn set_retval<'a>(context: &mut Self::RunContext<'a>, value: i32) -> i32 {
586        context.set_retval(value)
587    }
588    fn get_retval<'a>(context: &mut Self::RunContext<'a>) -> i32 {
589        context.get_retval()
590    }
591}
592
593fn bpf_set_retval<C: ReturnValueProgramContext>(
594    context: &mut C::RunContext<'_>,
595    value: BpfValue,
596    _: BpfValue,
597    _: BpfValue,
598    _: BpfValue,
599    _: BpfValue,
600) -> BpfValue {
601    C::set_retval(context, value.as_i32()).into()
602}
603
604fn bpf_get_retval<C: ReturnValueProgramContext>(
605    context: &mut C::RunContext<'_>,
606    _: BpfValue,
607    _: BpfValue,
608    _: BpfValue,
609    _: BpfValue,
610    _: BpfValue,
611) -> BpfValue {
612    C::get_retval(context).into()
613}
614
615fn bpf_sk_lookup_tcp<C: EbpfProgramContext>(
616    _context: &mut C::RunContext<'_>,
617    _: BpfValue,
618    _: BpfValue,
619    _: BpfValue,
620    _: BpfValue,
621    _: BpfValue,
622) -> BpfValue {
623    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_sk_lookup_tcp");
624    0.into()
625}
626
627fn bpf_sk_lookup_udp<C: EbpfProgramContext>(
628    _context: &mut C::RunContext<'_>,
629    _: BpfValue,
630    _: BpfValue,
631    _: BpfValue,
632    _: BpfValue,
633    _: BpfValue,
634) -> BpfValue {
635    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_sk_lookup_udp");
636    0.into()
637}
638
639fn bpf_sk_release<C: EbpfProgramContext>(
640    _context: &mut C::RunContext<'_>,
641    _: BpfValue,
642    _: BpfValue,
643    _: BpfValue,
644    _: BpfValue,
645    _: BpfValue,
646) -> BpfValue {
647    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_sk_release");
648    0.into()
649}
650
651fn bpf_get_netns_cookie<C: EbpfProgramContext>(
652    _context: &mut C::RunContext<'_>,
653    _: BpfValue,
654    _: BpfValue,
655    _: BpfValue,
656    _: BpfValue,
657    _: BpfValue,
658) -> BpfValue {
659    track_stub!(TODO("https://fxbug.dev/287120494"), "bpf_get_netns_cookie");
660    const DEFAULT_NETWORK_NAMESPACE_COOKIE: u64 = 1;
661    DEFAULT_NETWORK_NAMESPACE_COOKIE.into()
662}
663
664fn get_common_helpers<C: MapsProgramContext>() -> impl Iterator<Item = (u32, EbpfHelperImpl<C>)> {
665    [
666        (bpf_func_id_BPF_FUNC_ktime_get_boot_ns, EbpfHelperImpl(bpf_ktime_get_boot_ns)),
667        (bpf_func_id_BPF_FUNC_ktime_get_coarse_ns, EbpfHelperImpl(bpf_ktime_get_coarse_ns)),
668        (bpf_func_id_BPF_FUNC_ktime_get_ns, EbpfHelperImpl(bpf_ktime_get_ns)),
669        (bpf_func_id_BPF_FUNC_map_delete_elem, EbpfHelperImpl(bpf_map_delete_elem)),
670        (bpf_func_id_BPF_FUNC_map_lookup_elem, EbpfHelperImpl(bpf_map_lookup_elem)),
671        (bpf_func_id_BPF_FUNC_map_update_elem, EbpfHelperImpl(bpf_map_update_elem)),
672        (bpf_func_id_BPF_FUNC_probe_read_str, EbpfHelperImpl(bpf_probe_read_str)),
673        (bpf_func_id_BPF_FUNC_probe_read_user, EbpfHelperImpl(bpf_probe_read_user)),
674        (bpf_func_id_BPF_FUNC_probe_read_user_str, EbpfHelperImpl(bpf_probe_read_user_str)),
675        (bpf_func_id_BPF_FUNC_ringbuf_discard, EbpfHelperImpl(bpf_ringbuf_discard)),
676        (bpf_func_id_BPF_FUNC_ringbuf_reserve, EbpfHelperImpl(bpf_ringbuf_reserve)),
677        (bpf_func_id_BPF_FUNC_ringbuf_submit, EbpfHelperImpl(bpf_ringbuf_submit)),
678        (bpf_func_id_BPF_FUNC_trace_printk, EbpfHelperImpl(bpf_trace_printk)),
679        (bpf_func_id_BPF_FUNC_get_smp_processor_id, EbpfHelperImpl(bpf_get_smp_processor_id)),
680    ]
681    .into_iter()
682}
683
684/// Returns helper implementations that depend on `CurrentTask`.
685fn get_current_task_helpers<C: CurrentTaskProgramContext>()
686-> impl Iterator<Item = (u32, EbpfHelperImpl<C>)> {
687    [
688        (bpf_func_id_BPF_FUNC_get_current_uid_gid, EbpfHelperImpl(bpf_get_current_uid_gid)),
689        (bpf_func_id_BPF_FUNC_get_current_pid_tgid, EbpfHelperImpl(bpf_get_current_pid_tgid)),
690    ]
691    .into_iter()
692}
693
694// Trait for `EbpfProgramContext` implementations that are used for
695// `BPF_PROG_TYPE_CGROUP_SOCK` programs.
696pub trait CgroupSockProgramContext:
697    MapsProgramContext
698    + SocketCookieProgramContext
699    + CurrentTaskProgramContext
700    + SkStorageProgramContext
701{
702    fn get_helpers() -> HelperSet<Self> {
703        [
704            (bpf_func_id_BPF_FUNC_get_netns_cookie, EbpfHelperImpl(bpf_get_netns_cookie)),
705            (bpf_func_id_BPF_FUNC_get_socket_cookie, EbpfHelperImpl(bpf_get_socket_cookie)),
706            (bpf_func_id_BPF_FUNC_sk_storage_get, EbpfHelperImpl(bpf_sk_storage_get)),
707            (bpf_func_id_BPF_FUNC_sk_lookup_tcp, EbpfHelperImpl(bpf_sk_lookup_tcp)),
708            (bpf_func_id_BPF_FUNC_sk_lookup_udp, EbpfHelperImpl(bpf_sk_lookup_udp)),
709            (bpf_func_id_BPF_FUNC_sk_release, EbpfHelperImpl(bpf_sk_release)),
710        ]
711        .into_iter()
712        .chain(get_common_helpers())
713        .chain(get_current_task_helpers())
714        .collect()
715    }
716}
717
718// Trait for `EbpfProgramContext` implementations that are used for
719// `BPF_PROG_TYPE_CGROUP_SOCKADDR` programs.
720pub trait CgroupSockAddrProgramContext:
721    MapsProgramContext
722    + SocketCookieProgramContext
723    + CurrentTaskProgramContext
724    + SkStorageProgramContext
725{
726    fn get_helpers() -> HelperSet<Self> {
727        [
728            (bpf_func_id_BPF_FUNC_get_netns_cookie, EbpfHelperImpl(bpf_get_netns_cookie)),
729            (bpf_func_id_BPF_FUNC_get_socket_cookie, EbpfHelperImpl(bpf_get_socket_cookie)),
730            (bpf_func_id_BPF_FUNC_sk_storage_get, EbpfHelperImpl(bpf_sk_storage_get)),
731            (bpf_func_id_BPF_FUNC_sk_lookup_tcp, EbpfHelperImpl(bpf_sk_lookup_tcp)),
732            (bpf_func_id_BPF_FUNC_sk_lookup_udp, EbpfHelperImpl(bpf_sk_lookup_udp)),
733            (bpf_func_id_BPF_FUNC_sk_release, EbpfHelperImpl(bpf_sk_release)),
734        ]
735        .into_iter()
736        .chain(get_common_helpers())
737        .chain(get_current_task_helpers())
738        .collect()
739    }
740}
741
742// Trait for `EbpfProgramContext` implementations that are used for
743// `BPF_PROG_TYPE_CGROUP_SOCKOPT` programs.
744pub trait CgroupSockOptProgramContext:
745    MapsProgramContext + CurrentTaskProgramContext + ReturnValueProgramContext + SkStorageProgramContext
746{
747    fn get_helpers() -> HelperSet<Self> {
748        [
749            (bpf_func_id_BPF_FUNC_get_netns_cookie, EbpfHelperImpl(bpf_get_netns_cookie)),
750            (bpf_func_id_BPF_FUNC_set_retval, EbpfHelperImpl(bpf_set_retval)),
751            (bpf_func_id_BPF_FUNC_get_retval, EbpfHelperImpl(bpf_get_retval)),
752            (bpf_func_id_BPF_FUNC_sk_storage_get, EbpfHelperImpl(bpf_sk_storage_get)),
753            (bpf_func_id_BPF_FUNC_sk_lookup_tcp, EbpfHelperImpl(bpf_sk_lookup_tcp)),
754            (bpf_func_id_BPF_FUNC_sk_lookup_udp, EbpfHelperImpl(bpf_sk_lookup_udp)),
755            (bpf_func_id_BPF_FUNC_sk_release, EbpfHelperImpl(bpf_sk_release)),
756        ]
757        .into_iter()
758        .chain(get_common_helpers())
759        .chain(get_current_task_helpers())
760        .collect()
761    }
762}
763
764// Trait for `EbpfProgramContext` implementations that are used for
765// `BPF_PROG_TYPE_SOCKET_FILTER` programs.
766pub trait SocketFilterProgramContext:
767    MapsProgramContext
768    + SocketUidProgramContext
769    + SocketCookieProgramContext
770    + SkbLoadBytesProgramContext
771{
772    fn get_helpers() -> HelperSet<Self> {
773        vec![
774            (bpf_func_id_BPF_FUNC_get_netns_cookie, EbpfHelperImpl(bpf_get_netns_cookie)),
775            (bpf_func_id_BPF_FUNC_get_socket_uid, EbpfHelperImpl(bpf_get_socket_uid)),
776            (bpf_func_id_BPF_FUNC_get_socket_cookie, EbpfHelperImpl(bpf_get_socket_cookie)),
777            (bpf_func_id_BPF_FUNC_skb_load_bytes, EbpfHelperImpl(bpf_skb_load_bytes)),
778            (
779                bpf_func_id_BPF_FUNC_skb_load_bytes_relative,
780                EbpfHelperImpl(bpf_skb_load_bytes_relative),
781            ),
782        ]
783        .into_iter()
784        .chain(get_common_helpers())
785        .collect()
786    }
787}
788
789// Trait for `EbpfProgramContext` implementations that are used for
790// `BPF_PROG_TYPE_CGROUP_SKB` programs.
791pub trait CgroupSkbProgramContext:
792    MapsProgramContext
793    + SocketUidProgramContext
794    + SocketCookieProgramContext
795    + SkbLoadBytesProgramContext
796    + SkStorageProgramContext
797{
798    fn get_helpers() -> HelperSet<Self> {
799        vec![
800            (bpf_func_id_BPF_FUNC_get_netns_cookie, EbpfHelperImpl(bpf_get_netns_cookie)),
801            (bpf_func_id_BPF_FUNC_get_socket_uid, EbpfHelperImpl(bpf_get_socket_uid)),
802            (bpf_func_id_BPF_FUNC_get_socket_cookie, EbpfHelperImpl(bpf_get_socket_cookie)),
803            (bpf_func_id_BPF_FUNC_skb_load_bytes, EbpfHelperImpl(bpf_skb_load_bytes)),
804            (
805                bpf_func_id_BPF_FUNC_skb_load_bytes_relative,
806                EbpfHelperImpl(bpf_skb_load_bytes_relative),
807            ),
808            (bpf_func_id_BPF_FUNC_sk_storage_get, EbpfHelperImpl(bpf_sk_storage_get)),
809            (bpf_func_id_BPF_FUNC_sk_lookup_tcp, EbpfHelperImpl(bpf_sk_lookup_tcp)),
810            (bpf_func_id_BPF_FUNC_sk_lookup_udp, EbpfHelperImpl(bpf_sk_lookup_udp)),
811            (bpf_func_id_BPF_FUNC_sk_release, EbpfHelperImpl(bpf_sk_release)),
812            (bpf_func_id_BPF_FUNC_sk_fullsock, EbpfHelperImpl(bpf_sk_fullsock)),
813        ]
814        .into_iter()
815        .chain(get_common_helpers())
816        .collect()
817    }
818}
819
820/// Macro used to declare program type for a `EbpfProgramContext` implementation.
821/// Implements `StaticHelperSet` trait for the context type.
822///
823/// # Example
824///
825/// The following example declares that `MyEbpfProgramContext` is used to run
826/// socket filter programs:
827///
828/// ```
829/// ebpf_program_context_type!(MyEbpfProgramContext, SocketFilterProgramContext);
830/// ```
831#[macro_export]
832macro_rules! ebpf_program_context_type {
833    ($context:ty, $subtrait:ty) => {
834        impl $subtrait for $context {}
835        ebpf::static_helper_set!($context, <$context as $subtrait>::get_helpers());
836    };
837}
838
839#[cfg(test)]
840mod tests {
841    use super::*;
842    use crate::maps::{Map, PinnedMap};
843    use ebpf::{BpfValue, EbpfProgramContext, FromBpfValue, MapFlags, MapSchema};
844    use linux_uapi::{BPF_SK_STORAGE_GET_F_CREATE, bpf_map_type_BPF_MAP_TYPE_SK_STORAGE};
845
846    struct MockSocket {
847        cookie: u64,
848    }
849    impl SocketRef for MockSocket {
850        fn get_socket_cookie(&self) -> Option<u64> {
851            Some(self.cookie)
852        }
853        fn get_socket_uid(&self) -> Option<uid_t> {
854            Some(0)
855        }
856    }
857    impl<'a> FromBpfValue<TestRunContext<'a>> for MockSocket {
858        unsafe fn from_bpf_value(_context: &mut TestRunContext<'a>, value: BpfValue) -> Self {
859            Self { cookie: value.as_u64() }
860        }
861    }
862
863    struct TestRunContext<'a> {
864        map_refs: Vec<MapValueRef<'a>>,
865    }
866    impl<'a> BpfSockContext for TestRunContext<'a> {
867        type BpfSockRef = MockSocket;
868    }
869    impl<'a> MapsContext<'a> for TestRunContext<'a> {
870        fn on_map_access(&mut self, _map: &Map) {}
871        fn add_value_ref(&mut self, map_ref: MapValueRef<'a>) {
872            self.map_refs.push(map_ref);
873        }
874    }
875
876    struct TestContext;
877    impl EbpfProgramContext for TestContext {
878        type RunContext<'a> = TestRunContext<'a>;
879        type Packet<'a> = ();
880        type Arg1<'a> = ();
881        type Arg2<'a> = ();
882        type Arg3<'a> = ();
883        type Arg4<'a> = ();
884        type Arg5<'a> = ();
885        type Map = PinnedMap;
886    }
887
888    #[fuchsia::test]
889    fn test_sk_storage_get_uaf() {
890        let schema = MapSchema {
891            map_type: bpf_map_type_BPF_MAP_TYPE_SK_STORAGE,
892            key_size: 4,
893            value_size: 8,
894            max_entries: 0,
895            flags: MapFlags::NoPrealloc,
896        };
897        let map = Map::new(schema, "test").unwrap();
898        let map_value = BpfValue::from(&*map as *const Map);
899
900        let mut context = TestRunContext { map_refs: vec![] };
901
902        // 1. Create entry for socket 42
903        let sk_value1 = BpfValue::from(42u64);
904        let init_value1 = [0x11u8; 8];
905        let init_value_ptr1 = BpfValue::from(init_value1.as_ptr());
906        let flags = BpfValue::from(BPF_SK_STORAGE_GET_F_CREATE as u64);
907
908        let ptr1 = bpf_sk_storage_get::<TestContext>(
909            &mut context,
910            map_value,
911            sk_value1,
912            init_value_ptr1,
913            flags,
914            BpfValue::default(),
915        );
916        assert!(!ptr1.is_zero());
917
918        // Verify initial value
919        // SAFETY: ptr1 is a valid pointer to the map value.
920        unsafe {
921            assert_eq!(*(ptr1.as_ptr::<u64>()), 0x1111111111111111);
922        }
923
924        // 2. Delete entry for socket 42 from map
925        let key_bytes = 42u64.to_ne_bytes();
926        map.delete(&key_bytes).unwrap();
927
928        // 3. Create entry for socket 43
929        // If UAF exists, this should reuse the same memory block because it was freed.
930        let sk_value2 = BpfValue::from(43u64);
931        let init_value2 = [0x22u8; 8];
932        let init_value_ptr2 = BpfValue::from(init_value2.as_ptr());
933
934        let ptr2 = bpf_sk_storage_get::<TestContext>(
935            &mut context,
936            map_value,
937            sk_value2,
938            init_value_ptr2,
939            flags,
940            BpfValue::default(),
941        );
942        assert!(!ptr2.is_zero());
943
944        // We want to assert that the value at ptr1 has NOT changed, which means it was not reused.
945        // This assertion will FAIL without the fix (UAF occurs,
946        // ptr1's memory is overwritten with ptr2's init value),
947        // and PASS with the fix (ptr1's memory is kept alive).
948        // SAFETY: ptr1 points to memory that is kept alive by the reference in `context`.
949        unsafe {
950            assert_eq!(*(ptr1.as_ptr::<u64>()), 0x1111111111111111);
951        }
952    }
953
954    #[fuchsia::test]
955    fn test_sk_storage_get_uaf_query() {
956        let schema = MapSchema {
957            map_type: bpf_map_type_BPF_MAP_TYPE_SK_STORAGE,
958            key_size: 4,
959            value_size: 8,
960            max_entries: 0,
961            flags: MapFlags::NoPrealloc,
962        };
963        let map = Map::new(schema, "test").unwrap();
964        let map_value = BpfValue::from(&*map as *const Map);
965
966        let mut context = TestRunContext { map_refs: vec![] };
967
968        // 1. Create entry for socket 42
969        let sk_value1 = BpfValue::from(42u64);
970        let init_value1 = [0x11u8; 8];
971        let init_value_ptr1 = BpfValue::from(init_value1.as_ptr());
972        let flags = BpfValue::from(BPF_SK_STORAGE_GET_F_CREATE as u64);
973
974        let ptr1 = bpf_sk_storage_get::<TestContext>(
975            &mut context,
976            map_value,
977            sk_value1,
978            init_value_ptr1,
979            flags,
980            BpfValue::default(),
981        );
982        assert!(!ptr1.is_zero());
983
984        // Clear context to simulate that we don't hold the creation
985        // reference anymore. The map still holds the reference.
986        context.map_refs.clear();
987
988        // 2. Query entry for socket 42 (without CREATE flag)
989        let ptr1_query = bpf_sk_storage_get::<TestContext>(
990            &mut context,
991            map_value,
992            sk_value1,
993            BpfValue::default(),
994            BpfValue::default(),
995            BpfValue::default(),
996        );
997        assert_eq!(ptr1.as_u64(), ptr1_query.as_u64());
998
999        // 3. Delete entry for socket 42 from map
1000        let key_bytes = 42u64.to_ne_bytes();
1001        map.delete(&key_bytes).unwrap();
1002
1003        // 4. Create entry for socket 43
1004        // If UAF exists, this should reuse the same memory block
1005        // because it was freed.
1006        let sk_value2 = BpfValue::from(43u64);
1007        let init_value2 = [0x22u8; 8];
1008        let init_value_ptr2 = BpfValue::from(init_value2.as_ptr());
1009
1010        let ptr2 = bpf_sk_storage_get::<TestContext>(
1011            &mut context,
1012            map_value,
1013            sk_value2,
1014            init_value_ptr2,
1015            flags,
1016            BpfValue::default(),
1017        );
1018        assert!(!ptr2.is_zero());
1019
1020        // We want to assert that the value at ptr1_query has NOT
1021        // changed.
1022        // SAFETY: ptr1_query points to memory that is kept alive by
1023        // the reference in `context` (from the query).
1024        unsafe {
1025            assert_eq!(*(ptr1_query.as_ptr::<u64>()), 0x1111111111111111);
1026        }
1027    }
1028}