1use omaha_client::protocol::Cohort;
6use serde::{Deserialize, Serialize};
7use std::io;
8
9#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
11pub struct ChannelConfigs {
12 pub default_channel: Option<String>,
13 #[serde(rename = "channels")]
14 pub known_channels: Vec<ChannelConfig>,
15}
16
17impl ChannelConfigs {
18 pub fn validate(&self) -> Result<(), io::Error> {
19 let names: Vec<&str> = self.known_channels.iter().map(|c| c.name.as_str()).collect();
20 if !names.iter().all(|n| Cohort::validate_name(n)) {
21 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid channel name"));
22 }
23 if let Some(default) = &self.default_channel {
24 if !names.contains(&default.as_str()) {
25 return Err(io::Error::new(
26 io::ErrorKind::InvalidData,
27 "default channel not a known channel",
28 ));
29 }
30 }
31 Ok(())
32 }
33
34 pub fn get_default_channel(&self) -> Option<ChannelConfig> {
35 self.default_channel.as_ref().and_then(|default| self.get_channel(default))
36 }
37
38 pub fn get_channel(&self, name: &str) -> Option<ChannelConfig> {
39 self.known_channels.iter().find(|channel_config| channel_config.name == name).cloned()
40 }
41}
42
43#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
44pub struct ChannelConfig {
45 pub name: String,
46 pub repo: String,
47 pub appid: Option<String>,
48 pub check_interval_secs: Option<u64>,
49}
50
51impl ChannelConfig {
52 pub fn new_for_test(name: &str) -> Self {
53 testing::ChannelConfigBuilder::new(name, name.to_owned() + "-repo").build()
54 }
55
56 pub fn with_appid_for_test(name: &str, appid: &str) -> Self {
57 testing::ChannelConfigBuilder::new(name, name.to_owned() + "-repo").appid(appid).build()
58 }
59}
60
61pub mod testing {
62 use super::*;
63 #[derive(Debug, Default)]
64 pub struct ChannelConfigBuilder {
65 name: String,
66 repo: String,
67 appid: Option<String>,
68 check_interval_secs: Option<u64>,
69 }
70
71 impl ChannelConfigBuilder {
72 pub fn new(name: impl Into<String>, repo: impl Into<String>) -> Self {
73 ChannelConfigBuilder {
74 name: name.into(),
75 repo: repo.into(),
76 ..ChannelConfigBuilder::default()
77 }
78 }
79
80 pub fn appid(mut self, appid: impl Into<String>) -> Self {
81 self.appid = Some(appid.into());
82 self
83 }
84
85 pub fn check_interval_secs(mut self, check_interval_secs: u64) -> Self {
86 self.check_interval_secs = Some(check_interval_secs);
87 self
88 }
89
90 pub fn build(self) -> ChannelConfig {
91 ChannelConfig {
92 name: self.name,
93 repo: self.repo,
94 appid: self.appid,
95 check_interval_secs: self.check_interval_secs,
96 }
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use pretty_assertions::assert_eq;
105
106 #[test]
107 fn test_channel_configs_get_default() {
108 let configs = ChannelConfigs {
109 default_channel: Some("default_channel".to_string()),
110 known_channels: vec![
111 ChannelConfig::new_for_test("some_channel"),
112 ChannelConfig::new_for_test("default_channel"),
113 ChannelConfig::new_for_test("other"),
114 ],
115 };
116 assert_eq!(configs.get_default_channel().unwrap(), configs.known_channels[1]);
117 }
118
119 #[test]
120 fn test_channel_configs_get_default_none() {
121 let configs = ChannelConfigs {
122 default_channel: None,
123 known_channels: vec![ChannelConfig::new_for_test("some_channel")],
124 };
125 assert_eq!(configs.get_default_channel(), None);
126 }
127
128 #[test]
129 fn test_channel_configs_get_channel() {
130 let configs = ChannelConfigs {
131 default_channel: Some("default_channel".to_string()),
132 known_channels: vec![
133 ChannelConfig::new_for_test("some_channel"),
134 ChannelConfig::new_for_test("default_channel"),
135 ChannelConfig::new_for_test("other"),
136 ],
137 };
138 assert_eq!(configs.get_channel("other").unwrap(), configs.known_channels[2]);
139 }
140
141 #[test]
142 fn test_channel_configs_get_channel_missing() {
143 let configs = ChannelConfigs {
144 default_channel: Some("default_channel".to_string()),
145 known_channels: vec![
146 ChannelConfig::new_for_test("some_channel"),
147 ChannelConfig::new_for_test("default_channel"),
148 ChannelConfig::new_for_test("other"),
149 ],
150 };
151 assert_eq!(configs.get_channel("missing"), None);
152 }
153
154 #[test]
155 fn test_channel_cfg_builder_app_id() {
156 let config = testing::ChannelConfigBuilder::new("name", "repo").appid("appid").build();
157 assert_eq!("name", config.name);
158 assert_eq!("repo", config.repo);
159 assert_eq!(Some("appid".to_owned()), config.appid);
160 assert_eq!(None, config.check_interval_secs);
161 }
162
163 #[test]
164 fn test_channel_cfg_builder_check_interval() {
165 let config = testing::ChannelConfigBuilder::new("name", "repo")
166 .appid("appid")
167 .check_interval_secs(3600)
168 .build();
169 assert_eq!("name", config.name);
170 assert_eq!("repo", config.repo);
171 assert_eq!(Some("appid".to_owned()), config.appid);
172 assert_eq!(Some(3600), config.check_interval_secs);
173 }
174}