1use crate::key::Tk;
6use crate::{Error, rsn_ensure};
7use mundane::bytes;
8use std::hash::{Hash, Hasher};
9use wlan_common::ie::rsn::cipher::Cipher;
10
11#[derive(Debug)]
13pub struct GtkProvider(Gtk);
14
15impl GtkProvider {
16 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
17 pub fn new(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<GtkProvider, Error> {
18 Ok(GtkProvider(Gtk::generate_random(cipher, key_id, key_rsc)?))
19 }
20
21 pub fn get_gtk(&self) -> &Gtk {
22 &self.0
23 }
24}
25
26#[derive(Debug, Clone, Eq)]
27pub struct Gtk {
28 pub bytes: Box<[u8]>,
29 cipher: Cipher,
30 tk_len: usize,
31 key_id: u8,
32 key_rsc: u64,
33}
34
35impl PartialEq for Gtk {
40 fn eq(&self, other: &Self) -> bool {
41 self.bytes == other.bytes
42 && self.tk_len == other.tk_len
43 && self.key_id == other.key_id
44 && self.key_rsc == other.key_rsc
45 }
46}
47
48impl Hash for Gtk {
51 fn hash<H: Hasher>(&self, state: &mut H) {
52 self.key_id.hash(state);
53 self.tk().hash(state);
54 }
55}
56
57impl Gtk {
58 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
59 pub fn generate_random(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<Gtk, Error> {
60 rsn_ensure!(
62 0 < key_id && key_id < 4,
63 "GTK key ID must not be zero and must fit in a two bit field"
64 );
65
66 let tk_len: usize =
67 cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
68 let mut gtk_bytes: Box<[u8]> = vec![0; tk_len].into();
69 bytes::rand(&mut gtk_bytes[..]);
70
71 Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
72 }
73
74 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
75 pub fn from_bytes(
76 gtk_bytes: Box<[u8]>,
77 cipher: Cipher,
78 key_id: u8,
79 key_rsc: u64,
80 ) -> Result<Gtk, Error> {
81 rsn_ensure!(
83 0 < key_id && key_id < 4,
84 "GTK key ID must not be zero and must fit in a two bit field"
85 );
86
87 let tk_len: usize =
88 cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
89 rsn_ensure!(gtk_bytes.len() >= tk_len, "GTK must be larger than the resulting TK");
91
92 Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
93 }
94
95 pub fn cipher(&self) -> &Cipher {
96 &self.cipher
97 }
98
99 pub fn key_id(&self) -> u8 {
100 self.key_id
101 }
102
103 pub fn key_rsc(&self) -> u64 {
104 self.key_rsc
105 }
106}
107
108impl Tk for Gtk {
109 fn tk(&self) -> &[u8] {
110 &self.bytes[0..self.tk_len]
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use std::collections::HashSet;
118 use wlan_common::ie::rsn::cipher;
119 use wlan_common::ie::rsn::suite_selector::OUI;
120
121 #[test]
122 fn generated_gtks_are_not_zero_and_not_constant_with_high_probability() {
123 let mut gtks = HashSet::new();
124 for i in 0..10 {
125 let provider =
126 GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 2, 5)
127 .expect("failed creating GTK Provider");
128 let gtk_bytes: Box<[u8]> = provider.get_gtk().tk().into();
129 assert!(gtk_bytes.iter().any(|&x| x != 0));
130 if i > 0 && !gtks.contains(>k_bytes) {
131 return;
132 }
133 gtks.insert(gtk_bytes);
134 }
135 panic!("GtkProvider::generate_gtk() generated the same GTK 10 times in a row.");
136 }
137
138 #[test]
139 fn generated_gtk_captures_key_id() {
140 let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
141 .expect("failed creating GTK Provider");
142 let gtk = provider.get_gtk();
143 assert_eq!(gtk.key_id(), 1);
144 }
145
146 #[test]
147 fn generated_gtk_captures_key_rsc() {
148 let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
149 .expect("failed creating GTK Provider");
150 let gtk = provider.get_gtk();
151 assert_eq!(gtk.key_rsc(), 3);
152 }
153
154 #[test]
155 fn gtk_generation_fails_with_key_id_zero() {
156 GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 0, 4)
157 .expect_err("GTK provider incorrectly accepts key ID 0");
158 }
159}