wlan_common/ie/
intersect.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::ie::*;
6use std::cmp::{max, min};
7use std::collections::HashSet;
8use std::ops::BitAnd;
9use zerocopy::Ref;
10
11// TODO(https://fxbug.dev/42118992): HT and VHT intersections defined here are best effort only.
12
13// For example:
14// struct Foo { a: u8, b: u8 };
15// impl_intersect!(Foo, {
16//   intersect: a,
17//   min: b,
18// }
19// will produce a Foo { a: intersect(self.a, other.a), b: min(self.b, other.b) }
20
21macro_rules! impl_intersect {
22  ($struct_name:ident { $($op:ident: $field:ident),* $(,)?}) => {
23    paste::paste! {
24      impl Intersect for $struct_name {
25        fn intersect(&self, other: &Self) -> Self {
26          Self(0)
27          $(
28            .[<with_ $field>]($op(self.$field(), other.$field ()))
29          )*
30        }
31      }
32    }
33  };
34}
35
36/// Intersect capabilities between two entities, such as a client and an AP.
37/// Note: a.intersect(b) is not guaranteed to be the same as b.intersect(b). One such example is the
38/// TX_STBC and RX_STBC fields in HtCapabilityInfo.
39pub trait Intersect {
40    fn intersect(&self, other: &Self) -> Self;
41}
42
43fn intersect<I: Intersect>(a: I, b: I) -> I {
44    a.intersect(&b)
45}
46
47fn and<B: BitAnd>(a: B, b: B) -> B::Output {
48    a & b
49}
50
51// IEEE Std. 802.11-2016, 11.2.6 mentioned this but did not provide interpretation for DISABLED.
52// This is Fuchsia's interpretation.
53impl Intersect for SmPowerSave {
54    fn intersect(&self, other: &Self) -> Self {
55        if *self == Self::DISABLED || *other == Self::DISABLED {
56            Self::DISABLED
57        } else {
58            Self(min(self.0, other.0))
59        }
60    }
61}
62
63impl_intersect!(HtCapabilityInfo {
64    and: ldpc_coding_cap,
65    min: chan_width_set_raw,  // ChanWidthSet(u8)
66    intersect: sm_power_save, // SmPowerSave(u8)
67    and: greenfield,
68    and: short_gi_20,
69    and: short_gi_40,
70    and: tx_stbc,
71    min: rx_stbc,
72    and: delayed_block_ack,
73    and: max_amsdu_len_raw, // MaxAmsduLen(u8)
74    and: dsss_in_40,
75    and: intolerant_40,
76    and: lsig_txop_protect,
77});
78
79impl_intersect!(AmpduParams {
80    min: max_ampdu_exponent_raw, // MaxAmpduExponent(u8)
81    max: min_start_spacing_raw,  // MinMpduStartSpacing(u8)
82});
83
84// TODO(https://fxbug.dev/42104152): if tx_rx_diff is set, the intersection rule may be more complicated.
85impl_intersect!(SupportedMcsSet {
86    and: rx_mcs_raw, // RxMcsBitmask(u128)
87    min: rx_highest_rate,
88    and: tx_set_defined,
89    and: tx_rx_diff,
90    min: tx_max_ss_raw, // NumSpatialStreams(u8)
91    and: tx_ueqm,
92});
93
94// IEEE Std. 802.11-2016, 11.17.3
95// PcoTransitionTime can be dynamic so the best effort here is to use the slower transition time.
96impl Intersect for PcoTransitionTime {
97    fn intersect(&self, other: &Self) -> Self {
98        if *self == Self::PCO_RESERVED || *other == Self::PCO_RESERVED {
99            Self::PCO_RESERVED
100        } else {
101            Self(max(self.0, other.0))
102        }
103    }
104}
105
106impl_intersect!(HtExtCapabilities {
107    and: pco,
108    intersect: pco_transition, // PcoTransitionTime(u8)
109    min: mcs_feedback_raw,     // McsFeedback(u8)
110    and: htc_ht_support,
111    and: rd_responder,
112});
113
114impl_intersect!(TxBfCapability {
115    and: implicit_rx,
116    and: rx_stag_sounding,
117    and: tx_stag_sounding,
118    and: rx_ndp,
119    and: tx_ndp,
120    and: implicit,
121    min: calibration_raw, // Calibration(u8)
122    and: csi,
123    and: noncomp_steering,
124    and: comp_steering,
125
126    // IEEE 802.11-2016 Table 9-166
127    // xxx_feedback behaves like bitmask for delayed and immediate feedback
128    and: csi_feedback_raw,     // Feedback(u8)
129    and: noncomp_feedback_raw, // Feedback(u8)
130    and: comp_feedback_raw,    // Feedback(u8)
131
132    min: min_grouping_raw,          // MinGroup(u8)
133    min: csi_antennas_raw,          // NumAntennas(u8)
134    min: noncomp_steering_ants_raw, // NumAntennas(u8)
135    min: comp_steering_ants_raw,    // NumAntennas(u8)
136    min: csi_rows_raw,              // NumCsiRows(u8)
137    min: chan_estimation_raw,       // NumSpaceTimeStreams(u8)
138});
139
140impl_intersect!(AselCapability {
141    and: asel,
142    and: csi_feedback_tx_asel,
143    and: ant_idx_feedback_tx_asel,
144    and: explicit_csi_feedback,
145    and: antenna_idx_feedback,
146    and: rx_asel,
147    and: tx_sounding_ppdu,
148});
149
150impl Intersect for HtCapabilities {
151    fn intersect(&self, other: &Self) -> Self {
152        let mut out = Self {
153            ht_cap_info: { self.ht_cap_info }.intersect(&{ other.ht_cap_info }),
154            ampdu_params: { self.ampdu_params }.intersect(&{ other.ampdu_params }),
155            mcs_set: { self.mcs_set }.intersect(&{ other.mcs_set }),
156            ht_ext_cap: { self.ht_ext_cap }.intersect(&{ other.ht_ext_cap }),
157            txbf_cap: { self.txbf_cap }.intersect(&{ other.txbf_cap }),
158            asel_cap: { self.asel_cap }.intersect(&{ other.asel_cap }),
159        };
160        // IEEE Std. 802.11-2016, 10.17
161        // An STA can use rx_stbc if its peer supports tx_stbc. Similarly, an STA can use tx_stbc if
162        // its peer supports at least one(1) spatial stream for rx_stbc.
163        // TODO(https://fxbug.dev/42103849): Verify STBC behavior is correct.
164        out.ht_cap_info = out
165            .ht_cap_info
166            .with_tx_stbc(if { other.ht_cap_info }.rx_stbc() != 0 {
167                { self.ht_cap_info }.tx_stbc()
168            } else {
169                false
170            })
171            .with_rx_stbc(if { other.ht_cap_info }.tx_stbc() {
172                { self.ht_cap_info }.rx_stbc()
173            } else {
174                0
175            });
176        out
177    }
178}
179
180impl_intersect!(VhtCapabilitiesInfo {
181    // TODO(https://fxbug.dev/42104152): IEEE 802.11-2016 Table 9-250 - supported_cbw_set needs to consider ext_nss_bw
182    min: max_mpdu_len_raw, // MaxMpduLen(u8)
183    min: supported_cbw_set,
184    and: rx_ldpc,
185    and: sgi_cbw80,
186    and: sgi_cbw160,
187    and: tx_stbc,
188    min: rx_stbc,
189    and: su_bfer,
190    and: su_bfee,
191    min: bfee_sts,
192    min: num_sounding,
193    and: mu_bfer,
194    and: mu_bfee,
195    and: txop_ps,
196    and: htc_vht,
197    min: max_ampdu_exponent_raw, // MaxAmpduExponent(u8)
198    min: link_adapt_raw,         // VhtLinkAdaptation(u8)
199    and: rx_ant_pattern,
200    and: tx_ant_pattern,
201    min: ext_nss_bw,
202});
203
204impl Intersect for VhtMcsSet {
205    fn intersect(&self, other: &Self) -> Self {
206        if *self == Self::NONE || *other == Self::NONE {
207            Self::NONE
208        } else {
209            Self(min(self.0, other.0))
210        }
211    }
212}
213
214impl_intersect!(VhtMcsNssMap {
215    intersect: ss1, // VhtMcsSet(u8)
216    intersect: ss2, // VhtMcsSet(u8)
217    intersect: ss3, // VhtMcsSet(u8)
218    intersect: ss4, // VhtMcsSet(u8)
219    intersect: ss5, // VhtMcsSet(u8)
220    intersect: ss6, // VhtMcsSet(u8)
221    intersect: ss7, // VhtMcsSet(u8)
222    intersect: ss8, // VhtMcsSet(u8)
223});
224
225impl_intersect!(VhtMcsNssSet {
226    intersect: rx_max_mcs, // VhtMcsNssMap(u16)
227    min: rx_max_data_rate,
228    min: max_nsts,
229    intersect: tx_max_mcs, // VhtMcsNssMap(u16)
230    min: tx_max_data_rate,
231    and: ext_nss_bw,
232});
233
234impl Intersect for VhtCapabilities {
235    fn intersect(&self, other: &Self) -> Self {
236        Self {
237            vht_cap_info: { self.vht_cap_info }.intersect(&{ other.vht_cap_info }),
238            vht_mcs_nss: { self.vht_mcs_nss }.intersect(&{ other.vht_mcs_nss }),
239        }
240    }
241}
242
243pub struct ApRates<'a>(pub &'a [SupportedRate]);
244pub struct ClientRates<'a>(pub &'a [SupportedRate]);
245
246impl<'a> From<&'a [u8]> for ApRates<'a> {
247    fn from(rates: &'a [u8]) -> Self {
248        // This is always safe, as SupportedRate is a newtype of u8.
249        Self(Ref::into_ref(Ref::from_bytes(rates).unwrap()))
250    }
251}
252
253impl<'a> From<&'a [u8]> for ClientRates<'a> {
254    fn from(rates: &'a [u8]) -> Self {
255        // This is always safe, as SupportedRate is a newtype of u8.
256        Self(Ref::into_ref(Ref::from_bytes(rates).unwrap()))
257    }
258}
259
260#[derive(Eq, PartialEq, Debug)]
261pub enum IntersectRatesError {
262    BasicRatesMismatch,
263    NoApRatesSupported,
264}
265
266/// Returns the rates specified by the AP that are also supported by the client, with basic bits
267/// following their values in the AP.
268/// Returns Error if intersection fails.
269/// Note: The client MUST support ALL the basic rates specified by the AP or the intersection fails.
270pub fn intersect_rates(
271    ap_rates: ApRates<'_>,
272    client_rates: ClientRates<'_>,
273) -> Result<Vec<SupportedRate>, IntersectRatesError> {
274    // Omit BSS membership selectors, which should not be interpreted as BSS rates.
275    let mut rates: Vec<_> =
276        ap_rates.0.iter().copied().filter(|rate| !rate.is_bss_membership_selector()).collect();
277    let client_rates = client_rates.0.iter().map(|r| r.rate()).collect::<HashSet<_>>();
278    // The client MUST support ALL basic rates specified by the AP.
279    if rates.iter().any(|ra| ra.basic() && !client_rates.contains(&ra.rate())) {
280        return Err(IntersectRatesError::BasicRatesMismatch);
281    }
282
283    // Remove rates that are not supported by the client.
284    rates.retain(|ra| client_rates.contains(&ra.rate()));
285    if rates.is_empty() {
286        Err(IntersectRatesError::NoApRatesSupported)
287    } else {
288        Ok(rates)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    impl SupportedRate {
296        fn new_basic(rate: u8) -> Self {
297            Self(rate).with_basic(true)
298        }
299    }
300
301    #[test]
302    fn some_basic_rate_missing() {
303        // AP basic rate 120 is not supported, resulting in an Error
304        let error = intersect_rates(
305            ApRates(&[SupportedRate::new_basic(120), SupportedRate::new_basic(111)][..]),
306            ClientRates(&[SupportedRate(111)][..]),
307        )
308        .unwrap_err();
309        assert_eq!(error, IntersectRatesError::BasicRatesMismatch);
310    }
311
312    #[test]
313    fn all_basic_rates_supported() {
314        assert_eq!(
315            vec![SupportedRate::new_basic(120)],
316            intersect_rates(
317                ApRates(&[SupportedRate::new_basic(120), SupportedRate(111)][..]),
318                ClientRates(&[SupportedRate(120)][..])
319            )
320            .unwrap()
321        );
322    }
323
324    #[test]
325    fn all_basic_and_non_basic_rates_supported() {
326        assert_eq!(
327            vec![SupportedRate::new_basic(120)],
328            intersect_rates(
329                ApRates(&[SupportedRate::new_basic(120), SupportedRate(111)][..]),
330                ClientRates(&[SupportedRate(120)][..])
331            )
332            .unwrap()
333        );
334    }
335
336    #[test]
337    fn no_rates_are_supported() {
338        let error =
339            intersect_rates(ApRates(&[SupportedRate(120)][..]), ClientRates(&[][..])).unwrap_err();
340        assert_eq!(error, IntersectRatesError::NoApRatesSupported);
341    }
342
343    #[test]
344    fn preserve_ap_rates_basicness() {
345        // AP side 120 is not basic so the result should be non-basic.
346        assert_eq!(
347            vec![SupportedRate(120)],
348            intersect_rates(
349                ApRates(&[SupportedRate(120), SupportedRate(111)][..]),
350                ClientRates(&[SupportedRate::new_basic(120)][..])
351            )
352            .unwrap()
353        );
354    }
355
356    // TODO(https://fxbug.dev/42118992): Currently, MCS set and channel bandwidth are the most important ones. Revisit
357    // other fields when the use cases arise or we have more understanding.
358    #[test]
359    fn intersect_ht_cap_info_chan_width_set() {
360        let a = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_ONLY);
361        let b = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
362        assert_eq!(ChanWidthSet::TWENTY_ONLY, a.intersect(&b).chan_width_set());
363
364        let a = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
365        let b = HtCapabilityInfo(0).with_chan_width_set(ChanWidthSet::TWENTY_FORTY);
366        assert_eq!(ChanWidthSet::TWENTY_FORTY, a.intersect(&b).chan_width_set());
367    }
368
369    #[test]
370    fn intersect_supported_mcs_set() {
371        let a = SupportedMcsSet(0).with_rx_mcs_raw(0xffff);
372        let b = SupportedMcsSet(0).with_rx_mcs_raw(0x0304);
373        assert_eq!(RxMcsBitmask(0x0304), a.intersect(&b).rx_mcs());
374    }
375
376    #[test]
377    fn intersect_sm_power_save() {
378        assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::DISABLED));
379        assert_eq!(SmPowerSave::DISABLED, SmPowerSave::STATIC.intersect(&SmPowerSave::DISABLED));
380        assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::DISABLED));
381        assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::STATIC));
382        assert_eq!(SmPowerSave::DISABLED, SmPowerSave::DISABLED.intersect(&SmPowerSave::DYNAMIC));
383
384        assert_eq!(SmPowerSave::STATIC, SmPowerSave::STATIC.intersect(&SmPowerSave::DYNAMIC));
385        assert_eq!(SmPowerSave::STATIC, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::STATIC));
386
387        assert_eq!(SmPowerSave::DYNAMIC, SmPowerSave::DYNAMIC.intersect(&SmPowerSave::DYNAMIC));
388    }
389
390    #[test]
391    fn intersect_pco_transition() {
392        type PTT = PcoTransitionTime;
393        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_RESERVED));
394        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_400_USEC));
395        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_1500_USEC));
396        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_RESERVED.intersect(&PTT::PCO_5000_USEC));
397
398        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_400_USEC.intersect(&PTT::PCO_RESERVED));
399        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_1500_USEC.intersect(&PTT::PCO_RESERVED));
400        assert_eq!(PTT::PCO_RESERVED, PTT::PCO_5000_USEC.intersect(&PTT::PCO_RESERVED));
401
402        assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_5000_USEC));
403        assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_1500_USEC.intersect(&PTT::PCO_5000_USEC));
404        assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_5000_USEC));
405
406        assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_400_USEC));
407        assert_eq!(PTT::PCO_5000_USEC, PTT::PCO_5000_USEC.intersect(&PTT::PCO_1500_USEC));
408
409        assert_eq!(PTT::PCO_1500_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_1500_USEC));
410        assert_eq!(PTT::PCO_1500_USEC, PTT::PCO_1500_USEC.intersect(&PTT::PCO_400_USEC));
411
412        assert_eq!(PTT::PCO_400_USEC, PTT::PCO_400_USEC.intersect(&PTT::PCO_400_USEC));
413    }
414
415    #[test]
416    // Check TX_STBC and RX_STBC too because they involve multiple fields.
417    fn intersect_ht_cap_info_stbc() {
418        let mut ht_cap_a = fake_ht_capabilities();
419        let mut ht_cap_b = fake_ht_capabilities();
420
421        ht_cap_a.ht_cap_info = HtCapabilityInfo(0).with_tx_stbc(true).with_rx_stbc(2);
422        ht_cap_b.ht_cap_info = HtCapabilityInfo(0).with_tx_stbc(false).with_rx_stbc(1);
423
424        let intersected_ht_cap_info = ht_cap_a.intersect(&ht_cap_b).ht_cap_info;
425        assert_eq!(true, intersected_ht_cap_info.tx_stbc());
426        assert_eq!(0, intersected_ht_cap_info.rx_stbc());
427
428        let intersected_ht_cap_info = ht_cap_b.intersect(&ht_cap_a).ht_cap_info;
429        assert_eq!(false, intersected_ht_cap_info.tx_stbc());
430        assert_eq!(1, intersected_ht_cap_info.rx_stbc())
431    }
432
433    #[test]
434    fn intersect_vht_mcs_set() {
435        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_7));
436        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_8));
437        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::UP_TO_9));
438        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::NONE.intersect(&VhtMcsSet::NONE));
439        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_7.intersect(&VhtMcsSet::NONE));
440        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::NONE));
441        assert_eq!(VhtMcsSet::NONE, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::NONE));
442
443        assert_eq!(VhtMcsSet::UP_TO_7, VhtMcsSet::UP_TO_7.intersect(&VhtMcsSet::UP_TO_8));
444        assert_eq!(VhtMcsSet::UP_TO_7, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::UP_TO_7));
445
446        assert_eq!(VhtMcsSet::UP_TO_8, VhtMcsSet::UP_TO_8.intersect(&VhtMcsSet::UP_TO_9));
447        assert_eq!(VhtMcsSet::UP_TO_8, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::UP_TO_8));
448
449        assert_eq!(VhtMcsSet::UP_TO_9, VhtMcsSet::UP_TO_9.intersect(&VhtMcsSet::UP_TO_9));
450    }
451}