1use crate::{with_output_array, with_output_vec, with_output_vec_fallible, FfiMutSlice, FfiSlice};
49use alloc::vec::Vec;
50
51#[derive(Debug)]
53pub struct InvalidCiphertext;
54
55pub trait Aead {
57 type Tag: AsRef<[u8]>;
60
61 type Nonce: AsRef<[u8]>;
64
65 fn seal(&self, nonce: &Self::Nonce, plaintext: &[u8], ad: &[u8]) -> Vec<u8>;
69
70 fn seal_in_place(&self, nonce: &Self::Nonce, plaintext: &mut [u8], ad: &[u8]) -> Self::Tag;
75
76 fn open(&self, nonce: &Self::Nonce, ciphertext: &[u8], ad: &[u8]) -> Option<Vec<u8>>;
81
82 fn open_in_place(
86 &self,
87 nonce: &Self::Nonce,
88 ciphertext: &mut [u8],
89 tag: &Self::Tag,
90 ad: &[u8],
91 ) -> Result<(), InvalidCiphertext>;
92}
93
94pub struct Aes128Gcm(EvpAead<16, 12, 16>);
96aead_algo!(Aes128Gcm, EVP_aead_aes_128_gcm, 16, 12, 16);
97
98pub struct Aes256Gcm(EvpAead<32, 12, 16>);
100aead_algo!(Aes256Gcm, EVP_aead_aes_256_gcm, 32, 12, 16);
101
102pub struct Aes128GcmSiv(EvpAead<16, 12, 16>);
104aead_algo!(Aes128GcmSiv, EVP_aead_aes_128_gcm_siv, 16, 12, 16);
105
106pub struct Aes256GcmSiv(EvpAead<32, 12, 16>);
108aead_algo!(Aes256GcmSiv, EVP_aead_aes_256_gcm_siv, 32, 12, 16);
109
110pub struct Chacha20Poly1305(EvpAead<32, 12, 16>);
112aead_algo!(Chacha20Poly1305, EVP_aead_chacha20_poly1305, 32, 12, 16);
113
114pub struct XChacha20Poly1305(EvpAead<32, 24, 16>);
116aead_algo!(XChacha20Poly1305, EVP_aead_xchacha20_poly1305, 32, 24, 16);
117
118struct EvpAead<const KEY_LEN: usize, const NONCE_LEN: usize, const TAG_LEN: usize>(
120 *mut bssl_sys::EVP_AEAD_CTX,
121);
122
123#[allow(clippy::unwrap_used)]
124impl<const KEY_LEN: usize, const NONCE_LEN: usize, const TAG_LEN: usize>
125 EvpAead<KEY_LEN, NONCE_LEN, TAG_LEN>
126{
127 unsafe fn new(key: &[u8; KEY_LEN], evp_aead: *const bssl_sys::EVP_AEAD) -> Self {
129 let ptr =
133 unsafe { bssl_sys::EVP_AEAD_CTX_new(evp_aead, key.as_ffi_ptr(), key.len(), TAG_LEN) };
134 assert!(!ptr.is_null());
135 Self(ptr)
136 }
137
138 fn seal(&self, nonce: &[u8; NONCE_LEN], plaintext: &[u8], ad: &[u8]) -> Vec<u8> {
139 let max_output = plaintext.len() + TAG_LEN;
140 unsafe {
141 with_output_vec(max_output, |out_buf| {
142 let mut out_len = 0usize;
143 let result = bssl_sys::EVP_AEAD_CTX_seal(
148 self.0,
149 out_buf,
150 &mut out_len,
151 max_output,
152 nonce.as_ffi_ptr(),
153 nonce.len(),
154 plaintext.as_ffi_ptr(),
155 plaintext.len(),
156 ad.as_ffi_ptr(),
157 ad.len(),
158 );
159 assert_eq!(result, 1);
161 assert_eq!(out_len, max_output);
164 out_len
166 })
167 }
168 }
169
170 fn seal_in_place(
171 &self,
172 nonce: &[u8; NONCE_LEN],
173 plaintext: &mut [u8],
174 ad: &[u8],
175 ) -> [u8; TAG_LEN] {
176 unsafe {
180 with_output_array(|tag, tag_len| {
181 let mut out_tag_len = 0usize;
182 let result = bssl_sys::EVP_AEAD_CTX_seal_scatter(
183 self.0,
184 plaintext.as_mut_ffi_ptr(),
185 tag,
186 &mut out_tag_len,
187 tag_len,
188 nonce.as_ffi_ptr(),
189 nonce.len(),
190 plaintext.as_ffi_ptr(),
191 plaintext.len(),
192 core::ptr::null(),
193 0,
194 ad.as_ffi_ptr(),
195 ad.len(),
196 );
197 assert_eq!(result, 1);
200 assert_eq!(out_tag_len, TAG_LEN);
202 })
203 }
204 }
205
206 fn open(&self, nonce: &[u8; NONCE_LEN], ciphertext: &[u8], ad: &[u8]) -> Option<Vec<u8>> {
207 if ciphertext.len() < TAG_LEN {
208 return None;
209 }
210 let max_output = ciphertext.len() - TAG_LEN;
211
212 unsafe {
213 with_output_vec_fallible(max_output, |out_buf| {
214 let mut out_len = 0usize;
215 let result = bssl_sys::EVP_AEAD_CTX_open(
220 self.0,
221 out_buf,
222 &mut out_len,
223 max_output,
224 nonce.as_ffi_ptr(),
225 nonce.len(),
226 ciphertext.as_ffi_ptr(),
227 ciphertext.len(),
228 ad.as_ffi_ptr(),
229 ad.len(),
230 );
231 if result == 1 {
232 Some(out_len)
234 } else {
235 None
236 }
237 })
238 }
239 }
240
241 fn open_in_place(
242 &self,
243 nonce: &[u8; NONCE_LEN],
244 ciphertext: &mut [u8],
245 tag: &[u8; TAG_LEN],
246 ad: &[u8],
247 ) -> Result<(), InvalidCiphertext> {
248 let result = unsafe {
251 bssl_sys::EVP_AEAD_CTX_open_gather(
252 self.0,
253 ciphertext.as_mut_ffi_ptr(),
254 nonce.as_ffi_ptr(),
255 nonce.len(),
256 ciphertext.as_ffi_ptr(),
257 ciphertext.len(),
258 tag.as_ffi_ptr(),
259 tag.len(),
260 ad.as_ffi_ptr(),
261 ad.len(),
262 )
263 };
264 if result == 1 {
265 Ok(())
266 } else {
267 Err(InvalidCiphertext)
268 }
269 }
270}
271
272impl<const KEY_LEN: usize, const NONCE_LEN: usize, const TAG_LEN: usize> Drop
273 for EvpAead<KEY_LEN, NONCE_LEN, TAG_LEN>
274{
275 fn drop(&mut self) {
276 unsafe { bssl_sys::EVP_AEAD_CTX_free(self.0) }
279 }
280}
281
282#[cfg(test)]
283mod test {
284 use super::*;
285 use crate::test_helpers::{decode_hex, decode_hex_into_vec};
286
287 fn check_aead_invariants<
288 const NONCE_LEN: usize,
289 const TAG_LEN: usize,
290 A: Aead<Nonce = [u8; NONCE_LEN], Tag = [u8; TAG_LEN]>,
291 >(
292 aead: A,
293 ) {
294 let plaintext = b"plaintext";
295 let ad = b"additional data";
296 let nonce: A::Nonce = [0u8; NONCE_LEN];
297
298 let mut ciphertext = aead.seal(&nonce, plaintext, ad);
299 let plaintext2 = aead
300 .open(&nonce, ciphertext.as_slice(), ad)
301 .expect("should decrypt");
302 assert_eq!(plaintext, plaintext2.as_slice());
303
304 ciphertext[0] ^= 1;
305 assert!(aead.open(&nonce, ciphertext.as_slice(), ad).is_none());
306 ciphertext[0] ^= 1;
307
308 let (ciphertext_in_place, tag_slice) =
309 ciphertext.as_mut_slice().split_at_mut(plaintext.len());
310 let tag: [u8; TAG_LEN] = tag_slice.try_into().unwrap();
311 aead.open_in_place(&nonce, ciphertext_in_place, &tag, ad)
312 .expect("should decrypt");
313 assert_eq!(plaintext, ciphertext_in_place);
314
315 let tag = aead.seal_in_place(&nonce, ciphertext_in_place, ad);
316 aead.open_in_place(&nonce, ciphertext_in_place, &tag, ad)
317 .expect("should decrypt");
318 assert_eq!(plaintext, ciphertext_in_place);
319
320 assert!(aead.open(&nonce, b"tooshort", b"").is_none());
321 }
322
323 #[test]
324 fn aes_128_gcm_invariants() {
325 check_aead_invariants(Aes128Gcm::new(&[0u8; 16]));
326 }
327
328 #[test]
329 fn aes_256_gcm_invariants() {
330 check_aead_invariants(Aes256Gcm::new(&[0u8; 32]));
331 }
332
333 #[test]
334 fn aes_128_gcm_siv_invariants() {
335 check_aead_invariants(Aes128GcmSiv::new(&[0u8; 16]));
336 }
337
338 #[test]
339 fn aes_256_gcm_siv_invariants() {
340 check_aead_invariants(Aes256GcmSiv::new(&[0u8; 32]));
341 }
342
343 #[test]
344 fn chacha20_poly1305_invariants() {
345 check_aead_invariants(Chacha20Poly1305::new(&[0u8; 32]));
346 }
347
348 #[test]
349 fn xchacha20_poly1305_invariants() {
350 check_aead_invariants(XChacha20Poly1305::new(&[0u8; 32]));
351 }
352
353 struct TestCase<const KEY_LEN: usize, const NONCE_LEN: usize> {
354 key: [u8; KEY_LEN],
355 nonce: [u8; NONCE_LEN],
356 msg: Vec<u8>,
357 ad: Vec<u8>,
358 ciphertext: Vec<u8>,
359 }
360
361 fn check_test_cases<
362 const KEY_LEN: usize,
363 const NONCE_LEN: usize,
364 const TAG_LEN: usize,
365 F: Fn(&[u8; KEY_LEN]) -> Box<dyn Aead<Nonce = [u8; NONCE_LEN], Tag = [u8; TAG_LEN]>>,
366 >(
367 new_func: F,
368 test_cases: &[TestCase<KEY_LEN, NONCE_LEN>],
369 ) {
370 for (test_num, test) in test_cases.iter().enumerate() {
371 let ctx = new_func(&test.key);
372 let ciphertext = ctx.seal(&test.nonce, test.msg.as_slice(), test.ad.as_slice());
373 assert_eq!(ciphertext, test.ciphertext, "Failed on test #{}", test_num);
374
375 let plaintext = ctx
376 .open(&test.nonce, ciphertext.as_slice(), test.ad.as_slice())
377 .unwrap();
378 assert_eq!(plaintext, test.msg, "Decrypt failed on test #{}", test_num);
379 }
380 }
381
382 #[test]
383 fn aes_128_gcm_siv() {
384 let test_cases: &[TestCase<16, 12>] = &[
385 TestCase {
386 key: decode_hex("01000000000000000000000000000000"),
389 nonce: decode_hex("030000000000000000000000"),
390 msg: Vec::new(),
391 ad: Vec::new(),
392 ciphertext: decode_hex_into_vec("dc20e2d83f25705bb49e439eca56de25"),
393 },
394 TestCase {
395 key: decode_hex("01000000000000000000000000000000"),
397 nonce: decode_hex("030000000000000000000000"),
398 msg: decode_hex_into_vec("0100000000000000"),
399 ad: Vec::new(),
400 ciphertext: decode_hex_into_vec("b5d839330ac7b786578782fff6013b815b287c22493a364c"),
401 },
402 TestCase {
403 key: decode_hex("01000000000000000000000000000000"),
405 nonce: decode_hex("030000000000000000000000"),
406 msg: decode_hex_into_vec("02000000"),
407 ad: decode_hex_into_vec("010000000000000000000000"),
408 ciphertext: decode_hex_into_vec("a8fe3e8707eb1f84fb28f8cb73de8e99e2f48a14"),
409 },
410 ];
411
412 check_test_cases(|key| Box::new(Aes128GcmSiv::new(key)), test_cases);
413 }
414
415 #[test]
416 fn aes_256_gcm_siv() {
417 let test_cases: &[TestCase<32, 12>] = &[
418 TestCase {
419 key: decode_hex("0100000000000000000000000000000000000000000000000000000000000000"),
422 nonce: decode_hex("030000000000000000000000"),
423 msg: decode_hex_into_vec("0100000000000000"),
424 ad: Vec::new(),
425 ciphertext: decode_hex_into_vec("c2ef328e5c71c83b843122130f7364b761e0b97427e3df28"),
426 },
427 TestCase {
428 key: decode_hex("0100000000000000000000000000000000000000000000000000000000000000"),
430 nonce: decode_hex("030000000000000000000000"),
431 msg: decode_hex_into_vec("010000000000000000000000"),
432 ad: Vec::new(),
433 ciphertext: decode_hex_into_vec(
434 "9aab2aeb3faa0a34aea8e2b18ca50da9ae6559e48fd10f6e5c9ca17e",
435 ),
436 },
437 TestCase {
438 key: decode_hex("0100000000000000000000000000000000000000000000000000000000000000"),
440 nonce: decode_hex("030000000000000000000000"),
441 msg: decode_hex_into_vec("02000000"),
442 ad: decode_hex_into_vec("010000000000000000000000"),
443 ciphertext: decode_hex_into_vec("22b3f4cd1835e517741dfddccfa07fa4661b74cf"),
444 },
445 ];
446
447 check_test_cases(|key| Box::new(Aes256GcmSiv::new(key)), test_cases);
448 }
449
450 #[test]
451 fn aes_128_gcm() {
452 let test_cases: &[TestCase<16, 12>] = &[
453 TestCase {
454 key: decode_hex("d480429666d48b400633921c5407d1d1"),
456 nonce: decode_hex("3388c676dc754acfa66e172a"),
457 msg: Vec::new(),
458 ad: Vec::new(),
459 ciphertext: decode_hex_into_vec("7d7daf44850921a34e636b01adeb104f"),
460 },
461 TestCase {
462 key: decode_hex("3881e7be1bb3bbcaff20bdb78e5d1b67"),
464 nonce: decode_hex("dcf5b7ae2d7552e2297fcfa9"),
465 msg: decode_hex_into_vec("0a2714aa7d"),
466 ad: decode_hex_into_vec("c60c64bbf7"),
467 ciphertext: decode_hex_into_vec("5626f96ecbff4c4f1d92b0abb1d0820833d9eb83c7"),
468 },
469 ];
470
471 check_test_cases(|key| Box::new(Aes128Gcm::new(key)), test_cases);
472 }
473
474 #[test]
475 fn aes_256_gcm() {
476 let test_cases: &[TestCase<32, 12>] = &[
477 TestCase {
478 key: decode_hex("e5ac4a32c67e425ac4b143c83c6f161312a97d88d634afdf9f4da5bd35223f01"),
480 nonce: decode_hex("5bf11a0951f0bfc7ea5c9e58"),
481 msg: Vec::new(),
482 ad: Vec::new(),
483 ciphertext: decode_hex_into_vec("d7cba289d6d19a5af45dc13857016bac"),
484 },
485 TestCase {
486 key: decode_hex("73ad7bbbbc640c845a150f67d058b279849370cd2c1f3c67c4dd6c869213e13a"),
488 nonce: decode_hex("a330a184fc245812f4820caa"),
489 msg: decode_hex_into_vec("f0535fe211"),
490 ad: decode_hex_into_vec("e91428be04"),
491 ciphertext: decode_hex_into_vec("e9b8a896da9115ed79f26a030c14947b3e454db9e7"),
492 },
493 ];
494
495 check_test_cases(|key| Box::new(Aes256Gcm::new(key)), test_cases);
496 }
497}