1use crate::channel::{Cbw, Channel};
12use crate::ie::intersect::*;
13use crate::ie::{
14    self, HtCapabilities, SupportedRate, VhtCapabilities, parse_ht_capabilities,
15    parse_vht_capabilities,
16};
17use crate::mac::CapabilityInfo;
18use anyhow::{Context as _, Error, format_err};
19use {fidl_fuchsia_wlan_ieee80211 as fidl_ieee80211, fidl_fuchsia_wlan_mlme as fidl_mlme};
20
21const OVERRIDE_CAP_INFO_ESS: bool = true;
25const OVERRIDE_CAP_INFO_IBSS: bool = false;
26
27const OVERRIDE_CAP_INFO_CF_POLLABLE: bool = false;
30const OVERRIDE_CAP_INFO_CF_POLL_REQUEST: bool = false;
31
32const OVERRIDE_CAP_INFO_PRIVACY: bool = false;
35
36const OVERRIDE_CAP_INFO_SPECTRUM_MGMT: bool = false;
38
39const OVERRIDE_HT_CAP_INFO_TX_STBC: bool = false;
41
42const OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET: u32 = 0;
47
48fn override_capability_info(capability_info: CapabilityInfo) -> CapabilityInfo {
51    capability_info
52        .with_ess(OVERRIDE_CAP_INFO_ESS)
53        .with_ibss(OVERRIDE_CAP_INFO_IBSS)
54        .with_cf_pollable(OVERRIDE_CAP_INFO_CF_POLLABLE)
55        .with_cf_poll_req(OVERRIDE_CAP_INFO_CF_POLL_REQUEST)
56        .with_privacy(OVERRIDE_CAP_INFO_PRIVACY)
57        .with_spectrum_mgmt(OVERRIDE_CAP_INFO_SPECTRUM_MGMT)
58}
59
60pub fn derive_join_capabilities(
65    bss_channel: Channel,
66    bss_rates: &[SupportedRate],
67    device_info: &fidl_mlme::DeviceInfo,
68) -> Result<ClientCapabilities, Error> {
69    let band_cap = get_band_cap_for_channel(&device_info.bands[..], bss_channel)
71        .context(format!("iface does not support BSS channel {}", bss_channel.primary))?;
72
73    let capability_info =
77        override_capability_info(CapabilityInfo(device_info.softmac_hardware_capability as u16));
78
79    let client_rates = band_cap.basic_rates.iter().map(|&r| SupportedRate(r)).collect::<Vec<_>>();
82    let rates = intersect_rates(ApRates(bss_rates), ClientRates(&client_rates))
83        .map_err(|error| format_err!("could not intersect rates: {:?}", error))
84        .context(format!("deriving rates: {:?} + {:?}", band_cap.basic_rates, bss_rates))?;
85
86    let (ht_cap, vht_cap) =
89        override_ht_vht(band_cap.ht_cap.as_ref(), band_cap.vht_cap.as_ref(), bss_channel.cbw)?;
90
91    Ok(ClientCapabilities(StaCapabilities { capability_info, rates, ht_cap, vht_cap }))
92}
93
94fn override_ht_vht(
97    fidl_ht_cap: Option<&Box<fidl_ieee80211::HtCapabilities>>,
98    fidl_vht_cap: Option<&Box<fidl_ieee80211::VhtCapabilities>>,
99    cbw: Cbw,
100) -> Result<(Option<HtCapabilities>, Option<VhtCapabilities>), Error> {
101    if fidl_ht_cap.is_none() && fidl_vht_cap.is_some() {
102        return Err(format_err!("VHT Cap without HT Cap is invalid."));
103    }
104
105    let ht_cap = match fidl_ht_cap {
106        Some(h) => {
107            let ht_cap = *parse_ht_capabilities(&h.bytes[..]).context("verifying HT Cap")?;
108            Some(override_ht_capabilities(ht_cap, cbw))
109        }
110        None => None,
111    };
112
113    let vht_cap = match fidl_vht_cap {
114        Some(v) => {
115            let vht_cap = *parse_vht_capabilities(&v.bytes[..]).context("verifying VHT Cap")?;
116            Some(override_vht_capabilities(vht_cap, cbw))
117        }
118        None => None,
119    };
120    Ok((ht_cap, vht_cap))
121}
122
123fn override_ht_capabilities(mut ht_cap: HtCapabilities, cbw: Cbw) -> HtCapabilities {
126    let mut ht_cap_info = ht_cap.ht_cap_info.with_tx_stbc(OVERRIDE_HT_CAP_INFO_TX_STBC);
127    match cbw {
128        Cbw::Cbw20 => ht_cap_info.set_chan_width_set(ie::ChanWidthSet::TWENTY_ONLY),
129        _ => (),
130    }
131    ht_cap.ht_cap_info = ht_cap_info;
132    ht_cap
133}
134
135fn override_vht_capabilities(mut vht_cap: VhtCapabilities, cbw: Cbw) -> VhtCapabilities {
138    let mut vht_cap_info = vht_cap.vht_cap_info;
139    if vht_cap_info.supported_cbw_set() != OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET {
140        match cbw {
145            Cbw::Cbw160 | Cbw::Cbw80P80 { secondary80: _ } => (),
146            _ => vht_cap_info.set_supported_cbw_set(OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET),
147        }
148    }
149    vht_cap.vht_cap_info = vht_cap_info;
150    vht_cap
151}
152
153pub fn get_band_cap_for_channel(
154    bands: &[fidl_mlme::BandCapability],
155    channel: Channel,
156) -> Result<&fidl_mlme::BandCapability, anyhow::Error> {
157    let target = channel.get_band().context("Failed to retrieve band capabilities")?;
158    bands
159        .iter()
160        .find(|b| b.band == target && b.operating_channels.contains(&channel.primary))
161        .ok_or_else(|| format_err!("No band capability for channel {channel:?}: {bands:?}"))
162}
163
164#[derive(Debug, PartialEq)]
170pub struct StaCapabilities {
171    pub capability_info: CapabilityInfo,
172    pub rates: Vec<SupportedRate>,
173    pub ht_cap: Option<HtCapabilities>,
174    pub vht_cap: Option<VhtCapabilities>,
175}
176
177#[derive(Debug, PartialEq)]
178pub struct ClientCapabilities(pub StaCapabilities);
179#[derive(Debug, PartialEq)]
180pub struct ApCapabilities(pub StaCapabilities);
181
182pub fn intersect_with_ap_as_client(
184    client: &ClientCapabilities,
185    ap: &ApCapabilities,
186) -> Result<StaCapabilities, Error> {
187    let rates = intersect_rates(ApRates(&ap.0.rates[..]), ClientRates(&client.0.rates[..]))
188        .map_err(|e| format_err!("could not intersect rates: {:?}", e))?;
189    let (capability_info, ht_cap, vht_cap) = intersect(&client.0, &ap.0);
190    Ok(StaCapabilities { rates, capability_info, ht_cap, vht_cap })
191}
192
193pub fn intersect_with_remote_client_as_ap(
195    ap: &ApCapabilities,
196    remote_client: &ClientCapabilities,
197) -> StaCapabilities {
198    let rates = intersect_rates(ApRates(&ap.0.rates[..]), ClientRates(&remote_client.0.rates[..]))
200        .unwrap_or(vec![]);
201    let (capability_info, ht_cap, vht_cap) = intersect(&ap.0, &remote_client.0);
202    StaCapabilities { rates, capability_info, ht_cap, vht_cap }
203}
204
205fn intersect(
206    ours: &StaCapabilities,
207    theirs: &StaCapabilities,
208) -> (CapabilityInfo, Option<HtCapabilities>, Option<VhtCapabilities>) {
209    let capability_info = CapabilityInfo(ours.capability_info.raw() & theirs.capability_info.raw());
211    let ht_cap = match (ours.ht_cap, theirs.ht_cap) {
212        (Some(ours), Some(theirs)) => Some(ours.intersect(&theirs)),
214        _ => None,
215    };
216    let vht_cap = match (ours.vht_cap, theirs.vht_cap) {
217        (Some(ours), Some(theirs)) => Some(ours.intersect(&theirs)),
219        _ => None,
220    };
221    (capability_info, ht_cap, vht_cap)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::mac;
228    use crate::test_utils::fake_capabilities::fake_5ghz_band_capability_ht;
229    use assert_matches::assert_matches;
230    use fidl_fuchsia_wlan_common as fidl_common;
231
232    #[test]
233    fn test_build_cap_info() {
234        let capability_info = CapabilityInfo(0)
235            .with_ess(!OVERRIDE_CAP_INFO_ESS)
236            .with_ibss(!OVERRIDE_CAP_INFO_IBSS)
237            .with_cf_pollable(!OVERRIDE_CAP_INFO_CF_POLLABLE)
238            .with_cf_poll_req(!OVERRIDE_CAP_INFO_CF_POLL_REQUEST)
239            .with_privacy(!OVERRIDE_CAP_INFO_PRIVACY)
240            .with_spectrum_mgmt(!OVERRIDE_CAP_INFO_SPECTRUM_MGMT);
241        let capability_info = override_capability_info(capability_info);
242        assert_eq!(capability_info.ess(), OVERRIDE_CAP_INFO_ESS);
243        assert_eq!(capability_info.ibss(), OVERRIDE_CAP_INFO_IBSS);
244        assert_eq!(capability_info.cf_pollable(), OVERRIDE_CAP_INFO_CF_POLLABLE);
245        assert_eq!(capability_info.cf_poll_req(), OVERRIDE_CAP_INFO_CF_POLL_REQUEST);
246        assert_eq!(capability_info.privacy(), OVERRIDE_CAP_INFO_PRIVACY);
247        assert_eq!(capability_info.spectrum_mgmt(), OVERRIDE_CAP_INFO_SPECTRUM_MGMT);
248    }
249
250    #[test]
251    fn test_override_ht_cap() {
252        let mut ht_cap = ie::fake_ht_capabilities();
253        let ht_cap_info = ht_cap
254            .ht_cap_info
255            .with_tx_stbc(!OVERRIDE_HT_CAP_INFO_TX_STBC)
256            .with_chan_width_set(ie::ChanWidthSet::TWENTY_FORTY);
257        ht_cap.ht_cap_info = ht_cap_info;
258        let mut channel = Channel { primary: 153, cbw: Cbw::Cbw20 };
259
260        let ht_cap_info = override_ht_capabilities(ht_cap, channel.cbw).ht_cap_info;
261        assert_eq!(ht_cap_info.tx_stbc(), OVERRIDE_HT_CAP_INFO_TX_STBC);
262        assert_eq!(ht_cap_info.chan_width_set(), ie::ChanWidthSet::TWENTY_ONLY);
263
264        channel.cbw = Cbw::Cbw40;
265        let ht_cap_info = override_ht_capabilities(ht_cap, channel.cbw).ht_cap_info;
266        assert_eq!(ht_cap_info.chan_width_set(), ie::ChanWidthSet::TWENTY_FORTY);
267    }
268
269    #[test]
270    fn test_override_vht_cap() {
271        let mut vht_cap = ie::fake_vht_capabilities();
272        let vht_cap_info = vht_cap.vht_cap_info.with_supported_cbw_set(2);
273        vht_cap.vht_cap_info = vht_cap_info;
274        let mut channel = Channel { primary: 153, cbw: Cbw::Cbw20 };
275
276        let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
279        assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
280
281        channel.cbw = Cbw::Cbw40;
282        let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
283        assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
284
285        channel.cbw = Cbw::Cbw80;
286        let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
287        assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
288
289        channel.cbw = Cbw::Cbw160;
292        let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
293        assert_eq!(vht_cap_info.supported_cbw_set(), 2);
294
295        channel.cbw = Cbw::Cbw80P80 { secondary80: 42 };
296        let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
297        assert_eq!(vht_cap_info.supported_cbw_set(), 2);
298    }
299
300    #[test]
301    fn test_get_device_band_cap() {
302        let device_info = fidl_mlme::DeviceInfo {
303            sta_addr: [0; 6],
304            role: fidl_common::WlanMacRole::Client,
305            bands: vec![fake_5ghz_band_capability_ht(ie::ChanWidthSet::TWENTY_FORTY)],
306            softmac_hardware_capability: 0,
307            qos_capable: true,
308        };
309        assert_eq!(
310            fidl_ieee80211::WlanBand::FiveGhz,
311            get_band_cap_for_channel(&device_info.bands[..], Channel::new(36, Cbw::Cbw20))
312                .unwrap()
313                .band
314        );
315    }
316
317    fn fake_client_join_cap() -> ClientCapabilities {
318        ClientCapabilities(StaCapabilities {
319            capability_info: mac::CapabilityInfo(0x1234),
320            rates: [101, 102, 103, 104].iter().cloned().map(SupportedRate).collect(),
321            ht_cap: Some(HtCapabilities {
322                ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(2).with_tx_stbc(false),
323                ..ie::fake_ht_capabilities()
324            }),
325            vht_cap: Some(ie::fake_vht_capabilities()),
326        })
327    }
328
329    fn fake_ap_join_cap() -> ApCapabilities {
330        ApCapabilities(StaCapabilities {
331            capability_info: mac::CapabilityInfo(0x4321),
332            rates: [101 + 128, 102, 9].iter().cloned().map(SupportedRate).collect(),
334            ht_cap: Some(HtCapabilities {
335                ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(1).with_tx_stbc(true),
336                ..ie::fake_ht_capabilities()
337            }),
338            vht_cap: Some(ie::fake_vht_capabilities()),
339        })
340    }
341
342    #[test]
343    fn client_intersect_with_ap() {
344        let caps = assert_matches!(
345            intersect_with_ap_as_client(&fake_client_join_cap(), &fake_ap_join_cap()),
346            Ok(caps) => caps
347        );
348        assert_eq!(
349            caps,
350            StaCapabilities {
351                capability_info: mac::CapabilityInfo(0x0220),
352                rates: [229, 102].iter().cloned().map(SupportedRate).collect(),
353                ht_cap: Some(HtCapabilities {
354                    ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(2).with_tx_stbc(false),
355                    ..ie::fake_ht_capabilities()
356                }),
357                ..fake_client_join_cap().0
358            }
359        )
360    }
361
362    #[test]
363    fn ap_intersect_with_remote_client() {
364        assert_eq!(
365            intersect_with_remote_client_as_ap(&fake_ap_join_cap(), &fake_client_join_cap()),
366            StaCapabilities {
367                capability_info: mac::CapabilityInfo(0x0220),
368                rates: [229, 102].iter().cloned().map(SupportedRate).collect(),
369                ht_cap: Some(HtCapabilities {
370                    ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(0).with_tx_stbc(true),
371                    ..ie::fake_ht_capabilities()
372                }),
373                ..fake_ap_join_cap().0
374            }
375        );
376    }
377}