1use serde::{Deserialize, Serialize};
8use std::cell::RefCell;
9use std::collections::BTreeMap;
10use std::fs::{self, File};
11use std::future::Future;
12use std::io;
13use std::path::{Path, PathBuf};
14use std::pin::Pin;
15use std::task::{Context, Poll, Waker};
16
17pub struct PubSubHub {
19 inner: RefCell<PubSubHubInner>,
24 storage_path: PathBuf,
26}
27
28pub struct PubSubFuture<'a> {
30 hub: &'a RefCell<PubSubHubInner>,
32 id: usize,
33 last_value: Option<String>,
34}
35
36struct PubSubHubInner {
37 item: Option<String>,
38 next_future_id: usize,
39 wakers: BTreeMap<usize, Waker>,
40}
41
42impl PubSubHub {
43 pub fn new(storage_path: PathBuf) -> Self {
44 let initial_value = load_region_code(&storage_path);
45 Self {
46 inner: RefCell::new(PubSubHubInner {
47 item: initial_value,
48 next_future_id: 0,
49 wakers: BTreeMap::new(),
50 }),
51 storage_path,
52 }
53 }
54
55 pub fn publish<S>(&self, new_value: S)
59 where
60 S: Into<String>,
61 {
62 let hub = &self.inner;
63 let new_value = new_value.into();
64 hub.borrow_mut().item = Some(new_value.clone());
65 hub.borrow_mut().wakers.values().for_each(|w| w.wake_by_ref());
66 hub.borrow_mut().wakers.clear();
67 write_region_code(new_value, &self.storage_path);
69 }
70
71 pub fn watch_for_change<S>(&self, last_value: Option<S>) -> PubSubFuture<'_>
74 where
75 S: Into<String>,
76 {
77 let hub = &self.inner;
78 let id = hub.borrow().next_future_id;
79 hub.borrow_mut().next_future_id = id.checked_add(1).expect("`id` is impossibly large");
80 PubSubFuture { hub, id, last_value: last_value.map(|s| s.into()) }
81 }
82
83 pub fn get_value(&self) -> Option<String> {
84 let hub = &self.inner;
85 hub.borrow().get_value()
86 }
87}
88
89#[derive(Debug, Deserialize, Serialize)]
91struct RegulatoryRegion {
92 region_code: String,
93}
94
95fn load_region_code(path: impl AsRef<Path>) -> Option<String> {
99 let file = match File::open(path.as_ref()) {
100 Ok(file) => file,
101 Err(e) => match e.kind() {
102 io::ErrorKind::NotFound => return None,
103 _ => {
104 log::info!(
105 "Failed to read cached regulatory region, will initialize with none: {}",
106 e
107 );
108 try_delete_file(path);
109 return None;
110 }
111 },
112 };
113 match serde_json::from_reader::<_, RegulatoryRegion>(io::BufReader::new(file)) {
114 Ok(region) => Some(region.region_code),
115 Err(e) => {
116 log::info!("Error parsing stored regulatory region code: {}", e);
117 try_delete_file(path);
118 None
119 }
120 }
121}
122
123fn write_region_code(region_code: String, storage_path: impl AsRef<Path>) {
128 let write_val = RegulatoryRegion { region_code };
129 let file = match File::create(storage_path.as_ref()) {
130 Ok(file) => file,
131 Err(e) => {
132 log::info!("Failed to open file to write regulatory region: {}", e);
133 try_delete_file(storage_path);
134 return;
135 }
136 };
137 if let Err(e) = serde_json::to_writer(io::BufWriter::new(file), &write_val) {
138 log::info!("Failed to write regulatory region: {}", e);
139 try_delete_file(storage_path);
140 }
141}
142
143fn try_delete_file(storage_path: impl AsRef<Path>) {
144 if let Err(e) = fs::remove_file(&storage_path) {
145 log::info!("Failed to delete previously cached regulatory region: {}", e);
146 }
147}
148
149impl Future for PubSubFuture<'_> {
150 type Output = Option<String>;
151
152 fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
153 let hub = &self.hub;
154 if hub.borrow().has_value(&self.last_value) {
155 hub.borrow_mut().set_waker_for_future(self.id, context.waker().clone());
156 Poll::Pending
157 } else {
158 Poll::Ready(hub.borrow().get_value())
159 }
160 }
161}
162
163impl PubSubHubInner {
164 fn set_waker_for_future(&mut self, future_id: usize, waker: Waker) {
165 self.wakers.insert(future_id, waker);
166 }
167
168 fn has_value(&self, expected: &Option<String>) -> bool {
169 self.item == *expected
170 }
171
172 fn get_value(&self) -> Option<String> {
173 self.item.clone()
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use assert_matches::assert_matches;
181 use fuchsia_async as fasync;
182 use futures_test::task::new_count_waker;
183 use std::io::Write;
184 use tempfile::TempDir;
185
186 #[fasync::run_until_stalled(test)]
187 async fn watch_for_change_future_is_pending_when_both_values_are_none() {
188 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
189 let path = temp_dir.path().join("regulatory_region.json");
190 let hub = PubSubHub::new(path);
191 let (waker, count) = new_count_waker();
192 let mut context = Context::from_waker(&waker);
193 let mut future = hub.watch_for_change(Option::<String>::None);
194 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
195 assert_eq!(0, count.get());
196 }
197
198 #[fasync::run_until_stalled(test)]
199 async fn watch_for_change_future_is_pending_when_values_are_same_and_not_none() {
200 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
201 let path = temp_dir.path().join("regulatory_region.json");
202 let hub = PubSubHub::new(path);
203 let (waker, count) = new_count_waker();
204 let mut context = Context::from_waker(&waker);
205 hub.publish("US");
206
207 let mut future = hub.watch_for_change(Some("US"));
208 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
209 assert_eq!(0, count.get());
210 }
211
212 #[fasync::run_until_stalled(test)]
213 async fn watch_for_change_future_is_immediately_ready_when_argument_differs_from_published_value(
214 ) {
215 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
216 let path = temp_dir.path().join("regulatory_region.json");
217 let hub = PubSubHub::new(path);
218 let (waker, count) = new_count_waker();
219 let mut context = Context::from_waker(&waker);
220 hub.publish("US");
221
222 let mut future = hub.watch_for_change(Option::<String>::None);
223 assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
224 assert_eq!(0, count.get());
225 }
226
227 #[fasync::run_until_stalled(test)]
228 async fn single_watcher_is_woken_correctly_on_change_from_none_to_some() {
229 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
230 let path = temp_dir.path().join("regulatory_region.json");
231 let hub = PubSubHub::new(path);
232 let (waker, count) = new_count_waker();
233 let mut context = Context::from_waker(&waker);
234 let mut future = hub.watch_for_change(Option::<String>::None);
235 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
236
237 hub.publish("US");
239 assert_eq!(1, count.get());
240 assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
241 }
242
243 #[fasync::run_until_stalled(test)]
244 async fn single_watcher_is_woken_correctly_on_change_from_some_to_new_some() {
245 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
246 let path = temp_dir.path().join("regulatory_region.json");
247 let hub = PubSubHub::new(path);
248 let (waker, count) = new_count_waker();
249 let mut context = Context::from_waker(&waker);
250 hub.publish("US");
251
252 let mut future = hub.watch_for_change(Some("US"));
253 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
254
255 hub.publish("SU");
257 assert_eq!(1, count.get());
258 assert_eq!(Poll::Ready(Some("SU".to_string())), Pin::new(&mut future).poll(&mut context));
259 }
260
261 #[fasync::run_until_stalled(test)]
262 async fn multiple_watchers_are_woken_correctly_on_change_from_some_to_new_some() {
263 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
264 let path = temp_dir.path().join("regulatory_region.json");
265 let hub = PubSubHub::new(path);
266 let (waker_a, wake_count_a) = new_count_waker();
267 let (waker_b, wake_count_b) = new_count_waker();
268 let mut context_a = Context::from_waker(&waker_a);
269 let mut context_b = Context::from_waker(&waker_b);
270 hub.publish("US");
271
272 let mut future_a = hub.watch_for_change(Some("US"));
273 let mut future_b = hub.watch_for_change(Some("US"));
274 assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
275 assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
276
277 hub.publish("SU");
279 assert_eq!(1, wake_count_a.get(), "for waker a");
280 assert_eq!(1, wake_count_b.get(), "for waker b");
281 assert_eq!(
282 Poll::Ready(Some("SU".to_string())),
283 Pin::new(&mut future_a).poll(&mut context_a),
284 "for future a"
285 );
286 assert_eq!(
287 Poll::Ready(Some("SU".to_string())),
288 Pin::new(&mut future_b).poll(&mut context_b),
289 "for future b"
290 );
291 }
292
293 #[fasync::run_until_stalled(test)]
294 async fn multiple_watchers_are_woken_correctly_after_spurious_update() {
295 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
296 let path = temp_dir.path().join("regulatory_region.json");
297 let hub = PubSubHub::new(path);
298 let (waker_a, wake_count_a) = new_count_waker();
299 let (waker_b, wake_count_b) = new_count_waker();
300 let mut context_a = Context::from_waker(&waker_a);
301 let mut context_b = Context::from_waker(&waker_b);
302 hub.publish("US");
303
304 let mut future_a = hub.watch_for_change(Some("US"));
305 let mut future_b = hub.watch_for_change(Some("US"));
306 assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
307 assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
308
309 hub.publish("US");
311 assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
312 assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
313
314 let old_wake_count_a = wake_count_a.get();
316 let old_wake_count_b = wake_count_b.get();
317 hub.publish("SU");
318 assert_eq!(1, wake_count_a.get() - old_wake_count_a);
319 assert_eq!(1, wake_count_b.get() - old_wake_count_b);
320 assert_eq!(
321 Poll::Ready(Some("SU".to_string())),
322 Pin::new(&mut future_a).poll(&mut context_a),
323 "for future a"
324 );
325 assert_eq!(
326 Poll::Ready(Some("SU".to_string())),
327 Pin::new(&mut future_b).poll(&mut context_b),
328 "for future b"
329 );
330 }
331
332 #[fasync::run_until_stalled(test)]
333 async fn multiple_watchers_can_share_a_waker() {
334 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
335 let path = temp_dir.path().join("regulatory_region.json");
336 let hub = PubSubHub::new(path);
337 let (waker, count) = new_count_waker();
338 let mut context = Context::from_waker(&waker);
339 let mut future_a = hub.watch_for_change(Option::<String>::None);
340 let mut future_b = hub.watch_for_change(Option::<String>::None);
341 assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context), "for future a");
342 assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context), "for future b");
343
344 hub.publish("US");
346 assert_eq!(2, count.get());
347 assert_eq!(
348 Poll::Ready(Some("US".to_string())),
349 Pin::new(&mut future_a).poll(&mut context),
350 "for future a"
351 );
352 assert_eq!(
353 Poll::Ready(Some("US".to_string())),
354 Pin::new(&mut future_b).poll(&mut context),
355 "for future b"
356 );
357 }
358
359 #[fasync::run_until_stalled(test)]
360 async fn single_watcher_is_not_woken_again_after_future_is_ready() {
361 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
362 let path = temp_dir.path().join("regulatory_region.json");
363 let hub = PubSubHub::new(path);
364 let (waker, count) = new_count_waker();
365 let mut context = Context::from_waker(&waker);
366 let mut future = hub.watch_for_change(Option::<String>::None);
367 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
368
369 hub.publish("US");
371 assert_eq!(1, count.get());
372 assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
373
374 hub.publish("SU");
376 assert_eq!(1, count.get());
377 }
378
379 #[fasync::run_until_stalled(test)]
380 async fn second_watcher_is_woken_for_second_update() {
381 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
382 let path = temp_dir.path().join("regulatory_region.json");
383 let hub = PubSubHub::new(path);
384 let (waker, count) = new_count_waker();
385 let mut context = Context::from_waker(&waker);
386 let mut future = hub.watch_for_change(Option::<String>::None);
387 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
388
389 hub.publish("US");
391 assert_eq!(1, count.get());
392 assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
393
394 let mut future = hub.watch_for_change(Some("US"));
396 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
397 hub.publish("SU");
398 assert!(count.get() > 1, "Count should be >1, but is {}", count.get());
399 assert_eq!(Poll::Ready(Some("SU".to_string())), Pin::new(&mut future).poll(&mut context));
400 }
401
402 #[fasync::run_until_stalled(test)]
403 async fn multiple_polls_of_single_watcher_do_not_cause_multiple_wakes_when_waker_is_reused() {
404 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
405 let path = temp_dir.path().join("regulatory_region.json");
406 let hub = PubSubHub::new(path);
407 let (waker, count) = new_count_waker();
408 let mut context = Context::from_waker(&waker);
409 let mut future = hub.watch_for_change(Option::<String>::None);
410 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
411 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
412
413 hub.publish("US");
415 assert_eq!(1, count.get());
416 }
417
418 #[fasync::run_until_stalled(test)]
419 async fn multiple_polls_of_single_watcher_do_not_cause_multiple_wakes_when_waker_is_replaced() {
420 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
421 let path = temp_dir.path().join("regulatory_region.json");
422 let hub = PubSubHub::new(path);
423 let (waker_a, wake_count_a) = new_count_waker();
424 let (waker_b, wake_count_b) = new_count_waker();
425 let mut context_a = Context::from_waker(&waker_a);
426 let mut context_b = Context::from_waker(&waker_b);
427 let mut future = hub.watch_for_change(Option::<String>::None);
428 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context_a));
429 assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context_b));
430
431 hub.publish("US");
433 assert_eq!(0, wake_count_a.get());
434 assert_eq!(1, wake_count_b.get());
435 }
436
437 #[test]
438 fn get_value_is_none() {
439 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
440 let path = temp_dir.path().join("regulatory_region.json");
441 let hub = PubSubHub::new(path);
442 assert_eq!(None, hub.get_value());
443 }
444
445 #[test]
446 fn get_value_is_some() {
447 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
448 let path = temp_dir.path().join("regulatory_region.json");
449 let hub = PubSubHub::new(path);
450 hub.publish("US");
451 assert_eq!(Some("US".to_string()), hub.get_value());
452 }
453
454 #[test]
455 fn published_value_is_saved_and_loaded_on_creation() {
456 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
457 let path = temp_dir.path().join("regulatory_region.json");
458 let hub = PubSubHub::new(path.to_path_buf());
459 assert_eq!(hub.get_value(), None);
460 hub.publish("WW");
461 assert_eq!(hub.get_value(), Some("WW".to_string()));
462
463 let hub = PubSubHub::new(path.to_path_buf());
466 assert_eq!(hub.get_value(), Some("WW".to_string()));
467
468 let file = File::open(&path).expect("Failed to open file");
470 assert_matches!(
471 serde_json::from_reader(io::BufReader::new(file)),
472 Ok(RegulatoryRegion{ region_code }) if region_code.as_str() == "WW"
473 );
474 }
475
476 #[test]
477 fn publishing_over_previously_saved_value_overwrites_cache() {
478 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
479 let path = temp_dir.path().join("regulatory_region.json");
480
481 let cache_val = RegulatoryRegion { region_code: "WW".to_string() };
483 let file = File::create(&path).expect("failed to create file");
484 serde_json::to_writer(io::BufWriter::new(file), &cache_val)
485 .expect("Failed to write JSON to file");
486
487 let hub = PubSubHub::new(path.to_path_buf());
489 assert_eq!(hub.get_value(), Some("WW".to_string()));
490
491 hub.publish("US");
493 let file = File::open(&path).expect("Failed to open file");
494 assert_matches!(
495 serde_json::from_reader(io::BufReader::new(file)),
496 Ok(RegulatoryRegion{ region_code }) if region_code.as_str() == "US"
497 );
498 let hub = PubSubHub::new(path.to_path_buf());
499 assert_eq!(hub.get_value(), Some("US".to_string()));
500 }
501
502 #[test]
503 fn load_as_none_if_cache_file_is_bad() {
504 let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
505 let path = temp_dir.path().join("regulatory_region.json");
506 assert!(!path.exists());
507 let mut file = File::create(&path).expect("failed to create file");
508 let bad_contents = b"{\"region_code\": ";
509 file.write_all(bad_contents).expect("failed to write to file");
510 file.flush().expect("failed to flush file");
511
512 let hub = PubSubHub::new(path.to_path_buf());
514 assert_eq!(hub.get_value(), None);
515
516 assert_matches!(File::open(&path), Err(io_err) if io_err.kind() == io::ErrorKind::NotFound);
518 }
519}