Skip to main content

netlink_packet_sock_diag/inet/
bytecode.rs

1// SPDX-License-Identifier: MIT
2
3//! Functionality for parsing and serializing INET_DIAG bytecode programs.
4//!
5//! SOCK_DIAG_BY_FAMILY requests with NLM_F_DUMP can accept a bytecode program.
6//! The program is run against all of the sockets matching the standard part of
7//! the request (though some fields, like socket_id, are not examined at all).
8//! If the program accepts a socket, it is returned to the caller. Acceptance is
9//! signalled by the program reaching the length of the buffer exactly.
10//! Rejection is signalled by the program jumping to somewhere past this.
11//!
12//! Each instruction is composed of the following basic structure, where `yes`
13//! and `no` are how many bytes jump forward if the instruction matches or not.
14//! Note that this means there are no loops and all programs trivially must
15//! terminate:
16//!
17//! ```c
18//! opcode: u8,
19//! yes: u8,
20//! no: u16,
21//! // Followed (optionally) by parameters for the instruction.
22//! ```
23//!
24//! Instructions are variable-length, which is unwieldy to deal with in Rust, so
25//! instead we represent a program as a series of fixed-length instructions,
26//! which requires mapping back and forth to byte offsets during parsing and
27//! serialization.
28//!
29//! There is a small loss of fidelity in this Rust representation. The types
30//! here encode acception and rejection explicitly, which means there's only a
31//! single rejection target. It also encodes NOPs and jumps more simply,
32//! forgoing the `yes` and `no` fields entirely. While this shouldn't make a
33//! semantic difference, it does mean round-tripping a program might result in a
34//! different representation.
35
36use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
37use std::num::NonZeroUsize;
38
39use netlink_packet_utils::{DecodeError, buffer};
40
41use crate::constants::{
42    AF_INET, AF_INET6, AF_UNSPEC, INET_DIAG_BC_AUTO, INET_DIAG_BC_CGROUP_COND, INET_DIAG_BC_D_COND,
43    INET_DIAG_BC_D_EQ, INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE, INET_DIAG_BC_DEV_COND,
44    INET_DIAG_BC_JMP, INET_DIAG_BC_MARK_COND, INET_DIAG_BC_NOP, INET_DIAG_BC_S_COND,
45    INET_DIAG_BC_S_EQ, INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
46};
47
48/// Types for keeping track of various `usize`s during parsing.
49///
50/// The two axes are "byte or instruction" and "index or offset (from current
51/// instruction)". They're in a submodule to force everyone to go through the
52/// interface.
53mod wrappers {
54    use std::num::NonZeroUsize;
55
56    /// The absolute index of a parsed [`Instruction`](super::Instruction).
57    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
58    pub(super) struct InstructionIndex(usize);
59
60    impl InstructionIndex {
61        pub(super) fn new(offset: usize) -> Self {
62            Self(offset)
63        }
64
65        pub(super) fn get(self) -> usize {
66            let Self(val) = self;
67            val
68        }
69
70        pub(super) fn checked_add(self, rhs: InstructionOffset) -> Option<Self> {
71            let Self(val) = self;
72            val.checked_add(rhs.get().get()).map(Self)
73        }
74    }
75
76    /// The relative offset between two parsed
77    /// [`Instruction`s](super::Instruction).
78    ///
79    /// Guaranteed to be greater than 0.
80    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
81    pub(super) struct InstructionOffset(NonZeroUsize);
82
83    impl InstructionOffset {
84        pub(super) fn new(offset: usize) -> Option<Self> {
85            Some(Self(NonZeroUsize::new(offset)?))
86        }
87
88        pub(super) fn new_nonzero(offset: NonZeroUsize) -> Self {
89            Self(offset)
90        }
91
92        pub(super) fn get(self) -> NonZeroUsize {
93            let Self(val) = self;
94            val
95        }
96    }
97
98    /// The absolute index of a byte in an unparsed program.
99    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash)]
100    pub(super) struct ByteIndex(usize);
101
102    impl ByteIndex {
103        pub(super) fn new(offset: usize) -> Self {
104            Self(offset)
105        }
106
107        pub(super) fn get(self) -> usize {
108            let Self(val) = self;
109            val
110        }
111
112        pub(super) fn checked_add(self, rhs: ByteOffset) -> Option<Self> {
113            let Self(val) = self;
114            val.checked_add(rhs.get().get()).map(Self)
115        }
116
117        pub(super) fn checked_sub(self, rhs: ByteIndex) -> Option<ByteOffset> {
118            let Self(val) = self;
119            val.checked_sub(rhs.0).and_then(ByteOffset::new)
120        }
121    }
122
123    /// The relative offset between two bytes in a unparsed program.
124    ///
125    /// Guaranteed to be greater than 0.
126    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
127    pub(super) struct ByteOffset(NonZeroUsize);
128
129    impl ByteOffset {
130        pub(super) fn new(offset: usize) -> Option<Self> {
131            Some(Self(NonZeroUsize::new(offset)?))
132        }
133
134        pub(super) fn get(self) -> NonZeroUsize {
135            let Self(val) = self;
136            val
137        }
138    }
139}
140
141use wrappers::{ByteIndex, ByteOffset, InstructionIndex, InstructionOffset};
142
143/// The size of a Linux `struct inet_diag_bc_op`.
144const STRUCT_BC_OP_SIZE: usize = 4;
145const DEVICE_COND_SIZE: usize = 4;
146const MARK_COND_SIZE: usize = 8;
147const CGROUP_COND_SIZE: usize = 8;
148const TUPLE_COND_MIN_SIZE: usize = 8;
149const AF_INET_ADDR_LEN: usize = 4;
150const AF_INET6_ADDR_LEN: usize = 16;
151
152/// A bytecode program used by Linux to match AF_INET sockets.
153#[derive(Debug, PartialEq, Eq, Clone)]
154pub struct Bytecode(pub Vec<Instruction>);
155
156#[derive(Debug, PartialEq, Eq)]
157pub enum SerializationError {
158    BufferTooSmall,
159    /// An index in in the instruction with index `at` was too large to fit into
160    /// the serialized representation (u8 or u16 depending on the field).
161    IndexTooLargeForSerializedType {
162        at: usize,
163    },
164    IndexOverflow {
165        at: usize,
166    },
167}
168
169enum SerializationErrorCode {
170    IndexTooLargeForSerializedType,
171    IndexOverflow,
172}
173
174impl SerializationErrorCode {
175    fn at_index(self, index: InstructionIndex) -> SerializationError {
176        match self {
177            Self::IndexTooLargeForSerializedType => {
178                SerializationError::IndexTooLargeForSerializedType { at: index.get() }
179            }
180            Self::IndexOverflow => SerializationError::IndexOverflow { at: index.get() },
181        }
182    }
183}
184
185/// An error encountered when parsing a program from a raw byte buffer.
186#[derive(Debug, PartialEq, Eq)]
187pub struct ParseError {
188    /// The index in the provided buffer at which the error occurred.
189    pub index: usize,
190    /// The specific error that occurred.
191    pub code: ParseErrorCode,
192}
193
194#[derive(Debug, PartialEq, Eq)]
195pub enum ParseErrorCode {
196    TruncatedInstruction,
197    UnknownOpcode,
198    InvalidJumpTarget,
199    SelfReference,
200    InvalidAddressFamily,
201    PrefixLengthLongerThanAddress,
202    IndexOverflow,
203}
204
205impl ParseErrorCode {
206    fn at_index(self, index: ByteIndex) -> ParseError {
207        ParseError { index: index.get(), code: self }
208    }
209}
210
211impl Bytecode {
212    /// Returns the length of the serialized form of this bytecode.
213    ///
214    /// Useful for sizing the buffer passed to [`Bytecode::serialize`].
215    pub fn serialized_len(&self) -> usize {
216        self.0.iter().map(Instruction::serialized_len).sum()
217    }
218
219    /// Parse a bytecode program from the provided buffer.
220    pub fn parse(buf: &[u8]) -> Result<Self, ParseError> {
221        let mut raw_ops = vec![];
222        let mut curr_byte_index = ByteIndex::new(0);
223        let buf_len = ByteIndex::new(buf.len());
224        let mut instruction_index_by_byte_offset = std::collections::HashMap::new();
225
226        // First, we build up a map of instruction index to byte offset.
227        while curr_byte_index < buf_len {
228            instruction_index_by_byte_offset
229                .insert(curr_byte_index, InstructionIndex::new(raw_ops.len()));
230            let inst = RawInstruction::parse(&buf[curr_byte_index.get()..])
231                .map_err(|code| code.at_index(curr_byte_index))?;
232            let inst_len = inst.serialized_len();
233            raw_ops.push((curr_byte_index, inst));
234            curr_byte_index = curr_byte_index
235                .checked_add(inst_len)
236                .ok_or_else(|| ParseErrorCode::IndexOverflow.at_index(curr_byte_index))?;
237        }
238
239        // If curr_byte_index < buf_len, we would have looped again.
240        // If curr_byte_index > buf_len, we would have returned
241        // TruncatedInstruction.
242        assert_eq!(curr_byte_index, buf_len);
243
244        // Now, we resolve the raw byte offsets to indexes.
245
246        let raw_accept_offset = curr_byte_index;
247        // Linux bytecode validation ensures that there is only a single valid rejection offset.
248        let raw_reject_offset = raw_accept_offset
249            .checked_add(ByteOffset::new(4).unwrap())
250            .ok_or_else(|| ParseErrorCode::IndexOverflow.at_index(ByteIndex::new(0)))?;
251
252        let resolve = |target_offset: ByteIndex, current_index: InstructionIndex| {
253            if target_offset == raw_accept_offset {
254                Ok(Action::Accept)
255            } else if target_offset == raw_reject_offset {
256                Ok(Action::Reject)
257            } else if let Some(&target_index) = instruction_index_by_byte_offset.get(&target_offset)
258            {
259                // By construction we know that an instruction can't reference
260                // an earlier one, so current_index will always be less than or
261                // equal to target_index.
262                let offset = target_index.get().checked_sub(current_index.get()).unwrap();
263                let index_offset =
264                    InstructionOffset::new(offset).ok_or(ParseErrorCode::SelfReference)?;
265                Ok(Action::AdvanceBy(index_offset.get()))
266            } else {
267                Err(ParseErrorCode::InvalidJumpTarget)
268            }
269        };
270
271        let resolved_ops = raw_ops
272            .into_iter()
273            .enumerate()
274            .map(|(curr_instr_index, (curr_byte_index, raw_op))| {
275                let curr_instr_index = InstructionIndex::new(curr_instr_index);
276
277                match raw_op {
278                    RawInstruction::Nop(offset) => {
279                        let target = curr_byte_index.checked_add(offset).ok_or_else(|| {
280                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
281                        })?;
282                        let action = resolve(target, curr_instr_index)
283                            .map_err(|code| code.at_index(curr_byte_index))?;
284                        Ok(Instruction::Nop(action))
285                    }
286                    RawInstruction::Jmp(offset) => {
287                        let target = curr_byte_index.checked_add(offset).ok_or_else(|| {
288                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
289                        })?;
290                        let action = resolve(target, curr_instr_index)
291                            .map_err(|code| code.at_index(curr_byte_index))?;
292                        Ok(Instruction::Jmp(action))
293                    }
294                    RawInstruction::Condition { yes, no, condition } => {
295                        let yes_target = curr_byte_index.checked_add(yes).ok_or_else(|| {
296                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
297                        })?;
298                        let no_target = curr_byte_index.checked_add(no).ok_or_else(|| {
299                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
300                        })?;
301
302                        let yes = resolve(yes_target, curr_instr_index)
303                            .map_err(|code| code.at_index(curr_byte_index))?;
304                        let no = resolve(no_target, curr_instr_index)
305                            .map_err(|code| code.at_index(curr_byte_index))?;
306                        Ok(Instruction::Condition { yes, no, condition })
307                    }
308                }
309            })
310            .collect::<Result<Vec<_>, ParseError>>()?;
311
312        Ok(Bytecode(resolved_ops))
313    }
314
315    /// Serialize the bytecode into the provided buffer.
316    pub fn serialize(self, buf: &mut [u8]) -> Result<(), SerializationError> {
317        let Self(instructions) = self;
318
319        let mut total_len = ByteIndex::new(0);
320        let byte_indices_by_instruction_index: Vec<_> =
321            instructions
322                .iter()
323                .enumerate()
324                .map(|(i, inst)| {
325                    let res = total_len;
326                    match total_len.checked_add(ByteOffset::new(inst.serialized_len()).unwrap()) {
327                        Some(new_len) => {
328                            total_len = new_len;
329                            Ok(res)
330                        }
331                        None => Err(SerializationErrorCode::IndexOverflow
332                            .at_index(InstructionIndex::new(i))),
333                    }
334                })
335                .collect::<Result<Vec<_>, _>>()?;
336
337        if total_len.get() > buf.len() {
338            return Err(SerializationError::BufferTooSmall);
339        }
340
341        instructions.into_iter().enumerate().try_for_each(|(curr_inst_index, inst)| {
342            let curr_inst_index = InstructionIndex::new(curr_inst_index);
343            let curr_byte_index = byte_indices_by_instruction_index[curr_inst_index.get()];
344
345            inst.try_into_raw(
346                &byte_indices_by_instruction_index,
347                curr_inst_index,
348                curr_byte_index,
349                total_len,
350            )
351            .and_then(|raw| raw.serialize(&mut buf[curr_byte_index.get()..]))
352            .map_err(|e| e.at_index(curr_inst_index))
353        })
354    }
355}
356
357#[derive(Debug, PartialEq, Eq, Clone)]
358pub enum Instruction {
359    Nop(Action),
360    Jmp(Action),
361    Condition { yes: Action, no: Action, condition: Condition },
362}
363
364impl Instruction {
365    fn serialized_len(&self) -> usize {
366        STRUCT_BC_OP_SIZE
367            + match self {
368                Self::Nop(_) | Self::Jmp(_) => 0,
369                Self::Condition { condition, .. } => condition.serialized_len(),
370            }
371    }
372
373    fn try_into_raw(
374        self,
375        byte_indices_by_instruction_index: &[ByteIndex],
376        instruction_index: InstructionIndex,
377        byte_index: ByteIndex,
378        total_len: ByteIndex,
379    ) -> Result<RawInstruction, SerializationErrorCode> {
380        // Calculate relative offsets
381        let calculate_rel = |action| {
382            let target = match action {
383                Action::Accept => total_len,
384                // Linux checks that all targets are multiples of 4.
385                Action::Reject => total_len
386                    .checked_add(ByteOffset::new(4).unwrap())
387                    .ok_or(SerializationErrorCode::IndexOverflow)?,
388                Action::AdvanceBy(dist) => {
389                    let target_index = instruction_index
390                        .checked_add(InstructionOffset::new_nonzero(dist))
391                        .ok_or(SerializationErrorCode::IndexOverflow)?;
392                    byte_indices_by_instruction_index[target_index.get()]
393                }
394            };
395
396            // This is safe because the elements of offsets are strictly
397            // increasing, so indexing into my_index+dist (and we know dist
398            // can't be zero because of its type) must give a larger value.
399            Ok(target.checked_sub(byte_index).unwrap())
400        };
401
402        match self {
403            Instruction::Nop(action) => Ok(RawInstruction::Nop(calculate_rel(action)?)),
404            Instruction::Jmp(action) => Ok(RawInstruction::Jmp(calculate_rel(action)?)),
405            Instruction::Condition { yes, no, condition } => Ok(RawInstruction::Condition {
406                yes: calculate_rel(yes)?,
407                no: calculate_rel(no)?,
408                condition,
409            }),
410        }
411    }
412}
413
414#[derive(Debug, PartialEq, Eq, Clone)]
415pub enum Action {
416    Accept,
417    Reject,
418    AdvanceBy(NonZeroUsize),
419}
420
421enum RawInstruction {
422    Nop(ByteOffset),
423    Jmp(ByteOffset),
424    Condition { yes: ByteOffset, no: ByteOffset, condition: Condition },
425}
426
427buffer!(RawInstructionBuffer(STRUCT_BC_OP_SIZE) {
428    code: (u8, 0),
429    yes: (u8, 1),
430    no: (u16, 2..4),
431    payload: (slice, STRUCT_BC_OP_SIZE..),
432});
433
434impl RawInstruction {
435    fn serialized_len(&self) -> ByteOffset {
436        ByteOffset::new(
437            STRUCT_BC_OP_SIZE
438                + match self {
439                    Self::Nop(_) => 0,
440                    Self::Jmp(_) => 0,
441                    Self::Condition { condition, .. } => condition.serialized_len(),
442                },
443        )
444        .unwrap()
445    }
446
447    fn parse(buf: &[u8]) -> Result<RawInstruction, ParseErrorCode> {
448        let buf =
449            RawInstructionBuffer::new(buf).map_err(|_| ParseErrorCode::TruncatedInstruction)?;
450
451        let code = buf.code();
452        let yes = ByteOffset::new(buf.yes().into());
453        let no = ByteOffset::new(buf.no().into());
454
455        // Handle these separately because they don't follow the same pattern
456        // for how to handle yes and no as the other instructions.
457        if code == INET_DIAG_BC_NOP {
458            return Ok(RawInstruction::Nop(yes.ok_or(ParseErrorCode::SelfReference)?));
459        } else if code == INET_DIAG_BC_JMP {
460            return Ok(RawInstruction::Jmp(no.ok_or(ParseErrorCode::SelfReference)?));
461        }
462
463        fn port_cond<F>(buf: &[u8], f: F) -> Result<Condition, ParseErrorCode>
464        where
465            F: FnOnce(u16) -> Condition,
466        {
467            match PortConditionBuffer::new(buf) {
468                Ok(buf) => Ok(f(buf.port())),
469                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
470            }
471        }
472
473        // Put the condition at the beginning of buf.
474        let payload = buf.payload();
475        let condition = match code {
476            // Handled above.
477            INET_DIAG_BC_NOP => unreachable!(),
478            INET_DIAG_BC_JMP => unreachable!(),
479
480            INET_DIAG_BC_S_COND => TupleCondition::parse(payload).map(Condition::SrcTuple),
481            INET_DIAG_BC_D_COND => TupleCondition::parse(payload).map(Condition::DstTuple),
482            INET_DIAG_BC_DEV_COND => match DeviceConditionBuffer::new(payload) {
483                Ok(buf) => Ok(Condition::Device(buf.ifindex())),
484                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
485            },
486            INET_DIAG_BC_MARK_COND => match MarkConditionBuffer::new(payload) {
487                Ok(buf) => Ok(Condition::Mark { mark: buf.mark(), mask: buf.mask() }),
488                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
489            },
490            INET_DIAG_BC_S_EQ => port_cond(payload, Condition::SrcPortEq),
491            INET_DIAG_BC_D_EQ => port_cond(payload, Condition::DstPortEq),
492            INET_DIAG_BC_S_GE => port_cond(payload, Condition::SrcPortGe),
493            INET_DIAG_BC_D_GE => port_cond(payload, Condition::DstPortGe),
494            INET_DIAG_BC_S_LE => port_cond(payload, Condition::SrcPortLe),
495            INET_DIAG_BC_D_LE => port_cond(payload, Condition::DstPortLe),
496            INET_DIAG_BC_AUTO => Ok(Condition::AutoPort),
497            INET_DIAG_BC_CGROUP_COND => match CgroupConditionBuffer::new(payload) {
498                Ok(buf) => Ok(Condition::Cgroup(buf.cgroup_id())),
499                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
500            },
501            _ => Err(ParseErrorCode::UnknownOpcode),
502        }?;
503
504        let yes = yes.ok_or(ParseErrorCode::SelfReference)?;
505        let no = no.ok_or(ParseErrorCode::SelfReference)?;
506
507        let inst = RawInstruction::Condition { yes, no, condition };
508
509        Ok(inst)
510    }
511
512    fn serialize(&self, buf: &mut [u8]) -> Result<(), SerializationErrorCode> {
513        // NOTE: buffer length was already checked in Bytecode::serialize, so we
514        // don't need to do that again in this function.
515
516        let (code, yes, no, condition) = match self {
517            Self::Nop(offset) => (INET_DIAG_BC_NOP, offset.get().get(), 0, None),
518            // Linux requires that the yes field always points at the next
519            // instruction, even though JMP doesn't use it.
520            Self::Jmp(offset) => (INET_DIAG_BC_JMP, STRUCT_BC_OP_SIZE, offset.get().get(), None),
521            Self::Condition { yes, no, condition } => {
522                (condition.code(), yes.get().get(), no.get().get(), Some(condition))
523            }
524        };
525
526        let mut buf = RawInstructionBuffer::new(buf).unwrap();
527        buf.set_code(code);
528        buf.set_yes(
529            yes.try_into().map_err(|_| SerializationErrorCode::IndexTooLargeForSerializedType)?,
530        );
531        buf.set_no(
532            u16::try_from(no)
533                .map_err(|_| SerializationErrorCode::IndexTooLargeForSerializedType)?,
534        );
535
536        let buf = buf.payload_mut();
537        if let Some(condition) = condition {
538            match condition {
539                Condition::AutoPort => {}
540                Condition::SrcPortGe(port)
541                | Condition::SrcPortLe(port)
542                | Condition::DstPortGe(port)
543                | Condition::DstPortLe(port)
544                | Condition::SrcPortEq(port)
545                | Condition::DstPortEq(port) => {
546                    let mut buf = PortConditionBuffer::new(buf).unwrap();
547                    buf.set_port(*port);
548                }
549                Condition::SrcTuple(c) | Condition::DstTuple(c) => c.serialize(buf),
550                Condition::Device(ifindex) => {
551                    let mut buf = DeviceConditionBuffer::new(buf).unwrap();
552                    buf.set_ifindex(*ifindex);
553                }
554                Condition::Mark { mark, mask } => {
555                    let mut buf = MarkConditionBuffer::new(buf).unwrap();
556                    buf.set_mark(*mark);
557                    buf.set_mask(*mask);
558                }
559                Condition::Cgroup(cgroup_id) => {
560                    let mut buf = CgroupConditionBuffer::new(buf).unwrap();
561                    buf.set_cgroup_id(*cgroup_id);
562                }
563            }
564        }
565
566        Ok(())
567    }
568}
569
570#[derive(Debug, PartialEq, Eq, Clone)]
571pub enum Condition {
572    SrcPortGe(u16),
573    SrcPortLe(u16),
574    DstPortGe(u16),
575    DstPortLe(u16),
576    SrcPortEq(u16),
577    DstPortEq(u16),
578    AutoPort,
579    SrcTuple(TupleCondition),
580    DstTuple(TupleCondition),
581    Device(u32),
582    Mark { mark: u32, mask: u32 },
583    Cgroup(u64),
584}
585
586// Linux uses a struct inet_diag_bc_op for the condition payload, but just the
587// `no` field.
588buffer!(PortConditionBuffer(STRUCT_BC_OP_SIZE) {
589    padding: (slice, 0..1),
590    port: (u16, 2..4),
591});
592
593buffer!(DeviceConditionBuffer(DEVICE_COND_SIZE) {
594    ifindex: (u32, 0..4),
595});
596
597buffer!(MarkConditionBuffer(MARK_COND_SIZE) {
598    mark: (u32, 0..4),
599    mask: (u32, 4..8),
600});
601
602buffer!(CgroupConditionBuffer(CGROUP_COND_SIZE) {
603    cgroup_id: (u64, 0..8),
604});
605
606impl Condition {
607    fn serialized_len(&self) -> usize {
608        match self {
609            Condition::AutoPort => 0,
610            // Linux puts the port in the no field of a struct inet_diag_bc_op.
611            Condition::SrcPortGe(_)
612            | Condition::SrcPortLe(_)
613            | Condition::DstPortGe(_)
614            | Condition::DstPortLe(_)
615            | Condition::SrcPortEq(_)
616            | Condition::DstPortEq(_) => STRUCT_BC_OP_SIZE,
617            Condition::SrcTuple(c) | Condition::DstTuple(c) => c.serialized_len(),
618            Condition::Device(_) => DEVICE_COND_SIZE,
619            Condition::Mark { .. } => MARK_COND_SIZE,
620            Condition::Cgroup(_) => CGROUP_COND_SIZE,
621        }
622    }
623
624    fn code(&self) -> u8 {
625        match self {
626            Self::SrcPortGe(_) => INET_DIAG_BC_S_GE,
627            Self::SrcPortLe(_) => INET_DIAG_BC_S_LE,
628            Self::DstPortGe(_) => INET_DIAG_BC_D_GE,
629            Self::DstPortLe(_) => INET_DIAG_BC_D_LE,
630            Self::AutoPort => INET_DIAG_BC_AUTO,
631            Self::SrcTuple(_) => INET_DIAG_BC_S_COND,
632            Self::DstTuple(_) => INET_DIAG_BC_D_COND,
633            Self::Device(_) => INET_DIAG_BC_DEV_COND,
634            Self::Mark { .. } => INET_DIAG_BC_MARK_COND,
635            Self::SrcPortEq(_) => INET_DIAG_BC_S_EQ,
636            Self::DstPortEq(_) => INET_DIAG_BC_D_EQ,
637            Self::Cgroup(_) => INET_DIAG_BC_CGROUP_COND,
638        }
639    }
640}
641
642#[derive(Debug, PartialEq, Eq, Clone)]
643pub struct TupleCondition {
644    pub prefix_len: u8,
645    pub addr: Option<IpAddr>,
646    pub port: Option<u16>,
647}
648
649buffer!(TupleConditionBuffer(TUPLE_COND_MIN_SIZE) {
650    family: (u8, 0),
651    prefix_len: (u8, 1),
652    port: (i32, 4..8),
653    payload: (slice, TUPLE_COND_MIN_SIZE..),
654});
655
656buffer!(Ipv4AddrBuffer(AF_INET_ADDR_LEN) {
657    addr: (u32, 0..AF_INET_ADDR_LEN),
658});
659
660buffer!(Ipv6AddrBuffer(AF_INET6_ADDR_LEN) {
661    addr: (u128, 0..AF_INET6_ADDR_LEN),
662});
663
664impl TupleCondition {
665    fn serialized_len(&self) -> usize {
666        TUPLE_COND_MIN_SIZE
667            + match self.addr {
668                Some(IpAddr::V4(_)) => AF_INET_ADDR_LEN,
669                Some(IpAddr::V6(_)) => AF_INET6_ADDR_LEN,
670                None => 0,
671            }
672    }
673
674    fn parse(buf: &[u8]) -> Result<Self, ParseErrorCode> {
675        let buf =
676            TupleConditionBuffer::new(buf).map_err(|_| ParseErrorCode::TruncatedInstruction)?;
677        let family = buf.family();
678        let prefix_len = buf.prefix_len();
679        let port = buf.port();
680        let port = if port == -1 { None } else { Some(port as u16) };
681
682        let payload = buf.payload();
683        let addr = match family {
684            AF_INET => match Ipv4AddrBuffer::new(payload) {
685                Ok(buf) => Ok(Some(IpAddr::V4(Ipv4Addr::from(buf.addr())))),
686                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
687            },
688            AF_INET6 => match Ipv6AddrBuffer::new(payload) {
689                Ok(buf) => Ok(Some(IpAddr::V6(Ipv6Addr::from(buf.addr())))),
690                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
691            },
692            AF_UNSPEC => Ok(None),
693            _ => Err(ParseErrorCode::InvalidAddressFamily),
694        }?;
695
696        let max_prefix_len = addr
697            .map(|a| match a {
698                IpAddr::V4(_) => AF_INET_ADDR_LEN,
699                IpAddr::V6(_) => AF_INET6_ADDR_LEN,
700            })
701            .unwrap_or(0)
702            * 8;
703
704        if usize::from(prefix_len) > max_prefix_len {
705            return Err(ParseErrorCode::PrefixLengthLongerThanAddress);
706        }
707
708        Ok(TupleCondition { prefix_len, port, addr })
709    }
710
711    fn serialize(&self, buf: &mut [u8]) {
712        let mut buf = TupleConditionBuffer::new(buf).unwrap();
713
714        match self.addr {
715            Some(IpAddr::V4(addr)) => {
716                buf.set_family(AF_INET);
717                Ipv4AddrBuffer::new(buf.payload_mut()).unwrap().set_addr(addr.into());
718            }
719            Some(IpAddr::V6(addr)) => {
720                buf.set_family(AF_INET6);
721                Ipv6AddrBuffer::new(buf.payload_mut()).unwrap().set_addr(addr.into());
722            }
723            None => {
724                buf.set_family(AF_UNSPEC);
725            }
726        };
727
728        buf.set_prefix_len(self.prefix_len);
729        buf.set_port(self.port.map(i32::from).unwrap_or(-1));
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    #[test]
738    fn instructions_roundtrip() {
739        let conditions = vec![
740            Condition::SrcPortGe(100),
741            Condition::SrcPortLe(200),
742            Condition::DstPortGe(300),
743            Condition::DstPortLe(400),
744            Condition::SrcPortEq(500),
745            Condition::DstPortEq(600),
746            Condition::AutoPort,
747            Condition::SrcTuple(TupleCondition {
748                prefix_len: 24,
749                port: Some(8080),
750                addr: Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
751            }),
752            Condition::SrcTuple(TupleCondition {
753                prefix_len: 128,
754                port: Some(8081),
755                addr: Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
756            }),
757            Condition::DstTuple(TupleCondition {
758                prefix_len: 24,
759                port: Some(9090),
760                addr: Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
761            }),
762            Condition::DstTuple(TupleCondition {
763                prefix_len: 64,
764                port: None,
765                addr: Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
766            }),
767            Condition::Device(1),
768            Condition::Mark { mark: 0x1234, mask: 0xFFFF },
769            Condition::Cgroup(123456789),
770        ];
771
772        for condition in conditions {
773            let bc = Bytecode(vec![
774                Instruction::Condition {
775                    yes: Action::AdvanceBy(NonZeroUsize::new(1).unwrap()),
776                    no: Action::Accept,
777                    condition: condition.clone(),
778                },
779                Instruction::Nop(Action::Accept),
780            ]);
781
782            let mut buf = vec![0u8; bc.serialized_len()];
783            bc.clone().serialize(&mut buf).unwrap();
784            let parsed = Bytecode::parse(&buf)
785                .unwrap_or_else(|e| panic!("parse failed for {:?}: {:?}", condition, e));
786            assert_eq!(parsed, bc, "roundtrip failed for {:?}", condition);
787        }
788    }
789
790    #[test]
791    fn accept_reject_mapping() {
792        let bc = Bytecode(vec![Instruction::Jmp(Action::Accept), Instruction::Jmp(Action::Reject)]);
793
794        let mut buf = vec![0u8; bc.serialized_len()];
795        bc.clone().serialize(&mut buf).unwrap();
796
797        assert_eq!(
798            buf,
799            [
800                INET_DIAG_BC_JMP,
801                4, // yes
802                8u16.to_ne_bytes()[0],
803                8u16.to_ne_bytes()[1], // no
804                INET_DIAG_BC_JMP,
805                4, // yes
806                8u16.to_ne_bytes()[0],
807                8u16.to_ne_bytes()[1], // no
808            ]
809        );
810
811        let parsed = Bytecode::parse(&buf).unwrap();
812        assert_eq!(parsed, bc);
813    }
814
815    #[test]
816    fn buffer_too_small() {
817        let bc = Bytecode(vec![Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))]);
818        let mut buf = vec![0u8; 3]; // Nop is 4 bytes
819        assert_eq!(bc.serialize(&mut buf), Err(SerializationError::BufferTooSmall));
820    }
821
822    #[test]
823    fn index_too_large_yes() {
824        const COUNT: usize = 64;
825
826        let mut ops = vec![Instruction::Condition {
827            yes: Action::AdvanceBy(NonZeroUsize::new(COUNT + 1).unwrap()),
828            no: Action::Accept,
829            condition: Condition::AutoPort,
830        }];
831        // Each NOP is 4 bytes, so 64 NOPs is 256 bytes.
832        ops.extend(
833            (0..COUNT).map(|_| Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))),
834        );
835        ops.push(Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))); // Target
836
837        let bc = Bytecode(ops);
838        let mut buf = vec![0u8; bc.serialized_len()];
839        assert_eq!(
840            bc.serialize(&mut buf),
841            Err(SerializationError::IndexTooLargeForSerializedType { at: 0 })
842        );
843    }
844
845    #[test]
846    fn index_too_large_no() {
847        const COUNT: usize = 16384;
848
849        let mut ops =
850            vec![Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::new(COUNT + 1).unwrap()))];
851        ops.extend(
852            (0..COUNT).map(|_| Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))),
853        );
854        ops.push(Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))); // Target
855
856        let bc = Bytecode(ops);
857        let mut buf = vec![0u8; bc.serialized_len()];
858        assert_eq!(
859            bc.serialize(&mut buf),
860            Err(SerializationError::IndexTooLargeForSerializedType { at: 0 })
861        );
862    }
863
864    #[test]
865    fn index_overflow() {
866        let ops = vec![
867            Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap())),
868            Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::MAX)),
869        ];
870        let bc = Bytecode(ops);
871        let mut buf = vec![0u8; bc.serialized_len()];
872        assert_eq!(bc.serialize(&mut buf), Err(SerializationError::IndexOverflow { at: 1 }));
873    }
874
875    #[test]
876    fn advance_by_mapping() {
877        let bc = Bytecode(vec![
878            Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::new(2).unwrap())),
879            Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap())),
880            Instruction::Nop(Action::Accept),
881        ]);
882
883        let mut buf = vec![0u8; bc.serialized_len()];
884        bc.clone().serialize(&mut buf).unwrap();
885
886        let parsed = Bytecode::parse(&buf).unwrap();
887        assert_eq!(parsed, bc);
888    }
889
890    #[test]
891    fn parse_errors() {
892        // Invalid bytecode!
893        let buf = vec![255, 4, 0, 0];
894        assert_eq!(
895            Bytecode::parse(&buf),
896            Err(ParseError { index: 0, code: ParseErrorCode::UnknownOpcode })
897        );
898
899        // Invalid target jump (jumping into the middle of an instruction).
900        let mut buf = vec![];
901        buf.push(INET_DIAG_BC_NOP);
902        buf.push(4); // yes
903        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
904
905        buf.push(INET_DIAG_BC_MARK_COND);
906        buf.push(4); // yes.
907        buf.extend_from_slice(&6u16.to_ne_bytes()); // no. Middle of the next instruction. Invalid!
908        buf.extend_from_slice(&0u32.to_ne_bytes());
909        buf.extend_from_slice(&0u32.to_ne_bytes());
910
911        buf.push(INET_DIAG_BC_JMP);
912        buf.push(4);
913        buf.extend_from_slice(&4u16.to_ne_bytes());
914        assert_eq!(
915            Bytecode::parse(&buf),
916            Err(ParseError { index: 4, code: ParseErrorCode::InvalidJumpTarget })
917        );
918
919        // Truncated instruction body (SrcPortGe missing bytes)
920        // S_GE: (4 bytes) + 4 bytes payload.
921        let mut buf = Vec::new();
922        buf.push(INET_DIAG_BC_S_GE);
923        buf.push(4); // yes.
924        buf.extend_from_slice(&4u16.to_ne_bytes());
925        assert_eq!(
926            Bytecode::parse(&buf),
927            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
928        );
929
930        // Invalid self-reference (yes=0).
931        let mut buf = Vec::new();
932        buf.push(INET_DIAG_BC_AUTO);
933        buf.push(0); // yes. Invalid!
934        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
935        assert_eq!(
936            Bytecode::parse(&buf),
937            Err(ParseError { index: 0, code: ParseErrorCode::SelfReference })
938        );
939
940        // Invalid self-reference (no=0).
941        let mut buf = Vec::new();
942        buf.push(INET_DIAG_BC_AUTO);
943        buf.push(4); // yes
944        buf.extend_from_slice(&0u16.to_ne_bytes()); // no. Invalid!
945        assert_eq!(
946            Bytecode::parse(&buf),
947            Err(ParseError { index: 0, code: ParseErrorCode::SelfReference })
948        );
949
950        // Invalid address family.
951        let mut buf = Vec::new();
952        buf.push(INET_DIAG_BC_S_COND);
953        buf.push(12); // yes
954        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
955        buf.push(255); // Invalid family
956        buf.push(0); // prefix len
957        buf.push(0); // pad
958        buf.push(0); // pad
959        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
960        assert_eq!(
961            Bytecode::parse(&buf),
962            Err(ParseError { index: 0, code: ParseErrorCode::InvalidAddressFamily })
963        );
964
965        // Prefix length longer than address.
966        let mut buf = Vec::new();
967        buf.push(INET_DIAG_BC_S_COND);
968        buf.push(16); // yes
969        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
970        buf.push(AF_INET);
971        buf.push(33); // prefix len
972        buf.push(0); // pad
973        buf.push(0); // pad
974        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
975        buf.extend_from_slice(&[0, 0, 0, 0]); // address
976        assert_eq!(
977            Bytecode::parse(&buf),
978            Err(ParseError { index: 0, code: ParseErrorCode::PrefixLengthLongerThanAddress })
979        );
980    }
981
982    #[test]
983    fn truncated_payloads() {
984        let mut buf = Vec::new();
985        buf.push(INET_DIAG_BC_DEV_COND);
986        buf.push(8); // yes
987        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
988        buf.extend_from_slice(&[0; 3]); // 3 bytes instead of 4
989        assert_eq!(
990            Bytecode::parse(&buf),
991            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
992        );
993
994        let mut buf = Vec::new();
995        buf.push(INET_DIAG_BC_MARK_COND);
996        buf.push(12); // yes
997        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
998        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8
999        assert_eq!(
1000            Bytecode::parse(&buf),
1001            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1002        );
1003
1004        let mut buf = Vec::new();
1005        buf.push(INET_DIAG_BC_CGROUP_COND);
1006        buf.push(12); // yes
1007        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1008        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8
1009        assert_eq!(
1010            Bytecode::parse(&buf),
1011            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1012        );
1013
1014        let mut buf = Vec::new();
1015        buf.push(INET_DIAG_BC_S_COND);
1016        buf.push(12); // yes
1017        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1018        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8 (min header size)
1019        assert_eq!(
1020            Bytecode::parse(&buf),
1021            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1022        );
1023
1024        let mut buf = Vec::new();
1025        buf.push(INET_DIAG_BC_S_COND);
1026        buf.push(16); // yes
1027        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1028        buf.push(AF_INET);
1029        buf.push(0); // prefix len
1030        buf.push(0);
1031        buf.push(0); // pad
1032        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
1033        buf.extend_from_slice(&[0; 3]); // 3 bytes instead of 4 for IPv4
1034        assert_eq!(
1035            Bytecode::parse(&buf),
1036            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1037        );
1038
1039        let mut buf = Vec::new();
1040        buf.push(INET_DIAG_BC_S_COND);
1041        buf.push(28); // yes
1042        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1043        buf.push(AF_INET6);
1044        buf.push(0); // prefix len
1045        buf.push(0);
1046        buf.push(0); // pad
1047        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
1048        buf.extend_from_slice(&[0; 15]); // 15 bytes instead of 16 for IPv6
1049        assert_eq!(
1050            Bytecode::parse(&buf),
1051            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1052        );
1053    }
1054}