netlink_packet_generic/ctrl/nlas/
mod.rs

1// SPDX-License-Identifier: MIT
2
3use crate::constants::*;
4use anyhow::Context;
5use byteorder::{ByteOrder, NativeEndian};
6use netlink_packet_utils::nla::{Nla, NlaBuffer, NlasIterator};
7use netlink_packet_utils::parsers::*;
8use netlink_packet_utils::traits::*;
9use netlink_packet_utils::DecodeError;
10use std::mem::size_of_val;
11
12mod mcast;
13mod oppolicy;
14mod ops;
15mod policy;
16
17pub use mcast::*;
18pub use oppolicy::*;
19pub use ops::*;
20pub use policy::*;
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum GenlCtrlAttrs {
24    FamilyId(u16),
25    FamilyName(String),
26    Version(u32),
27    HdrSize(u32),
28    MaxAttr(u32),
29    Ops(Vec<Vec<OpAttrs>>),
30    McastGroups(Vec<Vec<McastGrpAttrs>>),
31    Policy(PolicyAttr),
32    OpPolicy(OppolicyAttr),
33    Op(u32),
34}
35
36impl Nla for GenlCtrlAttrs {
37    fn value_len(&self) -> usize {
38        use GenlCtrlAttrs::*;
39        match self {
40            FamilyId(v) => size_of_val(v),
41            FamilyName(s) => s.len() + 1,
42            Version(v) => size_of_val(v),
43            HdrSize(v) => size_of_val(v),
44            MaxAttr(v) => size_of_val(v),
45            Ops(nlas) => OpList::from(nlas).as_slice().buffer_len(),
46            McastGroups(nlas) => McastGroupList::from(nlas).as_slice().buffer_len(),
47            Policy(nla) => nla.buffer_len(),
48            OpPolicy(nla) => nla.buffer_len(),
49            Op(v) => size_of_val(v),
50        }
51    }
52
53    fn kind(&self) -> u16 {
54        use GenlCtrlAttrs::*;
55        match self {
56            FamilyId(_) => CTRL_ATTR_FAMILY_ID,
57            FamilyName(_) => CTRL_ATTR_FAMILY_NAME,
58            Version(_) => CTRL_ATTR_VERSION,
59            HdrSize(_) => CTRL_ATTR_HDRSIZE,
60            MaxAttr(_) => CTRL_ATTR_MAXATTR,
61            Ops(_) => CTRL_ATTR_OPS,
62            McastGroups(_) => CTRL_ATTR_MCAST_GROUPS,
63            Policy(_) => CTRL_ATTR_POLICY,
64            OpPolicy(_) => CTRL_ATTR_OP_POLICY,
65            Op(_) => CTRL_ATTR_OP,
66        }
67    }
68
69    fn emit_value(&self, buffer: &mut [u8]) {
70        use GenlCtrlAttrs::*;
71        match self {
72            FamilyId(v) => NativeEndian::write_u16(buffer, *v),
73            FamilyName(s) => {
74                buffer[..s.len()].copy_from_slice(s.as_bytes());
75                buffer[s.len()] = 0;
76            }
77            Version(v) => NativeEndian::write_u32(buffer, *v),
78            HdrSize(v) => NativeEndian::write_u32(buffer, *v),
79            MaxAttr(v) => NativeEndian::write_u32(buffer, *v),
80            Ops(nlas) => {
81                OpList::from(nlas).as_slice().emit(buffer);
82            }
83            McastGroups(nlas) => {
84                McastGroupList::from(nlas).as_slice().emit(buffer);
85            }
86            Policy(nla) => nla.emit_value(buffer),
87            OpPolicy(nla) => nla.emit_value(buffer),
88            Op(v) => NativeEndian::write_u32(buffer, *v),
89        }
90    }
91}
92
93impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>> for GenlCtrlAttrs {
94    type Error = DecodeError;
95    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
96        let payload = buf.value();
97        Ok(match buf.kind() {
98            CTRL_ATTR_FAMILY_ID => {
99                Self::FamilyId(parse_u16(payload).context("invalid CTRL_ATTR_FAMILY_ID value")?)
100            }
101            CTRL_ATTR_FAMILY_NAME => Self::FamilyName(
102                parse_string(payload).context("invalid CTRL_ATTR_FAMILY_NAME value")?,
103            ),
104            CTRL_ATTR_VERSION => {
105                Self::Version(parse_u32(payload).context("invalid CTRL_ATTR_VERSION value")?)
106            }
107            CTRL_ATTR_HDRSIZE => {
108                Self::HdrSize(parse_u32(payload).context("invalid CTRL_ATTR_HDRSIZE value")?)
109            }
110            CTRL_ATTR_MAXATTR => {
111                Self::MaxAttr(parse_u32(payload).context("invalid CTRL_ATTR_MAXATTR value")?)
112            }
113            CTRL_ATTR_OPS => {
114                let ops = NlasIterator::new(payload)
115                    .map(|nlas| {
116                        nlas.map_err(|err| DecodeError::from(err)).and_then(|nlas| {
117                            NlasIterator::new(nlas.value())
118                                .map(|nla| {
119                                    nla.map_err(|err| DecodeError::from(err))
120                                        .and_then(|nla| OpAttrs::parse(&nla))
121                                })
122                                .collect::<Result<Vec<_>, _>>()
123                        })
124                    })
125                    .collect::<Result<Vec<Vec<_>>, _>>()
126                    .context("failed to parse CTRL_ATTR_OPS")?;
127                Self::Ops(ops)
128            }
129            CTRL_ATTR_MCAST_GROUPS => {
130                let groups = NlasIterator::new(payload)
131                    .map(|nlas| {
132                        nlas.map_err(|err| DecodeError::from(err)).and_then(|nlas| {
133                            NlasIterator::new(nlas.value())
134                                .map(|nla| {
135                                    nla.map_err(|err| DecodeError::from(err))
136                                        .and_then(|nla| McastGrpAttrs::parse(&nla))
137                                })
138                                .collect::<Result<Vec<_>, _>>()
139                        })
140                    })
141                    .collect::<Result<Vec<Vec<_>>, _>>()
142                    .context("failed to parse CTRL_ATTR_MCAST_GROUPS")?;
143                Self::McastGroups(groups)
144            }
145            CTRL_ATTR_POLICY => Self::Policy(
146                PolicyAttr::parse(&NlaBuffer::new(payload))
147                    .context("failed to parse CTRL_ATTR_POLICY")?,
148            ),
149            CTRL_ATTR_OP_POLICY => Self::OpPolicy(
150                OppolicyAttr::parse(&NlaBuffer::new(payload))
151                    .context("failed to parse CTRL_ATTR_OP_POLICY")?,
152            ),
153            CTRL_ATTR_OP => Self::Op(parse_u32(payload)?),
154            kind => return Err(DecodeError::from(format!("Unknown NLA type: {kind}"))),
155        })
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn mcast_groups_parse() {
165        let mcast_bytes: [u8; 24] = [
166            24, 0, // Netlink header length
167            7, 0, // Netlink header kind (Mcast groups)
168            20, 0, // Mcast group nested NLA length
169            1, 0, // Mcast group kind
170            8, 0, // Id length
171            2, 0, // Id kind
172            1, 0, 0, 0, // Id
173            8, 0, // Name length
174            1, 0, // Name kind
175            b't', b'e', b's', b't', // Name
176        ];
177        let nla_buffer =
178            NlaBuffer::new_checked(&mcast_bytes[..]).expect("Failed to create NlaBuffer");
179        let result_attr =
180            GenlCtrlAttrs::parse(&nla_buffer).expect("Failed to parse encoded McastGroups");
181        let expected_attr = GenlCtrlAttrs::McastGroups(vec![vec![
182            McastGrpAttrs::Id(1),
183            McastGrpAttrs::Name("test".to_string()),
184        ]]);
185        assert_eq!(expected_attr, result_attr);
186    }
187
188    #[test]
189    fn mcast_groups_emit() {
190        let mcast_attr = GenlCtrlAttrs::McastGroups(vec![
191            vec![McastGrpAttrs::Id(7), McastGrpAttrs::Name("group1".to_string())],
192            vec![McastGrpAttrs::Id(8), McastGrpAttrs::Name("group2".to_string())],
193        ]);
194        let expected_bytes: [u8; 52] = [
195            52, 0, // Netlink header length
196            7, 0, // Netlink header kind (Mcast groups)
197            24, 0, // Mcast group nested NLA length
198            1, 0, // Mcast group kind (index 1)
199            8, 0, // Id length
200            2, 0, // Id kind
201            7, 0, 0, 0, // Id
202            11, 0, // Name length
203            1, 0, // Name kind
204            b'g', b'r', b'o', b'u', b'p', b'1', 0, // Name
205            0, // mcast group padding
206            24, 0, // Mcast group nested NLA length
207            2, 0, // Mcast group kind (index 2)
208            8, 0, // Id length
209            2, 0, // Id kind
210            8, 0, 0, 0, // Id
211            11, 0, // Name length
212            1, 0, // Name kind
213            b'g', b'r', b'o', b'u', b'p', b'2', 0, // Name
214            0, // padding
215        ];
216        let mut buf = vec![0u8; 100];
217        mcast_attr.emit(&mut buf);
218
219        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
220    }
221
222    #[test]
223    fn ops_parse() {
224        let ops_bytes: [u8; 24] = [
225            24, 0, // Netlink header length
226            6, 0, // Netlink header kind (Ops)
227            20, 0, // Op nested NLA length
228            0, 0, // Op kind
229            8, 0, // Id length
230            1, 0, // Id kind
231            1, 0, 0, 0, // Id
232            8, 0, // Flags length
233            2, 0, // Flags kind
234            123, 0, 0, 0, // Flags
235        ];
236        let nla_buffer =
237            NlaBuffer::new_checked(&ops_bytes[..]).expect("Failed to create NlaBuffer");
238        let result_attr =
239            GenlCtrlAttrs::parse(&nla_buffer).expect("Failed to parse encoded McastGroups");
240        let expected_attr = GenlCtrlAttrs::Ops(vec![vec![OpAttrs::Id(1), OpAttrs::Flags(123)]]);
241        assert_eq!(expected_attr, result_attr);
242    }
243
244    #[test]
245    fn ops_emit() {
246        let ops = GenlCtrlAttrs::Ops(vec![
247            vec![OpAttrs::Id(1), OpAttrs::Flags(11)],
248            vec![OpAttrs::Id(3), OpAttrs::Flags(33)],
249        ]);
250        let expected_bytes: [u8; 44] = [
251            44, 0, // Netlink header length
252            6, 0, // Netlink header kind (Ops)
253            20, 0, // Op nested NLA length
254            1, 0, // Op kind
255            8, 0, // Id length
256            1, 0, // Id kind
257            1, 0, 0, 0, // Id
258            8, 0, // Flags length
259            2, 0, // Flags kind
260            11, 0, 0, 0, // Flags
261            20, 0, // Op nested NLA length
262            2, 0, // Op kind
263            8, 0, // Id length
264            1, 0, // Id kind
265            3, 0, 0, 0, // Id
266            8, 0, // Flags length
267            2, 0, // Flags kind
268            33, 0, 0, 0, // Flags
269        ];
270        let mut buf = vec![0u8; 100];
271        ops.emit(&mut buf);
272
273        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
274    }
275}