1use mundane::hash::{Digest, Hasher};
6use mundane::hmac::Hmac;
7
8fn div_ceil(x: usize, y: usize) -> usize {
15 (x + (y - 1)) / y
16}
17
18pub fn kdf_hash_length<H>(key: &[u8], label: &str, context: &[u8], bits: usize) -> Vec<u8>
21where
22 H: Hasher,
23 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]> + AsMut<[u8]>,
24{
25 let byte_length = div_ceil(bits, 8);
26 let iterations = div_ceil(bits, H::Digest::DIGEST_LEN * 8);
27 let mut result = Vec::with_capacity(byte_length);
28 let mut copied: usize = 0;
29 for i in 1..=iterations {
30 let to_copy = std::cmp::min(H::Digest::DIGEST_LEN, byte_length - copied);
31 {
32 let mut hmac: Hmac<H> = Hmac::new(key);
33 hmac.update(&(i as u16).to_le_bytes());
34 hmac.update(label.as_bytes());
35 hmac.update(context);
36 hmac.update(&(bits as u16).to_le_bytes());
37 let mut digest = hmac.finish().bytes();
38 digest.as_mut()[to_copy..H::Digest::DIGEST_LEN].fill(0);
40 result.extend_from_slice(&digest.as_ref()[..to_copy]);
41 }
42 copied += to_copy;
43 }
44 result
45}
46
47pub fn hkdf_extract<H>(salt: &[u8], ikm: &[u8]) -> Vec<u8>
50where
51 H: Hasher,
52 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]>,
53{
54 let mut hmac: Hmac<H> = Hmac::new(salt);
57 hmac.update(ikm);
58 hmac.finish().bytes().as_ref().to_vec()
59}
60
61pub fn hkdf_expand<H>(prk: &[u8], info: &str, length: usize) -> Vec<u8>
64where
65 H: Hasher,
66 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]>,
67{
68 let mut result: Vec<u8> = Vec::with_capacity(length);
69 let mut copied: usize = 0;
70 let mut prev_digest: Option<H::Digest> = None;
71 let digest_count = div_ceil(length, H::Digest::DIGEST_LEN);
72 for counter in 1..digest_count + 1 {
73 let to_copy = std::cmp::min(H::Digest::DIGEST_LEN, length - copied);
74 let digest = {
75 let mut hmac: Hmac<H> = Hmac::new(prk);
76 if let Some(prev_digest) = &prev_digest {
77 hmac.update(prev_digest.bytes().as_ref());
78 }
79 hmac.update(info.as_bytes());
80 hmac.update(&(counter as u8).to_le_bytes());
81 hmac.finish()
82 };
83 result.extend_from_slice(&digest.bytes().as_ref()[..to_copy]);
84 copied += to_copy;
85 let _ = prev_digest.insert(digest);
86 }
87 result
88}
89
90pub fn confirm<H>(key: &[u8], counter: u16, data: &[&[u8]]) -> Vec<u8>
93where
94 H: Hasher,
95 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]>,
96{
97 let mut hmac: Hmac<H> = Hmac::new(key);
98 hmac.update(&counter.to_le_bytes());
99 for data_part in data {
100 hmac.update(data_part);
101 }
102 hmac.finish().bytes().as_ref().to_vec()
103}
104
105pub trait HmacUtils {
107 fn bits(&self) -> usize;
109 fn kdf_hash_length(&self, key: &[u8], label: &str, context: &[u8], bits: usize) -> Vec<u8>;
111 fn hkdf_extract(&self, salt: &[u8], ikm: &[u8]) -> Vec<u8>;
113 fn hkdf_expand(&self, prk: &[u8], info: &str, length: usize) -> Vec<u8>;
115 fn confirm(&self, key: &[u8], counter: u16, data: &[&[u8]]) -> Vec<u8>;
117}
118
119#[derive(Debug, Clone)]
121pub struct HmacUtilsImpl<H>
122where
123 H: Hasher,
124 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]>,
125{
126 hasher_type: std::marker::PhantomData<fn(H)>,
128}
129
130impl<H> HmacUtilsImpl<H>
131where
132 H: Hasher,
133 <<H as Hasher>::Digest as Digest>::Bytes: AsRef<[u8]>,
134{
135 pub fn new() -> Self {
136 Self { hasher_type: std::marker::PhantomData }
137 }
138}
139
140impl<H> HmacUtils for HmacUtilsImpl<H>
141where
142 H: Hasher,
143 <<H as Hasher>::Digest as Digest>::Bytes: AsMut<[u8]> + AsRef<[u8]>,
144{
145 fn bits(&self) -> usize {
146 H::Digest::DIGEST_LEN * 8
147 }
148
149 fn kdf_hash_length(&self, key: &[u8], label: &str, context: &[u8], bits: usize) -> Vec<u8> {
150 kdf_hash_length::<H>(key, label, context, bits)
151 }
152
153 fn hkdf_extract(&self, salt: &[u8], ikm: &[u8]) -> Vec<u8> {
154 hkdf_extract::<H>(salt, ikm)
155 }
156
157 fn hkdf_expand(&self, prk: &[u8], info: &str, length: usize) -> Vec<u8> {
158 hkdf_expand::<H>(prk, info, length)
159 }
160
161 fn confirm(&self, key: &[u8], counter: u16, data: &[&[u8]]) -> Vec<u8> {
162 confirm::<H>(key, counter, data)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::boringssl::{Bignum, BignumCtx, EcGroup, EcGroupId};
170
171 use mundane::hash::Sha256;
172 use std::convert::TryFrom;
173
174 const TEST_GROUP: EcGroupId = EcGroupId::P256;
179 const TEST_LABEL: &'static str = "SAE Hunting and Pecking";
180 const TEST_H_1: &'static str =
181 "a9025368ef78f7d65e8d4d556f0d1d0d758f2f7f1e116eb1d11307a7e8a9621a";
182 const TEST_CAND_1: &'static str =
183 "b8e89a725c57f18e8f68a7f72613e15f1c904938c38800efa01f1306f5e454b5";
184 const TEST_H_2: &'static str =
185 "954bbbf8923284e4ca164e3af0b9520ce53aa35be39020e9ccb23aff86df2226";
186 const TEST_CAND_2: &'static str =
187 "da6eb7b06a1ac5624974f90afdd6a8e9d5722634cf987c34defc91a9874e5658";
188 const TEST_H_40: &'static str =
189 "cde7b81eb539c87af5bf1be2402d315c45ad4c3db06c9c56b7f8b7daae5e5842";
190 const TEST_CAND_40: &'static str =
191 "2e12a1d615647963fd7aa4a905fd51b6f49a902fd917ef8f0ff200102699ecdb";
192 const TEST_SSID: &'static str = "byteme";
193 const TEST_IDENTIFIER: &'static str = "psk4internet";
194 const TEST_PASSWORD: &'static str = "mekmitasdigoat";
195 const TEST_INFO_1: &'static str = "SAE Hash to Element u1 P1";
196 const TEST_INFO_2: &'static str = "SAE Hash to Element u2 P2";
197 const TEST_PRK: &'static str =
198 "3bd53fe9223dc0280fbfce17d7a3564064e20f48c6ec72246ce367b5569a22af";
199 const TEST_OKM_1: &'static str = "a5044469ab16f25b6abf1e0e37a36b56f50be73369053df8db87989a6b66fd1a\
200 491f1cdacbd07931620f83008ffc0ecc";
201 const TEST_OKM_2: &'static str = "9b4e0d5b1879f253c5319615099b05aec5b06fa5e788bcfd1e9ea60d33436927\
202 190814c322a62585c93c577bbaa3d307";
203
204 const TEST_LABEL_2: &'static str = "SAE KCK and PMK";
208 const TEST_LOCAL_COMMIT: &'static str = "1300eb3bab1964e4a0ab05925ddf3339519138bc65d6cdc0f813dd6fd4344eb4\
209 bfe44b5c21597658f4e3eddfb4b99f25b4d6540f32ff1fd5c530c60a79444861\
210 0bc6de3d92bdbbd47d935980ca6cf8988ab6630be6764c885ceb9793970f6952\
211 17eeff0d2170736b34696e7465726e6574";
212 const TEST_PEER_COMMIT: &'static str = "13005564f045b2ea1e566cf1dd741f70d9be35d2df5b9a5502946ee03cf8dae2\
213 7e1e05b8430eb7a99e24877ce69baf3dc580e309633d6b385f83ee1c3ec3591f\
214 1a5393c06e805ddceb2fde50930dd7cfebb987c6ff9666af164eb5184d8e6662\
215 ed6aff0d2170736b34696e7465726e6574";
216 const TEST_KEYSEED: &'static str =
217 "7457a00754dcc4e3dc2850c124d6bb8fa1699d7fa33bb0667d9c34eeb513deb9";
218 const TEST_KCK: &'static str =
219 "599d6f1e27548be8499dceed2feccf94818ce1c79f1b4eb3d6a53228a09bf3ed";
220 const TEST_LOCAL_CONFIRM: &'static str = "010012d9d5c78c500526d36c41dbc56aedf2914cedddd7cad4a58c48f83dbde9\
221 fc77";
222 const TEST_PEER_CONFIRM: &'static str = "010002871cf906898b8060ec184143be77b8c08a8019b13eb6d0aef0d8383dfa\
223 c2fd";
224 const TEST_PMK: &'static str =
225 "7aead86fba4c3221fc437f5f14d70d854ea5d5aac1690116793081eda4d557c5";
226
227 #[test]
228 fn test_kdf_sha256_256() {
229 let bignumctx = BignumCtx::new().unwrap();
230 let group = EcGroup::new(TEST_GROUP).unwrap();
231 let p = group.get_params(&bignumctx).unwrap().p;
232 let p_vec = p.to_be_vec(p.len());
233 let p_bits = p.bits();
234
235 let cand_1 =
236 kdf_hash_length::<Sha256>(&hex::decode(TEST_H_1).unwrap(), TEST_LABEL, &p_vec, p_bits);
237 assert_eq!(hex::encode(&cand_1), TEST_CAND_1);
238 let cand_2 =
239 kdf_hash_length::<Sha256>(&hex::decode(TEST_H_2).unwrap(), TEST_LABEL, &p_vec, p_bits);
240 assert_eq!(hex::encode(&cand_2), TEST_CAND_2);
241 let cand_40 =
242 kdf_hash_length::<Sha256>(&hex::decode(TEST_H_40).unwrap(), TEST_LABEL, &p_vec, p_bits);
243 assert_eq!(hex::encode(&cand_40), TEST_CAND_40);
244 }
245
246 #[test]
247 fn test_kdf_sha256_512() {
248 let bignumctx = BignumCtx::new().unwrap();
249 let group = EcGroup::new(TEST_GROUP).unwrap();
250 let r = group.get_order(&bignumctx).unwrap();
251
252 let local_commit_scalar = hex::decode(TEST_LOCAL_COMMIT).unwrap();
256 let local_commit_scalar = &local_commit_scalar[2..2 + r.len()];
257 let local_commit_scalar = Bignum::new_from_slice(local_commit_scalar).unwrap();
258 let peer_commit_scalar = hex::decode(TEST_PEER_COMMIT).unwrap();
259 let peer_commit_scalar = &peer_commit_scalar[2..2 + r.len()];
260 let peer_commit_scalar = Bignum::new_from_slice(peer_commit_scalar).unwrap();
261
262 let context = local_commit_scalar.mod_add(&peer_commit_scalar, &r, &bignumctx).unwrap();
263 let q = 256;
264 let kck_and_pmk = kdf_hash_length::<Sha256>(
265 &hex::decode(TEST_KEYSEED).unwrap(),
266 TEST_LABEL_2,
267 &context.to_be_vec(r.len()),
268 q + 256,
269 );
270 assert_eq!(kck_and_pmk.len(), (q + 256) / 8);
271 assert_eq!(hex::encode(&kck_and_pmk[0..q / 8]), TEST_KCK);
272 assert_eq!(hex::encode(&kck_and_pmk[q / 8..(q + 256) / 8]), TEST_PMK);
273 }
274
275 #[test]
276 fn test_kdf_sha256_short() {
277 let key = hex::decode("f0f0f0f0").unwrap();
278 let label = "LABELED!";
279 let context = hex::decode("babababa").unwrap();
280 let hash = kdf_hash_length::<Sha256>(&key[..], label, &context[..], 128);
281 assert_eq!(hash.len(), 16);
282 }
283
284 #[test]
285 fn test_kdf_sha256_empty_data() {
286 let key = hex::decode("f0f0f0f0").unwrap();
287 let label = "LABELED!";
288 let context = hex::decode("babababa").unwrap();
289 let hash = kdf_hash_length::<Sha256>(&key[..], label, &context[..], 0);
290 assert_eq!(hash.len(), 0);
291 }
292
293 #[test]
294 fn test_kdf_sha256_all_empty() {
295 let key = vec![];
296 let label = "";
297 let context = vec![];
298 let hash = kdf_hash_length::<Sha256>(&key[..], label, &context[..], 0);
299 assert_eq!(hash.len(), 0);
300 }
301
302 #[test]
303 fn test_hkdf_extract() {
304 let mut password_with_id: String = String::from(TEST_PASSWORD);
305 password_with_id.push_str(TEST_IDENTIFIER);
306 let pwd_seed = hkdf_extract::<Sha256>(TEST_SSID.as_bytes(), password_with_id.as_bytes());
307 assert_eq!(hex::encode(pwd_seed), TEST_PRK);
308 }
309
310 #[test]
311 fn test_hkdf_expand() {
312 let bignumctx = BignumCtx::new().unwrap();
313 let group = EcGroup::new(TEST_GROUP).unwrap();
314 let p = group.get_params(&bignumctx).unwrap().p;
315 let p_len = p.len();
316
317 let okm_1 = hkdf_expand::<Sha256>(
318 &hex::decode(TEST_PRK).unwrap(),
319 TEST_INFO_1,
320 p_len + (p_len / 2),
321 );
322 assert_eq!(hex::encode(&okm_1), TEST_OKM_1);
323 let okm_2 = hkdf_expand::<Sha256>(
324 &hex::decode(TEST_PRK).unwrap(),
325 TEST_INFO_2,
326 p_len + (p_len / 2),
327 );
328 assert_eq!(hex::encode(&okm_2), TEST_OKM_2);
329 }
330
331 #[test]
332 fn test_confirm() {
333 let bignumctx = BignumCtx::new().unwrap();
334 let group = EcGroup::new(TEST_GROUP).unwrap();
335 let r = group.get_order(&bignumctx).unwrap();
336
337 let local_commit_bytes = hex::decode(TEST_LOCAL_COMMIT).unwrap();
341 let local_commit_scalar = &local_commit_bytes[2..2 + r.len()];
342 let local_commit_element = &local_commit_bytes[2 + r.len()..2 + r.len() * 3];
343 let peer_commit_bytes = hex::decode(TEST_PEER_COMMIT).unwrap();
344 let peer_commit_scalar = &peer_commit_bytes[2..2 + r.len()];
345 let peer_commit_element = &peer_commit_bytes[2 + r.len()..2 + r.len() * 3];
346
347 let local_confirm_bytes = hex::decode(TEST_LOCAL_CONFIRM).unwrap();
351 let local_send_confirm =
352 u16::from_le_bytes(*<&[u8; 2]>::try_from(&local_confirm_bytes[0..2]).unwrap());
353 let local_confirm_element = &local_confirm_bytes[2..2 + r.len()];
354 let peer_confirm_bytes = hex::decode(TEST_PEER_CONFIRM).unwrap();
355 let peer_send_confirm =
356 u16::from_le_bytes(*<&[u8; 2]>::try_from(&peer_confirm_bytes[0..2]).unwrap());
357 let peer_confirm_element = &peer_confirm_bytes[2..2 + r.len()];
358
359 let local_confirm = confirm::<Sha256>(
360 &hex::decode(TEST_KCK).unwrap(),
361 local_send_confirm,
362 &[local_commit_scalar, local_commit_element, peer_commit_scalar, peer_commit_element],
363 );
364 assert_eq!(hex::encode(&local_confirm), hex::encode(local_confirm_element));
365
366 let peer_confirm = confirm::<Sha256>(
367 &hex::decode(TEST_KCK).unwrap(),
368 peer_send_confirm,
369 &[peer_commit_scalar, peer_commit_element, local_commit_scalar, local_commit_element],
370 );
371 assert_eq!(hex::encode(&peer_confirm), hex::encode(peer_confirm_element));
372 }
373}