use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::fs::{self, File};
use std::future::Future;
use std::io;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
pub struct PubSubHub {
inner: RefCell<PubSubHubInner>,
storage_path: PathBuf,
}
pub struct PubSubFuture<'a> {
hub: &'a RefCell<PubSubHubInner>,
id: usize,
last_value: Option<String>,
}
struct PubSubHubInner {
item: Option<String>,
next_future_id: usize,
wakers: BTreeMap<usize, Waker>,
}
impl PubSubHub {
pub fn new(storage_path: PathBuf) -> Self {
let initial_value = load_region_code(&storage_path);
Self {
inner: RefCell::new(PubSubHubInner {
item: initial_value,
next_future_id: 0,
wakers: BTreeMap::new(),
}),
storage_path,
}
}
pub fn publish<S>(&self, new_value: S)
where
S: Into<String>,
{
let hub = &self.inner;
let new_value = new_value.into();
hub.borrow_mut().item = Some(new_value.clone());
hub.borrow_mut().wakers.values().for_each(|w| w.wake_by_ref());
hub.borrow_mut().wakers.clear();
write_region_code(new_value, &self.storage_path);
}
pub fn watch_for_change<S>(&self, last_value: Option<S>) -> PubSubFuture<'_>
where
S: Into<String>,
{
let hub = &self.inner;
let id = hub.borrow().next_future_id;
hub.borrow_mut().next_future_id = id.checked_add(1).expect("`id` is impossibly large");
PubSubFuture { hub, id, last_value: last_value.map(|s| s.into()) }
}
pub fn get_value(&self) -> Option<String> {
let hub = &self.inner;
hub.borrow().get_value()
}
}
#[derive(Debug, Deserialize, Serialize)]
struct RegulatoryRegion {
region_code: String,
}
fn load_region_code(path: impl AsRef<Path>) -> Option<String> {
let file = match File::open(path.as_ref()) {
Ok(file) => file,
Err(e) => match e.kind() {
io::ErrorKind::NotFound => return None,
_ => {
log::info!(
"Failed to read cached regulatory region, will initialize with none: {}",
e
);
try_delete_file(path);
return None;
}
},
};
match serde_json::from_reader::<_, RegulatoryRegion>(io::BufReader::new(file)) {
Ok(region) => Some(region.region_code),
Err(e) => {
log::info!("Error parsing stored regulatory region code: {}", e);
try_delete_file(path);
None
}
}
}
fn write_region_code(region_code: String, storage_path: impl AsRef<Path>) {
let write_val = RegulatoryRegion { region_code };
let file = match File::create(storage_path.as_ref()) {
Ok(file) => file,
Err(e) => {
log::info!("Failed to open file to write regulatory region: {}", e);
try_delete_file(storage_path);
return;
}
};
if let Err(e) = serde_json::to_writer(io::BufWriter::new(file), &write_val) {
log::info!("Failed to write regulatory region: {}", e);
try_delete_file(storage_path);
}
}
fn try_delete_file(storage_path: impl AsRef<Path>) {
if let Err(e) = fs::remove_file(&storage_path) {
log::info!("Failed to delete previously cached regulatory region: {}", e);
}
}
impl Future for PubSubFuture<'_> {
type Output = Option<String>;
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
let hub = &self.hub;
if hub.borrow().has_value(&self.last_value) {
hub.borrow_mut().set_waker_for_future(self.id, context.waker().clone());
Poll::Pending
} else {
Poll::Ready(hub.borrow().get_value())
}
}
}
impl PubSubHubInner {
fn set_waker_for_future(&mut self, future_id: usize, waker: Waker) {
self.wakers.insert(future_id, waker);
}
fn has_value(&self, expected: &Option<String>) -> bool {
self.item == *expected
}
fn get_value(&self) -> Option<String> {
self.item.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use fuchsia_async as fasync;
use futures_test::task::new_count_waker;
use std::io::Write;
use tempfile::TempDir;
#[fasync::run_until_stalled(test)]
async fn watch_for_change_future_is_pending_when_both_values_are_none() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
assert_eq!(0, count.get());
}
#[fasync::run_until_stalled(test)]
async fn watch_for_change_future_is_pending_when_values_are_same_and_not_none() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
hub.publish("US");
let mut future = hub.watch_for_change(Some("US"));
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
assert_eq!(0, count.get());
}
#[fasync::run_until_stalled(test)]
async fn watch_for_change_future_is_immediately_ready_when_argument_differs_from_published_value(
) {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
hub.publish("US");
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
assert_eq!(0, count.get());
}
#[fasync::run_until_stalled(test)]
async fn single_watcher_is_woken_correctly_on_change_from_none_to_some() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("US");
assert_eq!(1, count.get());
assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
}
#[fasync::run_until_stalled(test)]
async fn single_watcher_is_woken_correctly_on_change_from_some_to_new_some() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
hub.publish("US");
let mut future = hub.watch_for_change(Some("US"));
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("SU");
assert_eq!(1, count.get());
assert_eq!(Poll::Ready(Some("SU".to_string())), Pin::new(&mut future).poll(&mut context));
}
#[fasync::run_until_stalled(test)]
async fn multiple_watchers_are_woken_correctly_on_change_from_some_to_new_some() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker_a, wake_count_a) = new_count_waker();
let (waker_b, wake_count_b) = new_count_waker();
let mut context_a = Context::from_waker(&waker_a);
let mut context_b = Context::from_waker(&waker_b);
hub.publish("US");
let mut future_a = hub.watch_for_change(Some("US"));
let mut future_b = hub.watch_for_change(Some("US"));
assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
hub.publish("SU");
assert_eq!(1, wake_count_a.get(), "for waker a");
assert_eq!(1, wake_count_b.get(), "for waker b");
assert_eq!(
Poll::Ready(Some("SU".to_string())),
Pin::new(&mut future_a).poll(&mut context_a),
"for future a"
);
assert_eq!(
Poll::Ready(Some("SU".to_string())),
Pin::new(&mut future_b).poll(&mut context_b),
"for future b"
);
}
#[fasync::run_until_stalled(test)]
async fn multiple_watchers_are_woken_correctly_after_spurious_update() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker_a, wake_count_a) = new_count_waker();
let (waker_b, wake_count_b) = new_count_waker();
let mut context_a = Context::from_waker(&waker_a);
let mut context_b = Context::from_waker(&waker_b);
hub.publish("US");
let mut future_a = hub.watch_for_change(Some("US"));
let mut future_b = hub.watch_for_change(Some("US"));
assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
hub.publish("US");
assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context_a), "for future a");
assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context_b), "for future b");
let old_wake_count_a = wake_count_a.get();
let old_wake_count_b = wake_count_b.get();
hub.publish("SU");
assert_eq!(1, wake_count_a.get() - old_wake_count_a);
assert_eq!(1, wake_count_b.get() - old_wake_count_b);
assert_eq!(
Poll::Ready(Some("SU".to_string())),
Pin::new(&mut future_a).poll(&mut context_a),
"for future a"
);
assert_eq!(
Poll::Ready(Some("SU".to_string())),
Pin::new(&mut future_b).poll(&mut context_b),
"for future b"
);
}
#[fasync::run_until_stalled(test)]
async fn multiple_watchers_can_share_a_waker() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future_a = hub.watch_for_change(Option::<String>::None);
let mut future_b = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future_a).poll(&mut context), "for future a");
assert_eq!(Poll::Pending, Pin::new(&mut future_b).poll(&mut context), "for future b");
hub.publish("US");
assert_eq!(2, count.get());
assert_eq!(
Poll::Ready(Some("US".to_string())),
Pin::new(&mut future_a).poll(&mut context),
"for future a"
);
assert_eq!(
Poll::Ready(Some("US".to_string())),
Pin::new(&mut future_b).poll(&mut context),
"for future b"
);
}
#[fasync::run_until_stalled(test)]
async fn single_watcher_is_not_woken_again_after_future_is_ready() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("US");
assert_eq!(1, count.get());
assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
hub.publish("SU");
assert_eq!(1, count.get());
}
#[fasync::run_until_stalled(test)]
async fn second_watcher_is_woken_for_second_update() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("US");
assert_eq!(1, count.get());
assert_eq!(Poll::Ready(Some("US".to_string())), Pin::new(&mut future).poll(&mut context));
let mut future = hub.watch_for_change(Some("US"));
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("SU");
assert!(count.get() > 1, "Count should be >1, but is {}", count.get());
assert_eq!(Poll::Ready(Some("SU".to_string())), Pin::new(&mut future).poll(&mut context));
}
#[fasync::run_until_stalled(test)]
async fn multiple_polls_of_single_watcher_do_not_cause_multiple_wakes_when_waker_is_reused() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker, count) = new_count_waker();
let mut context = Context::from_waker(&waker);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context));
hub.publish("US");
assert_eq!(1, count.get());
}
#[fasync::run_until_stalled(test)]
async fn multiple_polls_of_single_watcher_do_not_cause_multiple_wakes_when_waker_is_replaced() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
let (waker_a, wake_count_a) = new_count_waker();
let (waker_b, wake_count_b) = new_count_waker();
let mut context_a = Context::from_waker(&waker_a);
let mut context_b = Context::from_waker(&waker_b);
let mut future = hub.watch_for_change(Option::<String>::None);
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context_a));
assert_eq!(Poll::Pending, Pin::new(&mut future).poll(&mut context_b));
hub.publish("US");
assert_eq!(0, wake_count_a.get());
assert_eq!(1, wake_count_b.get());
}
#[test]
fn get_value_is_none() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
assert_eq!(None, hub.get_value());
}
#[test]
fn get_value_is_some() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path);
hub.publish("US");
assert_eq!(Some("US".to_string()), hub.get_value());
}
#[test]
fn published_value_is_saved_and_loaded_on_creation() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let hub = PubSubHub::new(path.to_path_buf());
assert_eq!(hub.get_value(), None);
hub.publish("WW");
assert_eq!(hub.get_value(), Some("WW".to_string()));
let hub = PubSubHub::new(path.to_path_buf());
assert_eq!(hub.get_value(), Some("WW".to_string()));
let file = File::open(&path).expect("Failed to open file");
assert_matches!(
serde_json::from_reader(io::BufReader::new(file)),
Ok(RegulatoryRegion{ region_code }) if region_code.as_str() == "WW"
);
}
#[test]
fn publishing_over_previously_saved_value_overwrites_cache() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
let cache_val = RegulatoryRegion { region_code: "WW".to_string() };
let file = File::create(&path).expect("failed to create file");
serde_json::to_writer(io::BufWriter::new(file), &cache_val)
.expect("Failed to write JSON to file");
let hub = PubSubHub::new(path.to_path_buf());
assert_eq!(hub.get_value(), Some("WW".to_string()));
hub.publish("US");
let file = File::open(&path).expect("Failed to open file");
assert_matches!(
serde_json::from_reader(io::BufReader::new(file)),
Ok(RegulatoryRegion{ region_code }) if region_code.as_str() == "US"
);
let hub = PubSubHub::new(path.to_path_buf());
assert_eq!(hub.get_value(), Some("US".to_string()));
}
#[test]
fn load_as_none_if_cache_file_is_bad() {
let temp_dir = TempDir::new_in("/cache/").expect("failed to create temporary directory");
let path = temp_dir.path().join("regulatory_region.json");
assert!(!path.exists());
let mut file = File::create(&path).expect("failed to create file");
let bad_contents = b"{\"region_code\": ";
file.write_all(bad_contents).expect("failed to write to file");
file.flush().expect("failed to flush file");
let hub = PubSubHub::new(path.to_path_buf());
assert_eq!(hub.get_value(), None);
assert_matches!(File::open(&path), Err(io_err) if io_err.kind() == io::ErrorKind::NotFound);
}
}