netlink_packet_generic/ctrl/nlas/
mod.rs

1// SPDX-License-Identifier: MIT
2
3use crate::constants::*;
4use anyhow::Context;
5use byteorder::{ByteOrder, NativeEndian};
6use netlink_packet_utils::DecodeError;
7use netlink_packet_utils::nla::{Nla, NlaBuffer, NlasIterator};
8use netlink_packet_utils::parsers::*;
9use netlink_packet_utils::traits::*;
10use std::mem::size_of_val;
11
12mod mcast;
13mod oppolicy;
14mod ops;
15mod policy;
16
17pub use mcast::*;
18pub use oppolicy::*;
19pub use ops::*;
20pub use policy::*;
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum GenlCtrlAttrs {
24    FamilyId(u16),
25    FamilyName(String),
26    Version(u32),
27    HdrSize(u32),
28    MaxAttr(u32),
29    Ops(Vec<Vec<OpAttrs>>),
30    McastGroups(Vec<Vec<McastGrpAttrs>>),
31    Policy(PolicyAttr),
32    OpPolicy(OppolicyAttr),
33    Op(u32),
34}
35
36impl Nla for GenlCtrlAttrs {
37    fn value_len(&self) -> usize {
38        use GenlCtrlAttrs::*;
39        match self {
40            FamilyId(v) => size_of_val(v),
41            FamilyName(s) => s.len() + 1,
42            Version(v) => size_of_val(v),
43            HdrSize(v) => size_of_val(v),
44            MaxAttr(v) => size_of_val(v),
45            Ops(nlas) => OpList::from(nlas).as_slice().buffer_len(),
46            McastGroups(nlas) => McastGroupList::from(nlas).as_slice().buffer_len(),
47            Policy(nla) => nla.buffer_len(),
48            OpPolicy(nla) => nla.buffer_len(),
49            Op(v) => size_of_val(v),
50        }
51    }
52
53    fn kind(&self) -> u16 {
54        use GenlCtrlAttrs::*;
55        match self {
56            FamilyId(_) => CTRL_ATTR_FAMILY_ID,
57            FamilyName(_) => CTRL_ATTR_FAMILY_NAME,
58            Version(_) => CTRL_ATTR_VERSION,
59            HdrSize(_) => CTRL_ATTR_HDRSIZE,
60            MaxAttr(_) => CTRL_ATTR_MAXATTR,
61            Ops(_) => CTRL_ATTR_OPS,
62            McastGroups(_) => CTRL_ATTR_MCAST_GROUPS,
63            Policy(_) => CTRL_ATTR_POLICY,
64            OpPolicy(_) => CTRL_ATTR_OP_POLICY,
65            Op(_) => CTRL_ATTR_OP,
66        }
67    }
68
69    fn emit_value(&self, buffer: &mut [u8]) {
70        use GenlCtrlAttrs::*;
71        match self {
72            FamilyId(v) => NativeEndian::write_u16(buffer, *v),
73            FamilyName(s) => {
74                buffer[..s.len()].copy_from_slice(s.as_bytes());
75                buffer[s.len()] = 0;
76            }
77            Version(v) => NativeEndian::write_u32(buffer, *v),
78            HdrSize(v) => NativeEndian::write_u32(buffer, *v),
79            MaxAttr(v) => NativeEndian::write_u32(buffer, *v),
80            Ops(nlas) => {
81                OpList::from(nlas).as_slice().emit(buffer);
82            }
83            McastGroups(nlas) => {
84                McastGroupList::from(nlas).as_slice().emit(buffer);
85            }
86            Policy(nla) => nla.emit_value(buffer),
87            OpPolicy(nla) => nla.emit_value(buffer),
88            Op(v) => NativeEndian::write_u32(buffer, *v),
89        }
90    }
91}
92
93impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>> for GenlCtrlAttrs {
94    type Error = DecodeError;
95    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
96        let payload = buf.value();
97        Ok(match buf.kind() {
98            CTRL_ATTR_FAMILY_ID => {
99                Self::FamilyId(parse_u16(payload).context("invalid CTRL_ATTR_FAMILY_ID value")?)
100            }
101            CTRL_ATTR_FAMILY_NAME => Self::FamilyName(
102                parse_string(payload).context("invalid CTRL_ATTR_FAMILY_NAME value")?,
103            ),
104            CTRL_ATTR_VERSION => {
105                Self::Version(parse_u32(payload).context("invalid CTRL_ATTR_VERSION value")?)
106            }
107            CTRL_ATTR_HDRSIZE => {
108                Self::HdrSize(parse_u32(payload).context("invalid CTRL_ATTR_HDRSIZE value")?)
109            }
110            CTRL_ATTR_MAXATTR => {
111                Self::MaxAttr(parse_u32(payload).context("invalid CTRL_ATTR_MAXATTR value")?)
112            }
113            CTRL_ATTR_OPS => {
114                let ops = NlasIterator::new(payload)
115                    .map(|nlas| {
116                        nlas.map_err(|err| DecodeError::from(err)).and_then(|nlas| {
117                            NlasIterator::new(nlas.value())
118                                .map(|nla| {
119                                    nla.map_err(|err| DecodeError::from(err))
120                                        .and_then(|nla| OpAttrs::parse(&nla))
121                                })
122                                .collect::<Result<Vec<_>, _>>()
123                        })
124                    })
125                    .collect::<Result<Vec<Vec<_>>, _>>()
126                    .context("failed to parse CTRL_ATTR_OPS")?;
127                Self::Ops(ops)
128            }
129            CTRL_ATTR_MCAST_GROUPS => {
130                let groups = NlasIterator::new(payload)
131                    .map(|nlas| {
132                        nlas.map_err(|err| DecodeError::from(err)).and_then(|nlas| {
133                            NlasIterator::new(nlas.value())
134                                .map(|nla| {
135                                    nla.map_err(|err| DecodeError::from(err))
136                                        .and_then(|nla| McastGrpAttrs::parse(&nla))
137                                })
138                                .collect::<Result<Vec<_>, _>>()
139                        })
140                    })
141                    .collect::<Result<Vec<Vec<_>>, _>>()
142                    .context("failed to parse CTRL_ATTR_MCAST_GROUPS")?;
143                Self::McastGroups(groups)
144            }
145            CTRL_ATTR_POLICY => Self::Policy(
146                PolicyAttr::parse(&NlaBuffer::new(payload)?)
147                    .context("failed to parse CTRL_ATTR_POLICY")?,
148            ),
149            CTRL_ATTR_OP_POLICY => Self::OpPolicy(
150                OppolicyAttr::parse(&NlaBuffer::new(payload)?)
151                    .context("failed to parse CTRL_ATTR_OP_POLICY")?,
152            ),
153            CTRL_ATTR_OP => Self::Op(parse_u32(payload)?),
154            kind => return Err(DecodeError::from(format!("Unknown NLA type: {kind}"))),
155        })
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn mcast_groups_parse() {
165        let mcast_bytes: [u8; 24] = [
166            24, 0, // Netlink header length
167            7, 0, // Netlink header kind (Mcast groups)
168            20, 0, // Mcast group nested NLA length
169            1, 0, // Mcast group kind
170            8, 0, // Id length
171            2, 0, // Id kind
172            1, 0, 0, 0, // Id
173            8, 0, // Name length
174            1, 0, // Name kind
175            b't', b'e', b's', b't', // Name
176        ];
177        let nla_buffer = NlaBuffer::new(&mcast_bytes[..]).expect("Failed to create NlaBuffer");
178        let result_attr =
179            GenlCtrlAttrs::parse(&nla_buffer).expect("Failed to parse encoded McastGroups");
180        let expected_attr = GenlCtrlAttrs::McastGroups(vec![vec![
181            McastGrpAttrs::Id(1),
182            McastGrpAttrs::Name("test".to_string()),
183        ]]);
184        assert_eq!(expected_attr, result_attr);
185    }
186
187    #[test]
188    fn mcast_groups_emit() {
189        let mcast_attr = GenlCtrlAttrs::McastGroups(vec![
190            vec![McastGrpAttrs::Id(7), McastGrpAttrs::Name("group1".to_string())],
191            vec![McastGrpAttrs::Id(8), McastGrpAttrs::Name("group2".to_string())],
192        ]);
193        let expected_bytes: [u8; 52] = [
194            52, 0, // Netlink header length
195            7, 0, // Netlink header kind (Mcast groups)
196            24, 0, // Mcast group nested NLA length
197            1, 0, // Mcast group kind (index 1)
198            8, 0, // Id length
199            2, 0, // Id kind
200            7, 0, 0, 0, // Id
201            11, 0, // Name length
202            1, 0, // Name kind
203            b'g', b'r', b'o', b'u', b'p', b'1', 0, // Name
204            0, // mcast group padding
205            24, 0, // Mcast group nested NLA length
206            2, 0, // Mcast group kind (index 2)
207            8, 0, // Id length
208            2, 0, // Id kind
209            8, 0, 0, 0, // Id
210            11, 0, // Name length
211            1, 0, // Name kind
212            b'g', b'r', b'o', b'u', b'p', b'2', 0, // Name
213            0, // padding
214        ];
215        let mut buf = vec![0u8; 100];
216        mcast_attr.emit(&mut buf);
217
218        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
219    }
220
221    #[test]
222    fn ops_parse() {
223        let ops_bytes: [u8; 24] = [
224            24, 0, // Netlink header length
225            6, 0, // Netlink header kind (Ops)
226            20, 0, // Op nested NLA length
227            0, 0, // Op kind
228            8, 0, // Id length
229            1, 0, // Id kind
230            1, 0, 0, 0, // Id
231            8, 0, // Flags length
232            2, 0, // Flags kind
233            123, 0, 0, 0, // Flags
234        ];
235        let nla_buffer = NlaBuffer::new(&ops_bytes[..]).expect("Failed to create NlaBuffer");
236        let result_attr =
237            GenlCtrlAttrs::parse(&nla_buffer).expect("Failed to parse encoded McastGroups");
238        let expected_attr = GenlCtrlAttrs::Ops(vec![vec![OpAttrs::Id(1), OpAttrs::Flags(123)]]);
239        assert_eq!(expected_attr, result_attr);
240    }
241
242    #[test]
243    fn ops_emit() {
244        let ops = GenlCtrlAttrs::Ops(vec![
245            vec![OpAttrs::Id(1), OpAttrs::Flags(11)],
246            vec![OpAttrs::Id(3), OpAttrs::Flags(33)],
247        ]);
248        let expected_bytes: [u8; 44] = [
249            44, 0, // Netlink header length
250            6, 0, // Netlink header kind (Ops)
251            20, 0, // Op nested NLA length
252            1, 0, // Op kind
253            8, 0, // Id length
254            1, 0, // Id kind
255            1, 0, 0, 0, // Id
256            8, 0, // Flags length
257            2, 0, // Flags kind
258            11, 0, 0, 0, // Flags
259            20, 0, // Op nested NLA length
260            2, 0, // Op kind
261            8, 0, // Id length
262            1, 0, // Id kind
263            3, 0, 0, 0, // Id
264            8, 0, // Flags length
265            2, 0, // Flags kind
266            33, 0, 0, 0, // Flags
267        ];
268        let mut buf = vec![0u8; 100];
269        ops.emit(&mut buf);
270
271        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
272    }
273}