Skip to main content

netlink_packet_sock_diag/inet/
bytecode.rs

1// SPDX-License-Identifier: MIT
2
3//! Functionality for parsing and serializing INET_DIAG bytecode programs.
4//!
5//! SOCK_DIAG_BY_FAMILY requests with NLM_F_DUMP can accept a bytecode program.
6//! The program is run against all of the sockets matching the standard part of
7//! the request (though some fields, like socket_id, are not examined at all).
8//! If the program accepts a socket, it is returned to the caller. Acceptance is
9//! signalled by the program reaching the length of the buffer exactly.
10//! Rejection is signalled by the program jumping to somewhere past this.
11//!
12//! Each instruction is composed of the following basic structure, where `yes`
13//! and `no` are how many bytes jump forward if the instruction matches or not.
14//! Note that this means there are no loops and all programs trivially must
15//! terminate:
16//!
17//! ```c
18//! opcode: u8,
19//! yes: u8,
20//! no: u16,
21//! // Followed (optionally) by parameters for the instruction.
22//! ```
23//!
24//! Instructions are variable-length, which is unwieldy to deal with in Rust, so
25//! instead we represent a program as a series of fixed-length instructions,
26//! which requires mapping back and forth to byte offsets during parsing and
27//! serialization.
28//!
29//! There is a small loss of fidelity in this Rust representation. The types
30//! here encode acception and rejection explicitly, which means there's only a
31//! single rejection target. It also encodes NOPs and jumps more simply,
32//! forgoing the `yes` and `no` fields entirely. While this shouldn't make a
33//! semantic difference, it does mean round-tripping a program might result in a
34//! different representation.
35
36use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
37use std::num::NonZeroUsize;
38
39use arbitrary::Arbitrary;
40use netlink_packet_utils::{DecodeError, buffer};
41
42use crate::constants::{
43    AF_INET, AF_INET6, AF_UNSPEC, INET_DIAG_BC_AUTO, INET_DIAG_BC_CGROUP_COND, INET_DIAG_BC_D_COND,
44    INET_DIAG_BC_D_EQ, INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE, INET_DIAG_BC_DEV_COND,
45    INET_DIAG_BC_JMP, INET_DIAG_BC_MARK_COND, INET_DIAG_BC_NOP, INET_DIAG_BC_S_COND,
46    INET_DIAG_BC_S_EQ, INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
47};
48
49/// Types for keeping track of various `usize`s during parsing.
50///
51/// The two axes are "byte or instruction" and "index or offset (from current
52/// instruction)". They're in a submodule to force everyone to go through the
53/// interface.
54mod wrappers {
55    use std::num::NonZeroUsize;
56
57    /// The absolute index of a parsed [`Instruction`](super::Instruction).
58    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
59    pub(super) struct InstructionIndex(usize);
60
61    impl InstructionIndex {
62        pub(super) fn new(offset: usize) -> Self {
63            Self(offset)
64        }
65
66        pub(super) fn get(self) -> usize {
67            let Self(val) = self;
68            val
69        }
70
71        pub(super) fn checked_add(self, rhs: InstructionOffset) -> Option<Self> {
72            let Self(val) = self;
73            val.checked_add(rhs.get().get()).map(Self)
74        }
75    }
76
77    /// The relative offset between two parsed
78    /// [`Instruction`s](super::Instruction).
79    ///
80    /// Guaranteed to be greater than 0.
81    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
82    pub(super) struct InstructionOffset(NonZeroUsize);
83
84    impl InstructionOffset {
85        pub(super) fn new(offset: usize) -> Option<Self> {
86            Some(Self(NonZeroUsize::new(offset)?))
87        }
88
89        pub(super) fn new_nonzero(offset: NonZeroUsize) -> Self {
90            Self(offset)
91        }
92
93        pub(super) fn get(self) -> NonZeroUsize {
94            let Self(val) = self;
95            val
96        }
97    }
98
99    /// The absolute index of a byte in an unparsed program.
100    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash)]
101    pub(super) struct ByteIndex(usize);
102
103    impl ByteIndex {
104        pub(super) fn new(offset: usize) -> Self {
105            Self(offset)
106        }
107
108        pub(super) fn get(self) -> usize {
109            let Self(val) = self;
110            val
111        }
112
113        pub(super) fn checked_add(self, rhs: ByteOffset) -> Option<Self> {
114            let Self(val) = self;
115            val.checked_add(rhs.get().get()).map(Self)
116        }
117
118        pub(super) fn checked_sub(self, rhs: ByteIndex) -> Option<ByteOffset> {
119            let Self(val) = self;
120            val.checked_sub(rhs.0).and_then(ByteOffset::new)
121        }
122    }
123
124    /// The relative offset between two bytes in a unparsed program.
125    ///
126    /// Guaranteed to be greater than 0.
127    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)]
128    pub(super) struct ByteOffset(NonZeroUsize);
129
130    impl ByteOffset {
131        pub(super) fn new(offset: usize) -> Option<Self> {
132            Some(Self(NonZeroUsize::new(offset)?))
133        }
134
135        pub(super) fn get(self) -> NonZeroUsize {
136            let Self(val) = self;
137            val
138        }
139    }
140}
141
142use wrappers::{ByteIndex, ByteOffset, InstructionIndex, InstructionOffset};
143
144/// The size of a Linux `struct inet_diag_bc_op`.
145const STRUCT_BC_OP_SIZE: usize = 4;
146const DEVICE_COND_SIZE: usize = 4;
147const MARK_COND_SIZE: usize = 8;
148const CGROUP_COND_SIZE: usize = 8;
149const TUPLE_COND_MIN_SIZE: usize = 8;
150const AF_INET_ADDR_LEN: usize = 4;
151const AF_INET6_ADDR_LEN: usize = 16;
152
153/// A bytecode program used by Linux to match AF_INET sockets.
154#[derive(Debug, PartialEq, Eq, Clone, Arbitrary)]
155pub struct Bytecode(pub Vec<Instruction>);
156
157#[derive(Debug, PartialEq, Eq)]
158pub enum SerializationError {
159    /// The target buffer provided was too small.
160    BufferTooSmall,
161    /// An error occurred during serialization of the instruction at index `at`.
162    InvalidInstruction { at: usize, error: InvalidInstructionError },
163}
164
165#[derive(Debug, PartialEq, Eq, Clone, Copy)]
166pub enum InvalidInstructionError {
167    IndexTooLargeForSerializedType,
168    IndexOverflow,
169    IndexPastEnd,
170    PrefixLengthLongerThanAddress,
171}
172
173impl InvalidInstructionError {
174    fn at_index(self, index: InstructionIndex) -> SerializationError {
175        SerializationError::InvalidInstruction { at: index.get(), error: self }
176    }
177}
178
179/// An error encountered when parsing a program from a raw byte buffer.
180#[derive(Debug, PartialEq, Eq)]
181pub struct ParseError {
182    /// The index in the provided buffer at which the error occurred.
183    pub index: usize,
184    /// The specific error that occurred.
185    pub code: ParseErrorCode,
186}
187
188#[derive(Debug, PartialEq, Eq)]
189pub enum ParseErrorCode {
190    TruncatedInstruction,
191    UnknownOpcode,
192    InvalidJumpTarget,
193    SelfReference,
194    InvalidAddressFamily,
195    PrefixLengthLongerThanAddress,
196    IndexOverflow,
197}
198
199impl ParseErrorCode {
200    fn at_index(self, index: ByteIndex) -> ParseError {
201        ParseError { index: index.get(), code: self }
202    }
203}
204
205impl Bytecode {
206    /// Returns the length of the serialized form of this bytecode.
207    ///
208    /// Useful for sizing the buffer passed to [`Bytecode::serialize`].
209    pub fn serialized_len(&self) -> usize {
210        self.0.iter().map(Instruction::serialized_len).sum()
211    }
212
213    /// Parse a bytecode program from the provided buffer.
214    pub fn parse(buf: &[u8]) -> Result<Self, ParseError> {
215        let mut raw_ops = vec![];
216        let mut curr_byte_index = ByteIndex::new(0);
217        let buf_len = ByteIndex::new(buf.len());
218        let mut instruction_index_by_byte_offset = std::collections::HashMap::new();
219
220        // First, we build up a map of instruction index to byte offset.
221        while curr_byte_index < buf_len {
222            instruction_index_by_byte_offset
223                .insert(curr_byte_index, InstructionIndex::new(raw_ops.len()));
224            let inst = RawInstruction::parse(&buf[curr_byte_index.get()..])
225                .map_err(|code| code.at_index(curr_byte_index))?;
226            let inst_len = inst.serialized_len();
227            raw_ops.push((curr_byte_index, inst));
228            curr_byte_index = curr_byte_index
229                .checked_add(inst_len)
230                .ok_or_else(|| ParseErrorCode::IndexOverflow.at_index(curr_byte_index))?;
231        }
232
233        // If curr_byte_index < buf_len, we would have looped again.
234        // If curr_byte_index > buf_len, we would have returned
235        // TruncatedInstruction.
236        assert_eq!(curr_byte_index, buf_len);
237
238        // Now, we resolve the raw byte offsets to indexes.
239
240        let raw_accept_offset = curr_byte_index;
241        // Linux bytecode validation ensures that there is only a single valid rejection offset.
242        let raw_reject_offset = raw_accept_offset
243            .checked_add(ByteOffset::new(4).unwrap())
244            .ok_or_else(|| ParseErrorCode::IndexOverflow.at_index(ByteIndex::new(0)))?;
245
246        let resolve = |target_offset: ByteIndex, current_index: InstructionIndex| {
247            if target_offset == raw_accept_offset {
248                Ok(Action::Accept)
249            } else if target_offset == raw_reject_offset {
250                Ok(Action::Reject)
251            } else if let Some(&target_index) = instruction_index_by_byte_offset.get(&target_offset)
252            {
253                // By construction we know that an instruction can't reference
254                // an earlier one, so current_index will always be less than or
255                // equal to target_index.
256                let offset = target_index.get().checked_sub(current_index.get()).unwrap();
257                let index_offset =
258                    InstructionOffset::new(offset).ok_or(ParseErrorCode::SelfReference)?;
259                Ok(Action::AdvanceBy(index_offset.get()))
260            } else {
261                Err(ParseErrorCode::InvalidJumpTarget)
262            }
263        };
264
265        let resolved_ops = raw_ops
266            .into_iter()
267            .enumerate()
268            .map(|(curr_instr_index, (curr_byte_index, raw_op))| {
269                let curr_instr_index = InstructionIndex::new(curr_instr_index);
270
271                match raw_op {
272                    RawInstruction::Nop(offset) => {
273                        let target = curr_byte_index.checked_add(offset).ok_or_else(|| {
274                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
275                        })?;
276                        let action = resolve(target, curr_instr_index)
277                            .map_err(|code| code.at_index(curr_byte_index))?;
278                        Ok(Instruction::Nop(action))
279                    }
280                    RawInstruction::Jmp(offset) => {
281                        let target = curr_byte_index.checked_add(offset).ok_or_else(|| {
282                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
283                        })?;
284                        let action = resolve(target, curr_instr_index)
285                            .map_err(|code| code.at_index(curr_byte_index))?;
286                        Ok(Instruction::Jmp(action))
287                    }
288                    RawInstruction::Condition { yes, no, condition } => {
289                        let yes_target = curr_byte_index.checked_add(yes).ok_or_else(|| {
290                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
291                        })?;
292                        let no_target = curr_byte_index.checked_add(no).ok_or_else(|| {
293                            ParseErrorCode::IndexOverflow.at_index(curr_byte_index)
294                        })?;
295
296                        let yes = resolve(yes_target, curr_instr_index)
297                            .map_err(|code| code.at_index(curr_byte_index))?;
298                        let no = resolve(no_target, curr_instr_index)
299                            .map_err(|code| code.at_index(curr_byte_index))?;
300                        Ok(Instruction::Condition { yes, no, condition })
301                    }
302                }
303            })
304            .collect::<Result<Vec<_>, ParseError>>()?;
305
306        Ok(Bytecode(resolved_ops))
307    }
308
309    /// Serialize the bytecode into the provided buffer.
310    pub fn serialize(self, buf: &mut [u8]) -> Result<(), SerializationError> {
311        let Self(instructions) = self;
312
313        let mut total_len = ByteIndex::new(0);
314        let byte_indices_by_instruction_index: Vec<_> =
315            instructions
316                .iter()
317                .enumerate()
318                .map(|(i, inst)| {
319                    let res = total_len;
320                    match total_len.checked_add(ByteOffset::new(inst.serialized_len()).unwrap()) {
321                        Some(new_len) => {
322                            total_len = new_len;
323                            Ok(res)
324                        }
325                        None => Err(InvalidInstructionError::IndexOverflow
326                            .at_index(InstructionIndex::new(i))),
327                    }
328                })
329                .collect::<Result<Vec<_>, _>>()?;
330
331        if total_len.get() > buf.len() {
332            return Err(SerializationError::BufferTooSmall);
333        }
334
335        instructions.into_iter().enumerate().try_for_each(|(curr_inst_index, inst)| {
336            let curr_inst_index = InstructionIndex::new(curr_inst_index);
337            let curr_byte_index = byte_indices_by_instruction_index[curr_inst_index.get()];
338
339            inst.try_into_raw(
340                &byte_indices_by_instruction_index,
341                curr_inst_index,
342                curr_byte_index,
343                total_len,
344            )
345            .and_then(|raw| raw.serialize(&mut buf[curr_byte_index.get()..]))
346            .map_err(|e| e.at_index(curr_inst_index))
347        })
348    }
349}
350
351#[derive(Debug, PartialEq, Eq, Clone, Arbitrary)]
352pub enum Instruction {
353    Nop(Action),
354    Jmp(Action),
355    Condition { yes: Action, no: Action, condition: Condition },
356}
357
358impl Instruction {
359    fn serialized_len(&self) -> usize {
360        STRUCT_BC_OP_SIZE
361            + match self {
362                Self::Nop(_) | Self::Jmp(_) => 0,
363                Self::Condition { condition, .. } => condition.serialized_len(),
364            }
365    }
366
367    fn try_into_raw(
368        self,
369        byte_indices_by_instruction_index: &[ByteIndex],
370        instruction_index: InstructionIndex,
371        byte_index: ByteIndex,
372        total_len: ByteIndex,
373    ) -> Result<RawInstruction, InvalidInstructionError> {
374        // Calculate relative offsets
375        let calculate_rel = |action| {
376            let target = match action {
377                Action::Accept => total_len,
378                // Linux checks that all targets are multiples of 4.
379                Action::Reject => total_len
380                    .checked_add(ByteOffset::new(4).unwrap())
381                    .ok_or(InvalidInstructionError::IndexOverflow)?,
382                Action::AdvanceBy(dist) => {
383                    let target_index = instruction_index
384                        .checked_add(InstructionOffset::new_nonzero(dist))
385                        .ok_or(InvalidInstructionError::IndexOverflow)?;
386                    *byte_indices_by_instruction_index
387                        .get(target_index.get())
388                        .ok_or(InvalidInstructionError::IndexPastEnd)?
389                }
390            };
391
392            // This is safe because the elements of offsets are strictly
393            // increasing, so indexing into my_index+dist (and we know dist
394            // can't be zero because of its type) must give a larger value.
395            Ok(target.checked_sub(byte_index).unwrap())
396        };
397
398        match self {
399            Instruction::Nop(action) => Ok(RawInstruction::Nop(calculate_rel(action)?)),
400            Instruction::Jmp(action) => Ok(RawInstruction::Jmp(calculate_rel(action)?)),
401            Instruction::Condition { yes, no, condition } => Ok(RawInstruction::Condition {
402                yes: calculate_rel(yes)?,
403                no: calculate_rel(no)?,
404                condition,
405            }),
406        }
407    }
408}
409
410#[derive(Debug, PartialEq, Eq, Clone, Arbitrary)]
411pub enum Action {
412    Accept,
413    Reject,
414    AdvanceBy(NonZeroUsize),
415}
416
417enum RawInstruction {
418    Nop(ByteOffset),
419    Jmp(ByteOffset),
420    Condition { yes: ByteOffset, no: ByteOffset, condition: Condition },
421}
422
423buffer!(RawInstructionBuffer(STRUCT_BC_OP_SIZE) {
424    code: (u8, 0),
425    yes: (u8, 1),
426    no: (u16, 2..4),
427    payload: (slice, STRUCT_BC_OP_SIZE..),
428});
429
430impl RawInstruction {
431    fn serialized_len(&self) -> ByteOffset {
432        ByteOffset::new(
433            STRUCT_BC_OP_SIZE
434                + match self {
435                    Self::Nop(_) => 0,
436                    Self::Jmp(_) => 0,
437                    Self::Condition { condition, .. } => condition.serialized_len(),
438                },
439        )
440        .unwrap()
441    }
442
443    fn parse(buf: &[u8]) -> Result<RawInstruction, ParseErrorCode> {
444        let buf =
445            RawInstructionBuffer::new(buf).map_err(|_| ParseErrorCode::TruncatedInstruction)?;
446
447        let code = buf.code();
448        let yes = ByteOffset::new(buf.yes().into());
449        let no = ByteOffset::new(buf.no().into());
450
451        // Handle these separately because they don't follow the same pattern
452        // for how to handle yes and no as the other instructions.
453        if code == INET_DIAG_BC_NOP {
454            return Ok(RawInstruction::Nop(yes.ok_or(ParseErrorCode::SelfReference)?));
455        } else if code == INET_DIAG_BC_JMP {
456            return Ok(RawInstruction::Jmp(no.ok_or(ParseErrorCode::SelfReference)?));
457        }
458
459        fn port_cond<F>(buf: &[u8], f: F) -> Result<Condition, ParseErrorCode>
460        where
461            F: FnOnce(u16) -> Condition,
462        {
463            match PortConditionBuffer::new(buf) {
464                Ok(buf) => Ok(f(buf.port())),
465                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
466            }
467        }
468
469        // Put the condition at the beginning of buf.
470        let payload = buf.payload();
471        let condition = match code {
472            // Handled above.
473            INET_DIAG_BC_NOP => unreachable!(),
474            INET_DIAG_BC_JMP => unreachable!(),
475
476            INET_DIAG_BC_S_COND => TupleCondition::parse(payload).map(Condition::SrcTuple),
477            INET_DIAG_BC_D_COND => TupleCondition::parse(payload).map(Condition::DstTuple),
478            INET_DIAG_BC_DEV_COND => match DeviceConditionBuffer::new(payload) {
479                Ok(buf) => Ok(Condition::Device(buf.ifindex())),
480                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
481            },
482            INET_DIAG_BC_MARK_COND => match MarkConditionBuffer::new(payload) {
483                Ok(buf) => Ok(Condition::Mark { mark: buf.mark(), mask: buf.mask() }),
484                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
485            },
486            INET_DIAG_BC_S_EQ => port_cond(payload, Condition::SrcPortEq),
487            INET_DIAG_BC_D_EQ => port_cond(payload, Condition::DstPortEq),
488            INET_DIAG_BC_S_GE => port_cond(payload, Condition::SrcPortGe),
489            INET_DIAG_BC_D_GE => port_cond(payload, Condition::DstPortGe),
490            INET_DIAG_BC_S_LE => port_cond(payload, Condition::SrcPortLe),
491            INET_DIAG_BC_D_LE => port_cond(payload, Condition::DstPortLe),
492            INET_DIAG_BC_AUTO => Ok(Condition::AutoPort),
493            INET_DIAG_BC_CGROUP_COND => match CgroupConditionBuffer::new(payload) {
494                Ok(buf) => Ok(Condition::Cgroup(buf.cgroup_id())),
495                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
496            },
497            _ => Err(ParseErrorCode::UnknownOpcode),
498        }?;
499
500        let yes = yes.ok_or(ParseErrorCode::SelfReference)?;
501        let no = no.ok_or(ParseErrorCode::SelfReference)?;
502
503        let inst = RawInstruction::Condition { yes, no, condition };
504
505        Ok(inst)
506    }
507
508    fn serialize(&self, buf: &mut [u8]) -> Result<(), InvalidInstructionError> {
509        // NOTE: buffer length was already checked in Bytecode::serialize, so we
510        // don't need to do that again in this function.
511
512        let (code, yes, no, condition) = match self {
513            Self::Nop(offset) => (INET_DIAG_BC_NOP, offset.get().get(), 0, None),
514            // Linux requires that the yes field always points at the next
515            // instruction, even though JMP doesn't use it.
516            Self::Jmp(offset) => (INET_DIAG_BC_JMP, STRUCT_BC_OP_SIZE, offset.get().get(), None),
517            Self::Condition { yes, no, condition } => {
518                (condition.code(), yes.get().get(), no.get().get(), Some(condition))
519            }
520        };
521
522        let mut buf = RawInstructionBuffer::new(buf).unwrap();
523        buf.set_code(code);
524        buf.set_yes(
525            yes.try_into().map_err(|_| InvalidInstructionError::IndexTooLargeForSerializedType)?,
526        );
527        buf.set_no(
528            u16::try_from(no)
529                .map_err(|_| InvalidInstructionError::IndexTooLargeForSerializedType)?,
530        );
531
532        let buf = buf.payload_mut();
533        if let Some(condition) = condition {
534            match condition {
535                Condition::AutoPort => {}
536                Condition::SrcPortGe(port)
537                | Condition::SrcPortLe(port)
538                | Condition::DstPortGe(port)
539                | Condition::DstPortLe(port)
540                | Condition::SrcPortEq(port)
541                | Condition::DstPortEq(port) => {
542                    let mut buf = PortConditionBuffer::new(buf).unwrap();
543                    buf.set_port(*port);
544                }
545                Condition::SrcTuple(c) | Condition::DstTuple(c) => c.serialize(buf)?,
546                Condition::Device(ifindex) => {
547                    let mut buf = DeviceConditionBuffer::new(buf).unwrap();
548                    buf.set_ifindex(*ifindex);
549                }
550                Condition::Mark { mark, mask } => {
551                    let mut buf = MarkConditionBuffer::new(buf).unwrap();
552                    buf.set_mark(*mark);
553                    buf.set_mask(*mask);
554                }
555                Condition::Cgroup(cgroup_id) => {
556                    let mut buf = CgroupConditionBuffer::new(buf).unwrap();
557                    buf.set_cgroup_id(*cgroup_id);
558                }
559            }
560        }
561
562        Ok(())
563    }
564}
565
566#[derive(Debug, PartialEq, Eq, Clone, Arbitrary)]
567pub enum Condition {
568    SrcPortGe(u16),
569    SrcPortLe(u16),
570    DstPortGe(u16),
571    DstPortLe(u16),
572    SrcPortEq(u16),
573    DstPortEq(u16),
574    AutoPort,
575    SrcTuple(TupleCondition),
576    DstTuple(TupleCondition),
577    Device(u32),
578    Mark { mark: u32, mask: u32 },
579    Cgroup(u64),
580}
581
582// Linux uses a struct inet_diag_bc_op for the condition payload, but just the
583// `no` field.
584buffer!(PortConditionBuffer(STRUCT_BC_OP_SIZE) {
585    padding: (slice, 0..1),
586    port: (u16, 2..4),
587});
588
589buffer!(DeviceConditionBuffer(DEVICE_COND_SIZE) {
590    ifindex: (u32, 0..4),
591});
592
593buffer!(MarkConditionBuffer(MARK_COND_SIZE) {
594    mark: (u32, 0..4),
595    mask: (u32, 4..8),
596});
597
598buffer!(CgroupConditionBuffer(CGROUP_COND_SIZE) {
599    cgroup_id: (u64, 0..8),
600});
601
602impl Condition {
603    fn serialized_len(&self) -> usize {
604        match self {
605            Condition::AutoPort => 0,
606            // Linux puts the port in the no field of a struct inet_diag_bc_op.
607            Condition::SrcPortGe(_)
608            | Condition::SrcPortLe(_)
609            | Condition::DstPortGe(_)
610            | Condition::DstPortLe(_)
611            | Condition::SrcPortEq(_)
612            | Condition::DstPortEq(_) => STRUCT_BC_OP_SIZE,
613            Condition::SrcTuple(c) | Condition::DstTuple(c) => c.serialized_len(),
614            Condition::Device(_) => DEVICE_COND_SIZE,
615            Condition::Mark { .. } => MARK_COND_SIZE,
616            Condition::Cgroup(_) => CGROUP_COND_SIZE,
617        }
618    }
619
620    fn code(&self) -> u8 {
621        match self {
622            Self::SrcPortGe(_) => INET_DIAG_BC_S_GE,
623            Self::SrcPortLe(_) => INET_DIAG_BC_S_LE,
624            Self::DstPortGe(_) => INET_DIAG_BC_D_GE,
625            Self::DstPortLe(_) => INET_DIAG_BC_D_LE,
626            Self::AutoPort => INET_DIAG_BC_AUTO,
627            Self::SrcTuple(_) => INET_DIAG_BC_S_COND,
628            Self::DstTuple(_) => INET_DIAG_BC_D_COND,
629            Self::Device(_) => INET_DIAG_BC_DEV_COND,
630            Self::Mark { .. } => INET_DIAG_BC_MARK_COND,
631            Self::SrcPortEq(_) => INET_DIAG_BC_S_EQ,
632            Self::DstPortEq(_) => INET_DIAG_BC_D_EQ,
633            Self::Cgroup(_) => INET_DIAG_BC_CGROUP_COND,
634        }
635    }
636}
637
638#[derive(Debug, PartialEq, Eq, Clone, Arbitrary)]
639pub struct TupleCondition {
640    pub prefix_len: u8,
641    pub addr: Option<IpAddr>,
642    pub port: Option<u16>,
643}
644
645buffer!(TupleConditionBuffer(TUPLE_COND_MIN_SIZE) {
646    family: (u8, 0),
647    prefix_len: (u8, 1),
648    port: (i32, 4..8),
649    payload: (slice, TUPLE_COND_MIN_SIZE..),
650});
651
652buffer!(Ipv4AddrBuffer(AF_INET_ADDR_LEN) {
653    addr: (u32, 0..AF_INET_ADDR_LEN),
654});
655
656buffer!(Ipv6AddrBuffer(AF_INET6_ADDR_LEN) {
657    addr: (u128, 0..AF_INET6_ADDR_LEN),
658});
659
660impl TupleCondition {
661    fn serialized_len(&self) -> usize {
662        TUPLE_COND_MIN_SIZE
663            + match self.addr {
664                Some(IpAddr::V4(_)) => AF_INET_ADDR_LEN,
665                Some(IpAddr::V6(_)) => AF_INET6_ADDR_LEN,
666                None => 0,
667            }
668    }
669
670    fn parse(buf: &[u8]) -> Result<Self, ParseErrorCode> {
671        let buf =
672            TupleConditionBuffer::new(buf).map_err(|_| ParseErrorCode::TruncatedInstruction)?;
673        let family = buf.family();
674        let prefix_len = buf.prefix_len();
675        let port = buf.port();
676        let port = if port == -1 { None } else { Some(port as u16) };
677
678        let payload = buf.payload();
679        let addr = match family {
680            AF_INET => match Ipv4AddrBuffer::new(payload) {
681                Ok(buf) => Ok(Some(IpAddr::V4(Ipv4Addr::from(buf.addr())))),
682                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
683            },
684            AF_INET6 => match Ipv6AddrBuffer::new(payload) {
685                Ok(buf) => Ok(Some(IpAddr::V6(Ipv6Addr::from(buf.addr())))),
686                Err(_) => Err(ParseErrorCode::TruncatedInstruction),
687            },
688            AF_UNSPEC => Ok(None),
689            _ => Err(ParseErrorCode::InvalidAddressFamily),
690        }?;
691
692        let max_prefix_len = addr
693            .map(|a| match a {
694                IpAddr::V4(_) => AF_INET_ADDR_LEN,
695                IpAddr::V6(_) => AF_INET6_ADDR_LEN,
696            })
697            .unwrap_or(0)
698            * 8;
699
700        if usize::from(prefix_len) > max_prefix_len {
701            return Err(ParseErrorCode::PrefixLengthLongerThanAddress);
702        }
703
704        Ok(TupleCondition { prefix_len, port, addr })
705    }
706
707    fn serialize(&self, buf: &mut [u8]) -> Result<(), InvalidInstructionError> {
708        let mut buf = TupleConditionBuffer::new(buf).unwrap();
709
710        let max_prefix_len = match self.addr {
711            Some(IpAddr::V4(addr)) => {
712                buf.set_family(AF_INET);
713                Ipv4AddrBuffer::new(buf.payload_mut()).unwrap().set_addr(addr.into());
714
715                32
716            }
717            Some(IpAddr::V6(addr)) => {
718                buf.set_family(AF_INET6);
719                Ipv6AddrBuffer::new(buf.payload_mut()).unwrap().set_addr(addr.into());
720
721                128
722            }
723            None => {
724                buf.set_family(AF_UNSPEC);
725
726                0
727            }
728        };
729
730        if self.prefix_len > max_prefix_len {
731            return Err(InvalidInstructionError::PrefixLengthLongerThanAddress);
732        }
733
734        buf.set_prefix_len(self.prefix_len);
735        buf.set_port(self.port.map(i32::from).unwrap_or(-1));
736        Ok(())
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    #[test]
745    fn instructions_roundtrip() {
746        let conditions = vec![
747            Condition::SrcPortGe(100),
748            Condition::SrcPortLe(200),
749            Condition::DstPortGe(300),
750            Condition::DstPortLe(400),
751            Condition::SrcPortEq(500),
752            Condition::DstPortEq(600),
753            Condition::AutoPort,
754            Condition::SrcTuple(TupleCondition {
755                prefix_len: 24,
756                port: Some(8080),
757                addr: Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
758            }),
759            Condition::SrcTuple(TupleCondition {
760                prefix_len: 128,
761                port: Some(8081),
762                addr: Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
763            }),
764            Condition::DstTuple(TupleCondition {
765                prefix_len: 24,
766                port: Some(9090),
767                addr: Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
768            }),
769            Condition::DstTuple(TupleCondition {
770                prefix_len: 64,
771                port: None,
772                addr: Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
773            }),
774            Condition::Device(1),
775            Condition::Mark { mark: 0x1234, mask: 0xFFFF },
776            Condition::Cgroup(123456789),
777        ];
778
779        for condition in conditions {
780            let bc = Bytecode(vec![
781                Instruction::Condition {
782                    yes: Action::AdvanceBy(NonZeroUsize::new(1).unwrap()),
783                    no: Action::Accept,
784                    condition: condition.clone(),
785                },
786                Instruction::Nop(Action::Accept),
787            ]);
788
789            let mut buf = vec![0u8; bc.serialized_len()];
790            bc.clone().serialize(&mut buf).unwrap();
791            let parsed = Bytecode::parse(&buf)
792                .unwrap_or_else(|e| panic!("parse failed for {:?}: {:?}", condition, e));
793            assert_eq!(parsed, bc, "roundtrip failed for {:?}", condition);
794        }
795    }
796
797    #[test]
798    fn accept_reject_mapping() {
799        let bc = Bytecode(vec![Instruction::Jmp(Action::Accept), Instruction::Jmp(Action::Reject)]);
800
801        let mut buf = vec![0u8; bc.serialized_len()];
802        bc.clone().serialize(&mut buf).unwrap();
803
804        assert_eq!(
805            buf,
806            [
807                INET_DIAG_BC_JMP,
808                4, // yes
809                8u16.to_ne_bytes()[0],
810                8u16.to_ne_bytes()[1], // no
811                INET_DIAG_BC_JMP,
812                4, // yes
813                8u16.to_ne_bytes()[0],
814                8u16.to_ne_bytes()[1], // no
815            ]
816        );
817
818        let parsed = Bytecode::parse(&buf).unwrap();
819        assert_eq!(parsed, bc);
820    }
821
822    #[test]
823    fn buffer_too_small() {
824        let bc = Bytecode(vec![Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))]);
825        let mut buf = vec![0u8; 3]; // Nop is 4 bytes
826        assert_eq!(bc.serialize(&mut buf), Err(SerializationError::BufferTooSmall));
827    }
828
829    #[test]
830    fn index_too_large_yes() {
831        const COUNT: usize = 64;
832
833        let mut ops = vec![Instruction::Condition {
834            yes: Action::AdvanceBy(NonZeroUsize::new(COUNT + 1).unwrap()),
835            no: Action::Accept,
836            condition: Condition::AutoPort,
837        }];
838        // Each NOP is 4 bytes, so 64 NOPs is 256 bytes.
839        ops.extend(
840            (0..COUNT).map(|_| Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))),
841        );
842        ops.push(Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))); // Target
843
844        let bc = Bytecode(ops);
845        let mut buf = vec![0u8; bc.serialized_len()];
846        assert_eq!(
847            bc.serialize(&mut buf),
848            Err(SerializationError::InvalidInstruction {
849                at: 0,
850                error: InvalidInstructionError::IndexTooLargeForSerializedType
851            })
852        );
853    }
854
855    #[test]
856    fn index_too_large_no() {
857        const COUNT: usize = 16384;
858
859        let mut ops =
860            vec![Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::new(COUNT + 1).unwrap()))];
861        ops.extend(
862            (0..COUNT).map(|_| Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))),
863        );
864        ops.push(Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap()))); // Target
865
866        let bc = Bytecode(ops);
867        let mut buf = vec![0u8; bc.serialized_len()];
868        assert_eq!(
869            bc.serialize(&mut buf),
870            Err(SerializationError::InvalidInstruction {
871                at: 0,
872                error: InvalidInstructionError::IndexTooLargeForSerializedType
873            })
874        );
875    }
876
877    #[test]
878    fn index_overflow() {
879        let ops = vec![
880            Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap())),
881            Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::MAX)),
882        ];
883        let bc = Bytecode(ops);
884        let mut buf = vec![0u8; bc.serialized_len()];
885        assert_eq!(
886            bc.serialize(&mut buf),
887            Err(SerializationError::InvalidInstruction {
888                at: 1,
889                error: InvalidInstructionError::IndexOverflow
890            })
891        );
892    }
893
894    #[test]
895    fn advance_by_mapping() {
896        let bc = Bytecode(vec![
897            Instruction::Jmp(Action::AdvanceBy(NonZeroUsize::new(2).unwrap())),
898            Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(1).unwrap())),
899            Instruction::Nop(Action::Accept),
900        ]);
901
902        let mut buf = vec![0u8; bc.serialized_len()];
903        bc.clone().serialize(&mut buf).unwrap();
904
905        let parsed = Bytecode::parse(&buf).unwrap();
906        assert_eq!(parsed, bc);
907    }
908
909    #[test]
910    fn parse_errors() {
911        // Invalid bytecode!
912        let buf = vec![255, 4, 0, 0];
913        assert_eq!(
914            Bytecode::parse(&buf),
915            Err(ParseError { index: 0, code: ParseErrorCode::UnknownOpcode })
916        );
917
918        // Invalid target jump (jumping into the middle of an instruction).
919        let mut buf = vec![];
920        buf.push(INET_DIAG_BC_NOP);
921        buf.push(4); // yes
922        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
923
924        buf.push(INET_DIAG_BC_MARK_COND);
925        buf.push(4); // yes.
926        buf.extend_from_slice(&6u16.to_ne_bytes()); // no. Middle of the next instruction. Invalid!
927        buf.extend_from_slice(&0u32.to_ne_bytes());
928        buf.extend_from_slice(&0u32.to_ne_bytes());
929
930        buf.push(INET_DIAG_BC_JMP);
931        buf.push(4);
932        buf.extend_from_slice(&4u16.to_ne_bytes());
933        assert_eq!(
934            Bytecode::parse(&buf),
935            Err(ParseError { index: 4, code: ParseErrorCode::InvalidJumpTarget })
936        );
937
938        // Truncated instruction body (SrcPortGe missing bytes)
939        // S_GE: (4 bytes) + 4 bytes payload.
940        let mut buf = Vec::new();
941        buf.push(INET_DIAG_BC_S_GE);
942        buf.push(4); // yes.
943        buf.extend_from_slice(&4u16.to_ne_bytes());
944        assert_eq!(
945            Bytecode::parse(&buf),
946            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
947        );
948
949        // Invalid self-reference (yes=0).
950        let mut buf = Vec::new();
951        buf.push(INET_DIAG_BC_AUTO);
952        buf.push(0); // yes. Invalid!
953        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
954        assert_eq!(
955            Bytecode::parse(&buf),
956            Err(ParseError { index: 0, code: ParseErrorCode::SelfReference })
957        );
958
959        // Invalid self-reference (no=0).
960        let mut buf = Vec::new();
961        buf.push(INET_DIAG_BC_AUTO);
962        buf.push(4); // yes
963        buf.extend_from_slice(&0u16.to_ne_bytes()); // no. Invalid!
964        assert_eq!(
965            Bytecode::parse(&buf),
966            Err(ParseError { index: 0, code: ParseErrorCode::SelfReference })
967        );
968
969        // Invalid address family.
970        let mut buf = Vec::new();
971        buf.push(INET_DIAG_BC_S_COND);
972        buf.push(12); // yes
973        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
974        buf.push(255); // Invalid family
975        buf.push(0); // prefix len
976        buf.push(0); // pad
977        buf.push(0); // pad
978        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
979        assert_eq!(
980            Bytecode::parse(&buf),
981            Err(ParseError { index: 0, code: ParseErrorCode::InvalidAddressFamily })
982        );
983
984        // Prefix length longer than address.
985        let mut buf = Vec::new();
986        buf.push(INET_DIAG_BC_S_COND);
987        buf.push(16); // yes
988        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
989        buf.push(AF_INET);
990        buf.push(33); // prefix len
991        buf.push(0); // pad
992        buf.push(0); // pad
993        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
994        buf.extend_from_slice(&[0, 0, 0, 0]); // address
995        assert_eq!(
996            Bytecode::parse(&buf),
997            Err(ParseError { index: 0, code: ParseErrorCode::PrefixLengthLongerThanAddress })
998        );
999    }
1000
1001    #[test]
1002    fn truncated_payloads() {
1003        let mut buf = Vec::new();
1004        buf.push(INET_DIAG_BC_DEV_COND);
1005        buf.push(8); // yes
1006        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1007        buf.extend_from_slice(&[0; 3]); // 3 bytes instead of 4
1008        assert_eq!(
1009            Bytecode::parse(&buf),
1010            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1011        );
1012
1013        let mut buf = Vec::new();
1014        buf.push(INET_DIAG_BC_MARK_COND);
1015        buf.push(12); // yes
1016        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1017        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8
1018        assert_eq!(
1019            Bytecode::parse(&buf),
1020            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1021        );
1022
1023        let mut buf = Vec::new();
1024        buf.push(INET_DIAG_BC_CGROUP_COND);
1025        buf.push(12); // yes
1026        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1027        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8
1028        assert_eq!(
1029            Bytecode::parse(&buf),
1030            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1031        );
1032
1033        let mut buf = Vec::new();
1034        buf.push(INET_DIAG_BC_S_COND);
1035        buf.push(12); // yes
1036        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1037        buf.extend_from_slice(&[0; 7]); // 7 bytes instead of 8 (min header size)
1038        assert_eq!(
1039            Bytecode::parse(&buf),
1040            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1041        );
1042
1043        let mut buf = Vec::new();
1044        buf.push(INET_DIAG_BC_S_COND);
1045        buf.push(16); // yes
1046        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1047        buf.push(AF_INET);
1048        buf.push(0); // prefix len
1049        buf.push(0);
1050        buf.push(0); // pad
1051        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
1052        buf.extend_from_slice(&[0; 3]); // 3 bytes instead of 4 for IPv4
1053        assert_eq!(
1054            Bytecode::parse(&buf),
1055            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1056        );
1057
1058        let mut buf = Vec::new();
1059        buf.push(INET_DIAG_BC_S_COND);
1060        buf.push(28); // yes
1061        buf.extend_from_slice(&4u16.to_ne_bytes()); // no
1062        buf.push(AF_INET6);
1063        buf.push(0); // prefix len
1064        buf.push(0);
1065        buf.push(0); // pad
1066        buf.extend_from_slice(&(-1i32).to_ne_bytes()); // port (none)
1067        buf.extend_from_slice(&[0; 15]); // 15 bytes instead of 16 for IPv6
1068        assert_eq!(
1069            Bytecode::parse(&buf),
1070            Err(ParseError { index: 0, code: ParseErrorCode::TruncatedInstruction })
1071        );
1072    }
1073
1074    #[test]
1075    fn index_past_end() {
1076        let bc = Bytecode(vec![Instruction::Nop(Action::AdvanceBy(NonZeroUsize::new(2).unwrap()))]);
1077        let mut buf = vec![0u8; bc.serialized_len()];
1078        assert_eq!(
1079            bc.serialize(&mut buf),
1080            Err(SerializationError::InvalidInstruction {
1081                at: 0,
1082                error: InvalidInstructionError::IndexPastEnd
1083            })
1084        );
1085    }
1086
1087    #[test]
1088    fn prefix_length_longer_than_address_v4() {
1089        let bc = Bytecode(vec![Instruction::Condition {
1090            yes: Action::Accept,
1091            no: Action::Accept,
1092            condition: Condition::SrcTuple(TupleCondition {
1093                prefix_len: 33,
1094                port: None,
1095                addr: Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
1096            }),
1097        }]);
1098        let mut buf = vec![0u8; bc.serialized_len()];
1099        assert_eq!(
1100            bc.serialize(&mut buf),
1101            Err(SerializationError::InvalidInstruction {
1102                at: 0,
1103                error: InvalidInstructionError::PrefixLengthLongerThanAddress
1104            })
1105        );
1106    }
1107
1108    #[test]
1109    fn prefix_length_longer_than_address_v6() {
1110        let bc = Bytecode(vec![Instruction::Condition {
1111            yes: Action::Accept,
1112            no: Action::Accept,
1113            condition: Condition::SrcTuple(TupleCondition {
1114                prefix_len: 129,
1115                port: None,
1116                addr: Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
1117            }),
1118        }]);
1119        let mut buf = vec![0u8; bc.serialized_len()];
1120        assert_eq!(
1121            bc.serialize(&mut buf),
1122            Err(SerializationError::InvalidInstruction {
1123                at: 0,
1124                error: InvalidInstructionError::PrefixLengthLongerThanAddress
1125            })
1126        );
1127    }
1128
1129    #[test]
1130    fn prefix_length_longer_than_address_unspec() {
1131        let bc = Bytecode(vec![Instruction::Condition {
1132            yes: Action::Accept,
1133            no: Action::Accept,
1134            condition: Condition::SrcTuple(TupleCondition {
1135                prefix_len: 1,
1136                port: None,
1137                addr: None,
1138            }),
1139        }]);
1140        let mut buf = vec![0u8; bc.serialized_len()];
1141        assert_eq!(
1142            bc.serialize(&mut buf),
1143            Err(SerializationError::InvalidInstruction {
1144                at: 0,
1145                error: InvalidInstructionError::PrefixLengthLongerThanAddress
1146            })
1147        );
1148    }
1149}