1use crate::rand;
2use crate::server::ProducesTickets;
3use crate::Error;
4
5use ring::aead;
6use std::mem;
7use std::sync::{Arc, Mutex, MutexGuard};
8use std::time;
9
10#[derive(Clone, Copy, Debug)]
15pub struct TimeBase(time::Duration);
16
17impl TimeBase {
18 #[inline]
19 pub fn now() -> Result<Self, time::SystemTimeError> {
20 Ok(Self(
21 time::SystemTime::now().duration_since(time::UNIX_EPOCH)?,
22 ))
23 }
24
25 #[inline]
26 pub fn as_secs(&self) -> u64 {
27 self.0.as_secs()
28 }
29}
30
31struct AeadTicketer {
36 alg: &'static aead::Algorithm,
37 key: aead::LessSafeKey,
38 lifetime: u32,
39}
40
41impl AeadTicketer {
42 fn new() -> Result<Self, rand::GetRandomFailed> {
44 let mut key = [0u8; 32];
45 rand::fill_random(&mut key)?;
46
47 let alg = &aead::CHACHA20_POLY1305;
48 let key = aead::UnboundKey::new(alg, &key).unwrap();
49
50 Ok(Self {
51 alg,
52 key: aead::LessSafeKey::new(key),
53 lifetime: 60 * 60 * 12,
54 })
55 }
56}
57
58impl ProducesTickets for AeadTicketer {
59 fn enabled(&self) -> bool {
60 true
61 }
62 fn lifetime(&self) -> u32 {
63 self.lifetime
64 }
65
66 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
68 let mut nonce_buf = [0u8; 12];
70 rand::fill_random(&mut nonce_buf).ok()?;
71 let nonce = aead::Nonce::assume_unique_for_key(nonce_buf);
72 let aad = aead::Aad::empty();
73
74 let mut ciphertext =
75 Vec::with_capacity(nonce_buf.len() + message.len() + self.key.algorithm().tag_len());
76 ciphertext.extend(nonce_buf);
77 ciphertext.extend(message);
78 self.key
79 .seal_in_place_separate_tag(nonce, aad, &mut ciphertext[nonce_buf.len()..])
80 .map(|tag| {
81 ciphertext.extend(tag.as_ref());
82 ciphertext
83 })
84 .ok()
85 }
86
87 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
89 let nonce = ciphertext.get(..self.alg.nonce_len())?;
91 let ciphertext = ciphertext.get(nonce.len()..)?;
92
93 let nonce = aead::Nonce::try_assume_unique_for_key(nonce).ok()?;
95
96 let mut out = Vec::from(ciphertext);
97
98 let plain_len = self
99 .key
100 .open_in_place(nonce, aead::Aad::empty(), &mut out)
101 .ok()?
102 .len();
103 out.truncate(plain_len);
104
105 Some(out)
106 }
107}
108
109struct TicketSwitcherState {
110 next: Option<Box<dyn ProducesTickets>>,
111 current: Box<dyn ProducesTickets>,
112 previous: Option<Box<dyn ProducesTickets>>,
113 next_switch_time: u64,
114}
115
116struct TicketSwitcher {
120 generator: fn() -> Result<Box<dyn ProducesTickets>, rand::GetRandomFailed>,
121 lifetime: u32,
122 state: Mutex<TicketSwitcherState>,
123}
124
125impl TicketSwitcher {
126 fn new(
131 lifetime: u32,
132 generator: fn() -> Result<Box<dyn ProducesTickets>, rand::GetRandomFailed>,
133 ) -> Result<Self, Error> {
134 let now = TimeBase::now()?;
135 Ok(Self {
136 generator,
137 lifetime,
138 state: Mutex::new(TicketSwitcherState {
139 next: Some(generator()?),
140 current: generator()?,
141 previous: None,
142 next_switch_time: now
143 .as_secs()
144 .saturating_add(u64::from(lifetime)),
145 }),
146 })
147 }
148
149 fn maybe_roll(&self, now: TimeBase) -> Option<MutexGuard<TicketSwitcherState>> {
159 let now = now.as_secs();
188 let mut are_recovering = false; {
190 let mut state = self.state.lock().ok()?;
192
193 if now <= state.next_switch_time {
195 return Some(state);
196 }
197
198 if let Some(next) = state.next.take() {
200 state.previous = Some(mem::replace(&mut state.current, next));
201 state.next_switch_time = now.saturating_add(u64::from(self.lifetime));
202 } else {
203 are_recovering = true;
204 }
205 }
206
207 let next = (self.generator)().ok()?;
209 if !are_recovering {
210 let mut state = self.state.lock().ok()?;
212 state.next = Some(next);
213 Some(state)
214 } else {
215 let new_current = (self.generator)().ok()?;
219 let mut state = self.state.lock().ok()?;
220 state.next = Some(next);
221 if now > state.next_switch_time {
222 state.previous = Some(mem::replace(&mut state.current, new_current));
223 state.next_switch_time = now.saturating_add(u64::from(self.lifetime));
224 }
225 Some(state)
226 }
227 }
228}
229
230impl ProducesTickets for TicketSwitcher {
231 fn lifetime(&self) -> u32 {
232 self.lifetime * 2
233 }
234
235 fn enabled(&self) -> bool {
236 true
237 }
238
239 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
240 let state = self.maybe_roll(TimeBase::now().ok()?)?;
241
242 state.current.encrypt(message)
243 }
244
245 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
246 let state = self.maybe_roll(TimeBase::now().ok()?)?;
247
248 state
250 .current
251 .decrypt(ciphertext)
252 .or_else(|| {
253 state
254 .previous
255 .as_ref()
256 .and_then(|previous| previous.decrypt(ciphertext))
257 })
258 }
259}
260
261pub struct Ticketer {}
263
264fn generate_inner() -> Result<Box<dyn ProducesTickets>, rand::GetRandomFailed> {
265 Ok(Box::new(AeadTicketer::new()?))
266}
267
268impl Ticketer {
269 pub fn new() -> Result<Arc<dyn ProducesTickets>, Error> {
274 Ok(Arc::new(TicketSwitcher::new(6 * 60 * 60, generate_inner)?))
275 }
276}
277
278#[test]
279fn basic_pairwise_test() {
280 let t = Ticketer::new().unwrap();
281 assert!(t.enabled());
282 let cipher = t.encrypt(b"hello world").unwrap();
283 let plain = t.decrypt(&cipher).unwrap();
284 assert_eq!(plain, b"hello world");
285}
286
287#[test]
288fn ticketswitcher_switching_test() {
289 let t = Arc::new(TicketSwitcher::new(1, generate_inner).unwrap());
290 let now = TimeBase::now().unwrap();
291 let cipher1 = t.encrypt(b"ticket 1").unwrap();
292 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
293 {
294 t.maybe_roll(TimeBase(now.0 + time::Duration::from_secs(10)));
296 }
297 let cipher2 = t.encrypt(b"ticket 2").unwrap();
298 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
299 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
300 {
301 t.maybe_roll(TimeBase(now.0 + time::Duration::from_secs(20)));
303 }
304 let cipher3 = t.encrypt(b"ticket 3").unwrap();
305 assert!(t.decrypt(&cipher1).is_none());
306 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
307 assert_eq!(t.decrypt(&cipher3).unwrap(), b"ticket 3");
308}
309
310#[cfg(test)]
311fn fail_generator() -> Result<Box<dyn ProducesTickets>, rand::GetRandomFailed> {
312 Err(rand::GetRandomFailed)
313}
314
315#[test]
316fn ticketswitcher_recover_test() {
317 let mut t = TicketSwitcher::new(1, generate_inner).unwrap();
318 let now = TimeBase::now().unwrap();
319 let cipher1 = t.encrypt(b"ticket 1").unwrap();
320 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
321 t.generator = fail_generator;
322 {
323 t.maybe_roll(TimeBase(now.0 + time::Duration::from_secs(10)));
325 }
326 t.generator = generate_inner;
327 let cipher2 = t.encrypt(b"ticket 2").unwrap();
328 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
329 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
330 {
331 t.maybe_roll(TimeBase(now.0 + time::Duration::from_secs(20)));
333 }
334 let cipher3 = t.encrypt(b"ticket 3").unwrap();
335 assert!(t.decrypt(&cipher1).is_none());
336 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
337 assert_eq!(t.decrypt(&cipher3).unwrap(), b"ticket 3");
338}