use crate::{
protocol::{self, request::InstallSource, Cohort},
storage::Storage,
time::PartialComplexTime,
version::Version,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::time::Duration;
use tracing::error;
use typed_builder::TypedBuilder;
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub enum UserCounting {
ClientRegulatedByDate(
Option<u32>,
),
}
impl From<Option<protocol::response::DayStart>> for UserCounting {
fn from(opt_day_start: Option<protocol::response::DayStart>) -> Self {
match opt_day_start {
Some(day_start) => UserCounting::ClientRegulatedByDate(day_start.elapsed_days),
None => UserCounting::ClientRegulatedByDate(None),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, TypedBuilder)]
pub struct App {
#[builder(setter(into))]
pub id: String,
#[builder(setter(into))]
pub version: Version,
#[builder(default)]
#[builder(setter(into, strip_option))]
pub fingerprint: Option<String>,
#[builder(default)]
pub cohort: Cohort,
#[builder(default=UserCounting::ClientRegulatedByDate(None))]
pub user_counting: UserCounting,
#[builder(default)]
#[builder(setter(into))]
pub extra_fields: HashMap<String, String>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct PersistedApp {
pub cohort: Cohort,
pub user_counting: UserCounting,
}
impl From<&App> for PersistedApp {
fn from(app: &App) -> Self {
PersistedApp {
cohort: app.cohort.clone(),
user_counting: app.user_counting.clone(),
}
}
}
impl App {
pub async fn load<'a>(&'a mut self, storage: &'a impl Storage) {
if let Some(app_json) = storage.get_string(&self.id).await {
match serde_json::from_str::<PersistedApp>(&app_json) {
Ok(persisted_app) => {
if self.cohort.id.is_none() {
self.cohort.id = persisted_app.cohort.id;
}
if self.cohort.hint.is_none() {
self.cohort.hint = persisted_app.cohort.hint;
}
if self.cohort.name.is_none() {
self.cohort.name = persisted_app.cohort.name;
}
if self.user_counting == UserCounting::ClientRegulatedByDate(None) {
self.user_counting = persisted_app.user_counting;
}
}
Err(e) => {
error!(
"Unable to deserialize PersistedApp from json {}: {}",
app_json, e
);
}
}
}
}
pub async fn persist<'a>(&'a self, storage: &'a mut impl Storage) {
let persisted_app = PersistedApp::from(self);
match serde_json::to_string(&persisted_app) {
Ok(json) => {
if let Err(e) = storage.set_string(&self.id, &json).await {
error!("Unable to persist cohort id: {}", e);
}
}
Err(e) => {
error!(
"Unable to serialize PersistedApp {:?}: {}",
persisted_app, e
);
}
}
}
pub fn get_current_channel(&self) -> &str {
self.cohort.name.as_deref().unwrap_or("")
}
pub fn get_target_channel(&self) -> &str {
self.cohort
.hint
.as_deref()
.unwrap_or_else(|| self.get_current_channel())
}
pub fn set_target_channel(&mut self, channel: Option<String>, id: Option<String>) {
self.cohort.hint = channel;
if let Some(id) = id {
self.id = id;
}
}
pub fn valid(&self) -> bool {
!self.id.is_empty() && self.version != Version::from([0])
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct CheckOptions {
pub source: InstallSource,
}
#[derive(Clone, Copy, Default, PartialEq, Eq, TypedBuilder)]
pub struct UpdateCheckSchedule {
#[builder(default, setter(into))]
pub last_update_time: Option<PartialComplexTime>,
#[builder(default, setter(into))]
pub last_update_check_time: Option<PartialComplexTime>,
#[builder(default, setter(into))]
pub next_update_time: Option<CheckTiming>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, TypedBuilder)]
pub struct CheckTiming {
#[builder(setter(into))]
pub time: PartialComplexTime,
#[builder(default, setter(strip_option))]
pub minimum_wait: Option<Duration>,
}
pub struct PrettyOptionDisplay<T>(pub Option<T>)
where
T: fmt::Display;
impl<T> fmt::Display for PrettyOptionDisplay<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
None => write!(f, "None"),
Some(value) => fmt::Display::fmt(value, f),
}
}
}
impl<T> fmt::Debug for PrettyOptionDisplay<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl fmt::Debug for UpdateCheckSchedule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UpdateCheckSchedule")
.field(
"last_update_time",
&PrettyOptionDisplay(self.last_update_time),
)
.field(
"next_update_time",
&PrettyOptionDisplay(self.next_update_time),
)
.finish()
}
}
impl fmt::Display for CheckTiming {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.minimum_wait {
None => fmt::Display::fmt(&self.time, f),
Some(wait) => write!(f, "{} wait: {:?}", &self.time, &wait),
}
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ProtocolState {
pub server_dictated_poll_interval: Option<std::time::Duration>,
pub consecutive_failed_update_checks: u32,
pub consecutive_proxied_requests: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
storage::MemStorage,
time::{MockTimeSource, TimeSource},
};
use futures::executor::block_on;
use pretty_assertions::assert_eq;
use std::str::FromStr;
use std::time::SystemTime;
#[test]
fn test_app_new_version() {
let app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(Cohort::from_hint("some-channel"))
.build();
assert_eq!(app.id, "some_id");
assert_eq!(app.version, [1, 2].into());
assert_eq!(app.fingerprint, None);
assert_eq!(app.cohort.hint, Some("some-channel".to_string()));
assert_eq!(app.cohort.name, None);
assert_eq!(app.cohort.id, None);
assert_eq!(app.user_counting, UserCounting::ClientRegulatedByDate(None));
assert!(app.extra_fields.is_empty(), "Extra fields are not empty");
}
#[test]
fn test_app_with_fingerprint() {
let app = App::builder()
.id("some_id_2")
.version([4, 6])
.cohort(Cohort::from_hint("test-channel"))
.fingerprint("some_fp")
.build();
assert_eq!(app.id, "some_id_2");
assert_eq!(app.version, [4, 6].into());
assert_eq!(app.fingerprint, Some("some_fp".to_string()));
assert_eq!(app.cohort.hint, Some("test-channel".to_string()));
assert_eq!(app.cohort.name, None);
assert_eq!(app.cohort.id, None);
assert_eq!(app.user_counting, UserCounting::ClientRegulatedByDate(None));
assert!(app.extra_fields.is_empty(), "Extra fields are not empty");
}
#[test]
fn test_app_with_user_counting() {
let app = App::builder()
.id("some_id_2")
.version([4, 6])
.cohort(Cohort::from_hint("test-channel"))
.user_counting(UserCounting::ClientRegulatedByDate(Some(42)))
.build();
assert_eq!(app.id, "some_id_2");
assert_eq!(app.version, [4, 6].into());
assert_eq!(app.cohort.hint, Some("test-channel".to_string()));
assert_eq!(app.cohort.name, None);
assert_eq!(app.cohort.id, None);
assert_eq!(
app.user_counting,
UserCounting::ClientRegulatedByDate(Some(42))
);
assert!(app.extra_fields.is_empty(), "Extra fields are not empty");
}
#[test]
fn test_app_with_extras() {
let app = App::builder()
.id("some_id_2")
.version([4, 6])
.cohort(Cohort::from_hint("test-channel"))
.extra_fields([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
])
.build();
assert_eq!(app.id, "some_id_2");
assert_eq!(app.version, [4, 6].into());
assert_eq!(app.cohort.hint, Some("test-channel".to_string()));
assert_eq!(app.cohort.name, None);
assert_eq!(app.cohort.id, None);
assert_eq!(app.user_counting, UserCounting::ClientRegulatedByDate(None));
assert_eq!(app.extra_fields.len(), 2);
assert_eq!(app.extra_fields["key1"], "value1");
assert_eq!(app.extra_fields["key2"], "value2");
}
#[test]
fn test_app_load() {
block_on(async {
let mut storage = MemStorage::new();
let json = serde_json::json!({
"cohort": {
"cohort": "some_id",
"cohorthint":"some_hint",
"cohortname": "some_name"
},
"user_counting": {
"ClientRegulatedByDate":123
}});
let json = serde_json::to_string(&json).unwrap();
let mut app = App::builder().id("some_id").version([1, 2]).build();
storage.set_string(&app.id, &json).await.unwrap();
app.load(&storage).await;
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
assert_eq!(cohort, app.cohort);
assert_eq!(
UserCounting::ClientRegulatedByDate(Some(123)),
app.user_counting
);
});
}
#[test]
fn test_app_load_empty_storage() {
block_on(async {
let storage = MemStorage::new();
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
let mut app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.user_counting(UserCounting::ClientRegulatedByDate(Some(123)))
.build();
app.load(&storage).await;
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
assert_eq!(cohort, app.cohort);
assert_eq!(
UserCounting::ClientRegulatedByDate(Some(123)),
app.user_counting
);
});
}
#[test]
fn test_app_load_malformed() {
block_on(async {
let mut storage = MemStorage::new();
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
let mut app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.user_counting(UserCounting::ClientRegulatedByDate(Some(123)))
.build();
storage.set_string(&app.id, "not a json").await.unwrap();
app.load(&storage).await;
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
assert_eq!(cohort, app.cohort);
assert_eq!(
UserCounting::ClientRegulatedByDate(Some(123)),
app.user_counting
);
});
}
#[test]
fn test_app_load_partial() {
block_on(async {
let mut storage = MemStorage::new();
let json = serde_json::json!({
"cohort": {
"cohorthint":"some_hint_2",
"cohortname": "some_name_2"
},
"user_counting": {
"ClientRegulatedByDate":null
}});
let json = serde_json::to_string(&json).unwrap();
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
let mut app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.user_counting(UserCounting::ClientRegulatedByDate(Some(123)))
.build();
storage.set_string(&app.id, &json).await.unwrap();
app.load(&storage).await;
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
assert_eq!(cohort, app.cohort);
assert_eq!(
UserCounting::ClientRegulatedByDate(Some(123)),
app.user_counting
);
});
}
#[test]
fn test_app_load_override() {
block_on(async {
let mut storage = MemStorage::new();
let json = serde_json::json!({
"cohort": {
"cohort": "some_id_2",
"cohorthint":"some_hint_2",
"cohortname": "some_name_2"
},
"user_counting": {
"ClientRegulatedByDate":123
}});
let json = serde_json::to_string(&json).unwrap();
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: None,
};
let mut app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.user_counting(UserCounting::ClientRegulatedByDate(Some(123)))
.build();
storage.set_string(&app.id, &json).await.unwrap();
app.load(&storage).await;
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name_2".to_string()),
};
assert_eq!(cohort, app.cohort);
assert_eq!(
UserCounting::ClientRegulatedByDate(Some(123)),
app.user_counting
);
});
}
#[test]
fn test_app_persist() {
block_on(async {
let mut storage = MemStorage::new();
let cohort = Cohort {
id: Some("some_id".to_string()),
hint: Some("some_hint".to_string()),
name: Some("some_name".to_string()),
};
let app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.user_counting(UserCounting::ClientRegulatedByDate(Some(123)))
.build();
app.persist(&mut storage).await;
let expected = serde_json::json!({
"cohort": {
"cohort": "some_id",
"cohorthint":"some_hint",
"cohortname": "some_name"
},
"user_counting": {
"ClientRegulatedByDate":123
}});
let json = storage.get_string(&app.id).await.unwrap();
assert_eq!(expected, serde_json::Value::from_str(&json).unwrap());
assert!(!storage.committed());
});
}
#[test]
fn test_app_persist_empty() {
block_on(async {
let mut storage = MemStorage::new();
let cohort = Cohort {
id: None,
hint: None,
name: None,
};
let app = App::builder()
.id("some_id")
.version([1, 2])
.cohort(cohort)
.build();
app.persist(&mut storage).await;
let expected = serde_json::json!({
"cohort": {},
"user_counting": {
"ClientRegulatedByDate":null
}});
let json = storage.get_string(&app.id).await.unwrap();
assert_eq!(expected, serde_json::Value::from_str(&json).unwrap());
assert!(!storage.committed());
});
}
#[test]
fn test_app_get_current_channel() {
let cohort = Cohort {
name: Some("current-channel-123".to_string()),
..Cohort::default()
};
let app = App::builder()
.id("some_id")
.version([0, 1])
.cohort(cohort)
.build();
assert_eq!("current-channel-123", app.get_current_channel());
}
#[test]
fn test_app_get_current_channel_default() {
let app = App::builder().id("some_id").version([0, 1]).build();
assert_eq!("", app.get_current_channel());
}
#[test]
fn test_app_get_target_channel() {
let cohort = Cohort::from_hint("target-channel-456");
let app = App::builder()
.id("some_id")
.version([0, 1])
.cohort(cohort)
.build();
assert_eq!("target-channel-456", app.get_target_channel());
}
#[test]
fn test_app_get_target_channel_fallback() {
let cohort = Cohort {
name: Some("current-channel-123".to_string()),
..Cohort::default()
};
let app = App::builder()
.id("some_id")
.version([0, 1])
.cohort(cohort)
.build();
assert_eq!("current-channel-123", app.get_target_channel());
}
#[test]
fn test_app_get_target_channel_default() {
let app = App::builder().id("some_id").version([0, 1]).build();
assert_eq!("", app.get_target_channel());
}
#[test]
fn test_app_set_target_channel() {
let mut app = App::builder().id("some_id").version([0, 1]).build();
assert_eq!("", app.get_target_channel());
app.set_target_channel(Some("new-target-channel".to_string()), None);
assert_eq!("new-target-channel", app.get_target_channel());
app.set_target_channel(None, None);
assert_eq!("", app.get_target_channel());
}
#[test]
fn test_app_set_target_channel_and_id() {
let mut app = App::builder().id("some_id").version([0, 1]).build();
assert_eq!("", app.get_target_channel());
app.set_target_channel(
Some("new-target-channel".to_string()),
Some("new-id".to_string()),
);
assert_eq!("new-target-channel", app.get_target_channel());
assert_eq!("new-id", app.id);
app.set_target_channel(None, None);
assert_eq!("", app.get_target_channel());
assert_eq!("new-id", app.id);
}
#[test]
fn test_app_valid() {
let app = App::builder().id("some_id").version([0, 1]).build();
assert!(app.valid());
}
#[test]
fn test_app_not_valid() {
let app = App::builder().id("").version([0, 1]).build();
assert!(!app.valid());
let app = App::builder().id("some_id").version([0]).build();
assert!(!app.valid());
}
#[test]
fn test_pretty_option_display_with_none() {
assert_eq!(
"None",
format!("{:?}", PrettyOptionDisplay(Option::<String>::None))
);
}
#[test]
fn test_pretty_option_display_with_some() {
assert_eq!(
"this is a test",
format!("{:?}", PrettyOptionDisplay(Some("this is a test")))
);
}
#[test]
fn test_update_check_schedule_debug_with_defaults() {
assert_eq!(
"UpdateCheckSchedule { \
last_update_time: None, \
next_update_time: None \
}",
format!("{:?}", UpdateCheckSchedule::default())
);
}
#[test]
fn test_update_check_schedule_debug_with_values() {
let mock_time = MockTimeSource::new_from_now();
let last = mock_time.now();
let next = last + Duration::from_secs(1000);
assert_eq!(
format!(
"UpdateCheckSchedule {{ last_update_time: {}, next_update_time: {} }}",
PartialComplexTime::from(last),
next
),
format!(
"{:?}",
UpdateCheckSchedule::builder()
.last_update_time(last)
.next_update_time(CheckTiming::builder().time(next).build())
.build()
)
);
}
#[test]
fn test_update_check_schedule_builder_all_fields() {
let mock_time = MockTimeSource::new_from_now();
let now = PartialComplexTime::from(mock_time.now());
assert_eq!(
UpdateCheckSchedule::builder()
.last_update_time(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
))
.next_update_time(
CheckTiming::builder()
.time(now)
.minimum_wait(Duration::from_secs(100))
.build()
)
.build(),
UpdateCheckSchedule {
last_update_time: Some(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
)),
next_update_time: Some(CheckTiming {
time: now,
minimum_wait: Some(Duration::from_secs(100))
}),
..Default::default()
}
);
}
#[test]
fn test_update_check_schedule_builder_all_fields_from_options() {
let next_time = PartialComplexTime::from(MockTimeSource::new_from_now().now());
assert_eq!(
UpdateCheckSchedule::builder()
.last_update_time(Some(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
)))
.next_update_time(Some(
CheckTiming::builder()
.time(next_time)
.minimum_wait(Duration::from_secs(100))
.build()
))
.build(),
UpdateCheckSchedule {
last_update_time: Some(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
)),
next_update_time: Some(CheckTiming {
time: next_time,
minimum_wait: Some(Duration::from_secs(100))
}),
..Default::default()
}
);
}
#[test]
fn test_update_check_schedule_builder_subset_fields() {
assert_eq!(
UpdateCheckSchedule::builder()
.last_update_time(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
))
.build(),
UpdateCheckSchedule {
last_update_time: Some(PartialComplexTime::from(
SystemTime::UNIX_EPOCH + Duration::from_secs(100000)
)),
..Default::default()
}
);
let next_time = PartialComplexTime::from(MockTimeSource::new_from_now().now());
assert_eq!(
UpdateCheckSchedule::builder()
.next_update_time(
CheckTiming::builder()
.time(next_time)
.minimum_wait(Duration::from_secs(5))
.build()
)
.build(),
UpdateCheckSchedule {
next_update_time: Some(CheckTiming {
time: next_time,
minimum_wait: Some(Duration::from_secs(5))
}),
..Default::default()
}
);
}
#[test]
fn test_update_check_schedule_builder_defaults_are_same_as_default_impl() {
assert_eq!(
UpdateCheckSchedule::builder().build(),
UpdateCheckSchedule::default()
);
}
}