use crate::key::Tk;
use crate::{rsn_ensure, Error};
use mundane::bytes;
use std::hash::{Hash, Hasher};
use wlan_common::ie::rsn::cipher::Cipher;
#[derive(Debug)]
pub struct GtkProvider(Gtk);
impl GtkProvider {
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn new(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<GtkProvider, Error> {
Ok(GtkProvider(Gtk::generate_random(cipher, key_id, key_rsc)?))
}
pub fn get_gtk(&self) -> &Gtk {
&self.0
}
}
#[derive(Debug, Clone, Eq)]
pub struct Gtk {
pub bytes: Box<[u8]>,
cipher: Cipher,
tk_len: usize,
key_id: u8,
key_rsc: u64,
}
impl PartialEq for Gtk {
fn eq(&self, other: &Self) -> bool {
self.bytes == other.bytes
&& self.tk_len == other.tk_len
&& self.key_id == other.key_id
&& self.key_rsc == other.key_rsc
}
}
impl Hash for Gtk {
fn hash<H: Hasher>(&self, state: &mut H) {
self.key_id.hash(state);
self.tk().hash(state);
}
}
impl Gtk {
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn generate_random(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<Gtk, Error> {
rsn_ensure!(
0 < key_id && key_id < 4,
"GTK key ID must not be zero and must fit in a two bit field"
);
let tk_len: usize =
cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
let mut gtk_bytes: Box<[u8]> = vec![0; tk_len].into();
bytes::rand(&mut gtk_bytes[..]);
Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
}
#[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
pub fn from_bytes(
gtk_bytes: Box<[u8]>,
cipher: Cipher,
key_id: u8,
key_rsc: u64,
) -> Result<Gtk, Error> {
rsn_ensure!(
0 < key_id && key_id < 4,
"GTK key ID must not be zero and must fit in a two bit field"
);
let tk_len: usize =
cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
rsn_ensure!(gtk_bytes.len() >= tk_len, "GTK must be larger than the resulting TK");
Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
}
pub fn cipher(&self) -> &Cipher {
&self.cipher
}
pub fn key_id(&self) -> u8 {
self.key_id
}
pub fn key_rsc(&self) -> u64 {
self.key_rsc
}
}
impl Tk for Gtk {
fn tk(&self) -> &[u8] {
&self.bytes[0..self.tk_len]
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use wlan_common::ie::rsn::cipher;
use wlan_common::ie::rsn::suite_selector::OUI;
#[test]
fn generated_gtks_are_not_zero_and_not_constant_with_high_probability() {
let mut gtks = HashSet::new();
for i in 0..10 {
let provider =
GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 2, 5)
.expect("failed creating GTK Provider");
let gtk_bytes: Box<[u8]> = provider.get_gtk().tk().into();
assert!(gtk_bytes.iter().any(|&x| x != 0));
if i > 0 && !gtks.contains(>k_bytes) {
return;
}
gtks.insert(gtk_bytes);
}
panic!("GtkProvider::generate_gtk() generated the same GTK 10 times in a row.");
}
#[test]
fn generated_gtk_captures_key_id() {
let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
.expect("failed creating GTK Provider");
let gtk = provider.get_gtk();
assert_eq!(gtk.key_id(), 1);
}
#[test]
fn generated_gtk_captures_key_rsc() {
let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
.expect("failed creating GTK Provider");
let gtk = provider.get_gtk();
assert_eq!(gtk.key_rsc(), 3);
}
#[test]
fn gtk_generation_fails_with_key_id_zero() {
GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 0, 4)
.expect_err("GTK provider incorrectly accepts key ID 0");
}
}