1use crate::key::Tk;
6use crate::{rsn_ensure, Error};
7use mundane::bytes;
8use std::hash::{Hash, Hasher};
9use wlan_common::ie::rsn::cipher::Cipher;
10
11#[derive(Debug)]
13pub struct GtkProvider(Gtk);
14
15impl GtkProvider {
16 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
17 pub fn new(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<GtkProvider, Error> {
18 Ok(GtkProvider(Gtk::generate_random(cipher, key_id, key_rsc)?))
19 }
20
21 pub fn get_gtk(&self) -> &Gtk {
22 &self.0
23 }
24}
25
26#[derive(Debug, Clone, Eq)]
27pub struct Gtk {
28 pub bytes: Box<[u8]>,
29 cipher: Cipher,
30 tk_len: usize,
31 key_id: u8,
32 key_rsc: u64,
33}
34
35impl PartialEq for Gtk {
40 fn eq(&self, other: &Self) -> bool {
41 self.bytes == other.bytes
42 && self.tk_len == other.tk_len
43 && self.key_id == other.key_id
44 && self.key_rsc == other.key_rsc
45 }
46}
47
48impl Hash for Gtk {
51 fn hash<H: Hasher>(&self, state: &mut H) {
52 self.key_id.hash(state);
53 self.tk().hash(state);
54 }
55}
56
57impl Gtk {
58 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
59 pub fn generate_random(cipher: Cipher, key_id: u8, key_rsc: u64) -> Result<Gtk, Error> {
60 rsn_ensure!(
62 0 < key_id && key_id < 4,
63 "GTK key ID must not be zero and must fit in a two bit field"
64 );
65
66 let tk_len: usize =
67 cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
68 let mut gtk_bytes: Box<[u8]> = vec![0; tk_len].into();
69 bytes::rand(&mut gtk_bytes[..]);
70
71 Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
72 }
73
74 #[allow(clippy::result_large_err, reason = "mass allow for https://fxbug.dev/381896734")]
75 pub fn from_bytes(
76 gtk_bytes: Box<[u8]>,
77 cipher: Cipher,
78 key_id: u8,
79 key_rsc: u64,
80 ) -> Result<Gtk, Error> {
81 rsn_ensure!(
83 0 < key_id && key_id < 4,
84 "GTK key ID must not be zero and must fit in a two bit field"
85 );
86
87 let tk_len: usize =
88 cipher.tk_bytes().ok_or(Error::GtkHierarchyUnsupportedCipherError)?.into();
89 rsn_ensure!(gtk_bytes.len() >= tk_len, "GTK must be larger than the resulting TK");
90
91 Ok(Gtk { bytes: gtk_bytes, cipher, tk_len, key_id, key_rsc })
92 }
93
94 pub fn cipher(&self) -> &Cipher {
95 &self.cipher
96 }
97
98 pub fn key_id(&self) -> u8 {
99 self.key_id
100 }
101
102 pub fn key_rsc(&self) -> u64 {
103 self.key_rsc
104 }
105}
106
107impl Tk for Gtk {
108 fn tk(&self) -> &[u8] {
109 &self.bytes[0..self.tk_len]
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use std::collections::HashSet;
117 use wlan_common::ie::rsn::cipher;
118 use wlan_common::ie::rsn::suite_selector::OUI;
119
120 #[test]
121 fn generated_gtks_are_not_zero_and_not_constant_with_high_probability() {
122 let mut gtks = HashSet::new();
123 for i in 0..10 {
124 let provider =
125 GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 2, 5)
126 .expect("failed creating GTK Provider");
127 let gtk_bytes: Box<[u8]> = provider.get_gtk().tk().into();
128 assert!(gtk_bytes.iter().any(|&x| x != 0));
129 if i > 0 && !gtks.contains(>k_bytes) {
130 return;
131 }
132 gtks.insert(gtk_bytes);
133 }
134 panic!("GtkProvider::generate_gtk() generated the same GTK 10 times in a row.");
135 }
136
137 #[test]
138 fn generated_gtk_captures_key_id() {
139 let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
140 .expect("failed creating GTK Provider");
141 let gtk = provider.get_gtk();
142 assert_eq!(gtk.key_id(), 1);
143 }
144
145 #[test]
146 fn generated_gtk_captures_key_rsc() {
147 let provider = GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 1, 3)
148 .expect("failed creating GTK Provider");
149 let gtk = provider.get_gtk();
150 assert_eq!(gtk.key_rsc(), 3);
151 }
152
153 #[test]
154 fn gtk_generation_fails_with_key_id_zero() {
155 GtkProvider::new(Cipher { oui: OUI, suite_type: cipher::CCMP_128 }, 0, 4)
156 .expect_err("GTK provider incorrectly accepts key ID 0");
157 }
158}