1use crate::ie::*;
6use std::cmp::{max, min};
7use std::collections::HashSet;
8use std::ops::BitAnd;
9use zerocopy::Ref;
10
11macro_rules! impl_intersect {
22 ($struct_name:ident { $($op:ident: $field:ident),* $(,)?}) => {
23 paste::paste! {
24 impl Intersect for $struct_name {
25 fn intersect(&self, other: &Self) -> Self {
26 Self(0)
27 $(
28 .[<with_ $field>]($op(self.$field(), other.$field ()))
29 )*
30 }
31 }
32 }
33 };
34}
35
36pub trait Intersect {
40 fn intersect(&self, other: &Self) -> Self;
41}
42
43fn intersect<I: Intersect>(a: I, b: I) -> I {
44 a.intersect(&b)
45}
46
47fn and<B: BitAnd>(a: B, b: B) -> B::Output {
48 a & b
49}
50
51impl Intersect for SmPowerSave {
54 fn intersect(&self, other: &Self) -> Self {
55 if *self == Self::DISABLED || *other == Self::DISABLED {
56 Self::DISABLED
57 } else {
58 Self(min(self.0, other.0))
59 }
60 }
61}
62
63impl_intersect!(HtCapabilityInfo {
64 and: ldpc_coding_cap,
65 min: chan_width_set_raw, intersect: sm_power_save, and: greenfield,
68 and: short_gi_20,
69 and: short_gi_40,
70 and: tx_stbc,
71 min: rx_stbc,
72 and: delayed_block_ack,
73 and: max_amsdu_len_raw, and: dsss_in_40,
75 and: intolerant_40,
76 and: lsig_txop_protect,
77});
78
79impl_intersect!(AmpduParams {
80 min: max_ampdu_exponent_raw, max: min_start_spacing_raw, });
83
84impl_intersect!(SupportedMcsSet {
86 and: rx_mcs_raw, min: rx_highest_rate,
88 and: tx_set_defined,
89 and: tx_rx_diff,
90 min: tx_max_ss_raw, and: tx_ueqm,
92});
93
94impl Intersect for PcoTransitionTime {
97 fn intersect(&self, other: &Self) -> Self {
98 if *self == Self::PCO_RESERVED || *other == Self::PCO_RESERVED {
99 Self::PCO_RESERVED
100 } else {
101 Self(max(self.0, other.0))
102 }
103 }
104}
105
106impl_intersect!(HtExtCapabilities {
107 and: pco,
108 intersect: pco_transition, min: mcs_feedback_raw, and: htc_ht_support,
111 and: rd_responder,
112});
113
114impl_intersect!(TxBfCapability {
115 and: implicit_rx,
116 and: rx_stag_sounding,
117 and: tx_stag_sounding,
118 and: rx_ndp,
119 and: tx_ndp,
120 and: implicit,
121 min: calibration_raw, and: csi,
123 and: noncomp_steering,
124 and: comp_steering,
125
126 and: csi_feedback_raw, and: noncomp_feedback_raw, and: comp_feedback_raw, min: min_grouping_raw, min: csi_antennas_raw, min: noncomp_steering_ants_raw, min: comp_steering_ants_raw, min: csi_rows_raw, min: chan_estimation_raw, });
139
140impl_intersect!(AselCapability {
141 and: asel,
142 and: csi_feedback_tx_asel,
143 and: ant_idx_feedback_tx_asel,
144 and: explicit_csi_feedback,
145 and: antenna_idx_feedback,
146 and: rx_asel,
147 and: tx_sounding_ppdu,
148});
149
150impl Intersect for HtCapabilities {
151 fn intersect(&self, other: &Self) -> Self {
152 let mut out = Self {
153 ht_cap_info: { self.ht_cap_info }.intersect(&{ other.ht_cap_info }),
154 ampdu_params: { self.ampdu_params }.intersect(&{ other.ampdu_params }),
155 mcs_set: { self.mcs_set }.intersect(&{ other.mcs_set }),
156 ht_ext_cap: { self.ht_ext_cap }.intersect(&{ other.ht_ext_cap }),
157 txbf_cap: { self.txbf_cap }.intersect(&{ other.txbf_cap }),
158 asel_cap: { self.asel_cap }.intersect(&{ other.asel_cap }),
159 };
160 out.ht_cap_info = out
165 .ht_cap_info
166 .with_tx_stbc(if { other.ht_cap_info }.rx_stbc() != 0 {
167 { self.ht_cap_info }.tx_stbc()
168 } else {
169 false
170 })
171 .with_rx_stbc(if { other.ht_cap_info }.tx_stbc() {
172 { self.ht_cap_info }.rx_stbc()
173 } else {
174 0
175 });
176 out
177 }
178}
179
180impl_intersect!(VhtCapabilitiesInfo {
181 min: max_mpdu_len_raw, min: supported_cbw_set,
184 and: rx_ldpc,
185 and: sgi_cbw80,
186 and: sgi_cbw160,
187 and: tx_stbc,
188 min: rx_stbc,
189 and: su_bfer,
190 and: su_bfee,
191 min: bfee_sts,
192 min: num_sounding,
193 and: mu_bfer,
194 and: mu_bfee,
195 and: txop_ps,
196 and: htc_vht,
197 min: max_ampdu_exponent_raw, min: link_adapt_raw, and: rx_ant_pattern,
200 and: tx_ant_pattern,
201 min: ext_nss_bw,
202});
203
204impl Intersect for VhtMcsSet {
205 fn intersect(&self, other: &Self) -> Self {
206 if *self == Self::NONE || *other == Self::NONE {
207 Self::NONE
208 } else {
209 Self(min(self.0, other.0))
210 }
211 }
212}
213
214impl_intersect!(VhtMcsNssMap {
215 intersect: ss1, intersect: ss2, intersect: ss3, intersect: ss4, intersect: ss5, intersect: ss6, intersect: ss7, intersect: ss8, });
224
225impl_intersect!(VhtMcsNssSet {
226 intersect: rx_max_mcs, min: rx_max_data_rate,
228 min: max_nsts,
229 intersect: tx_max_mcs, min: tx_max_data_rate,
231 and: ext_nss_bw,
232});
233
234impl Intersect for VhtCapabilities {
235 fn intersect(&self, other: &Self) -> Self {
236 Self {
237 vht_cap_info: { self.vht_cap_info }.intersect(&{ other.vht_cap_info }),
238 vht_mcs_nss: { self.vht_mcs_nss }.intersect(&{ other.vht_mcs_nss }),
239 }
240 }
241}
242
243pub struct ApRates<'a>(pub &'a [SupportedRate]);
244pub struct ClientRates<'a>(pub &'a [SupportedRate]);
245
246impl<'a> From<&'a [u8]> for ApRates<'a> {
247 fn from(rates: &'a [u8]) -> Self {
248 Self(Ref::into_ref(Ref::from_bytes(rates).unwrap()))
250 }
251}
252
253impl<'a> From<&'a [u8]> for ClientRates<'a> {
254 fn from(rates: &'a [u8]) -> Self {
255 Self(Ref::into_ref(Ref::from_bytes(rates).unwrap()))
257 }
258}
259
260#[derive(Eq, PartialEq, Debug)]
261pub enum IntersectRatesError {
262 BasicRatesMismatch,
263 NoApRatesSupported,
264}
265
266pub fn intersect_rates(
271 ap_rates: ApRates<'_>,
272 client_rates: ClientRates<'_>,
273) -> Result<Vec<SupportedRate>, IntersectRatesError> {
274 let mut rates: Vec<_> =
276 ap_rates.0.iter().copied().filter(|rate| !rate.is_bss_membership_selector()).collect();
277 let client_rates = client_rates.0.iter().map(|r| r.rate()).collect::<HashSet<_>>();
278 if rates.iter().any(|ra| ra.basic() && !client_rates.contains(&ra.rate())) {
280 return Err(IntersectRatesError::BasicRatesMismatch);
281 }
282
283 rates.retain(|ra| client_rates.contains(&ra.rate()));
285 if rates.is_empty() {
286 Err(IntersectRatesError::NoApRatesSupported)
287 } else {
288 Ok(rates)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 impl SupportedRate {
296 fn new_basic(rate: u8) -> Self {
297 Self(rate).with_basic(true)
298 }
299 }
300
301 #[test]
302 fn some_basic_rate_missing() {
303 let error = intersect_rates(
305 ApRates(&[SupportedRate::new_basic(120), SupportedRate::new_basic(111)][..]),
306 ClientRates(&[SupportedRate(111)][..]),
307 )
308 .unwrap_err();
309 assert_eq!(error, IntersectRatesError::BasicRatesMismatch);
310 }
311
312 #[test]
313 fn all_basic_rates_supported() {
314 assert_eq!(
315 vec![SupportedRate::new_basic(120)],
316 intersect_rates(
317 ApRates(&[SupportedRate::new_basic(120), SupportedRate(111)][..]),
318 ClientRates(&[SupportedRate(120)][..])
319 )
320 .unwrap()
321 );
322 }
323
324 #[test]
325 fn all_basic_and_non_basic_rates_supported() {
326 assert_eq!(
327 vec![SupportedRate::new_basic(120)],
328 intersect_rates(
329 ApRates(&[SupportedRate::new_basic(120), SupportedRate(111)][..]),
330 ClientRates(&[SupportedRate(120)][..])
331 )
332 .unwrap()
333 );
334 }
335
336 #[test]
337 fn no_rates_are_supported() {
338 let error =
339 intersect_rates(ApRates(&[SupportedRate(120)][..]), ClientRates(&[][..])).unwrap_err();
340 assert_eq!(error, IntersectRatesError::NoApRatesSupported);
341 }
342
343 #[test]
344 fn preserve_ap_rates_basicness() {
345 assert_eq!(
347 vec![SupportedRate(120)],
348 intersect_rates(
349 ApRates(&[SupportedRate(120), SupportedRate(111)][..]),
350 ClientRates(&[SupportedRate::new_basic(120)][..])
351 )
352 .unwrap()
353 );
354 }
355
356 #[test]
359 fn intersect_ht_cap_info_chan_width_set() {
360 let a = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_ONLY);
361 let b = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
362 assert_eq!(ChanWidthSet::TWENTY_ONLY, a.intersect(&b).chan_width_set());
363
364 let a = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
365 let b = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
366 assert_eq!(ChanWidthSet::TWENTY_FORTY, a.intersect(&b).chan_width_set());
367 }
368
369 #[test]
370 fn intersect_supported_mcs_set() {
371 let a = SupportedMcsSet(0).with_rx_mcs_raw(0xffff);
372 let b = SupportedMcsSet(0).with_rx_mcs_raw(0x0304);
373 assert_eq!(RxMcsBitmask(0x0304), a.intersect(&b).rx_mcs());
374 }
375
376 #[test]
377 fn intersect_sm_power_save() {
378 assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::DISABLED));
379 assert_eq!(SmPowerSave::DISABLED, SmPowerSave::STATIC.intersect(&SmPowerSave::DISABLED));
380 assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::DISABLED));
381 assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::STATIC));
382 assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::DYNAMIC));
383
384 assert_eq!(SmPowerSave::STATIC, SmPowerSave::STATIC.intersect(&SmPowerSave::DYNAMIC));
385 assert_eq!(SmPowerSave::STATIC, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::STATIC));
386
387 assert_eq!(SmPowerSave::DYNAMIC, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::DYNAMIC));
388 }
389
390 #[test]
391 fn intersect_pco_transition() {
392 type PTT = PcoTransitionTime;
393 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_RESERVED));
394 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_400_USEC));
395 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_1500_USEC));
396 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_5000_USEC));
397
398 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_400_USEC.intersect(&PTT::PCO_RESERVED));
399 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_1500_USEC.intersect(&PTT::PCO_RESERVED));
400 assert_eq!(PTT::PCO_RESERVED, PTT::PCO_5000_USEC.intersect(&PTT::PCO_RESERVED));
401
402 assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_5000_USEC));
403 assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_1500_USEC.intersect(&PTT::PCO_5000_USEC));
404 assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_5000_USEC));
405
406 assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_400_USEC));
407 assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_1500_USEC));
408
409 assert_eq!(PTT::PCO_1500_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_1500_USEC));
410 assert_eq!(PTT::PCO_1500_USEC, PTT::PCO_1500_USEC.intersect(&PTT::PCO_400_USEC));
411
412 assert_eq!(PTT::PCO_400_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_400_USEC));
413 }
414
415 #[test]
416 fn intersect_ht_cap_info_stbc() {
418 let mut ht_cap_a = fake_ht_capabilities();
419 let mut ht_cap_b = fake_ht_capabilities();
420
421 ht_cap_a.ht_cap_info = HtCapabilityInfo(0).with_tx_stbc(true).with_rx_stbc(2);
422 ht_cap_b.ht_cap_info = HtCapabilityInfo(0).with_tx_stbc(false).with_rx_stbc(1);
423
424 let intersected_ht_cap_info = ht_cap_a.intersect(&ht_cap_b).ht_cap_info;
425 assert_eq!(true, intersected_ht_cap_info.tx_stbc());
426 assert_eq!(0, intersected_ht_cap_info.rx_stbc());
427
428 let intersected_ht_cap_info = ht_cap_b.intersect(&ht_cap_a).ht_cap_info;
429 assert_eq!(false, intersected_ht_cap_info.tx_stbc());
430 assert_eq!(1, intersected_ht_cap_info.rx_stbc())
431 }
432
433 #[test]
434 fn intersect_vht_mcs_set() {
435 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_7));
436 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_8));
437 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_9));
438 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::NONE));
439 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_7.intersect(&VhtMcsSet::NONE));
440 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::NONE));
441 assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::NONE));
442
443 assert_eq!(VhtMcsSet::UP_TO_7, VhtMcsSet::UP_TO_7.intersect(&VhtMcsSet::UP_TO_8));
444 assert_eq!(VhtMcsSet::UP_TO_7, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::UP_TO_7));
445
446 assert_eq!(VhtMcsSet::UP_TO_8, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::UP_TO_9));
447 assert_eq!(VhtMcsSet::UP_TO_8, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::UP_TO_8));
448
449 assert_eq!(VhtMcsSet::UP_TO_9, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::UP_TO_9));
450 }
451}