channel_config/
lib.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use omaha_client::protocol::Cohort;
6use serde::{Deserialize, Serialize};
7use std::io;
8
9/// Wrapper for deserializing repository configs to the on-disk JSON format.
10#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
11pub struct ChannelConfigs {
12    pub default_channel: Option<String>,
13    #[serde(rename = "channels")]
14    pub known_channels: Vec<ChannelConfig>,
15}
16
17impl ChannelConfigs {
18    pub fn validate(&self) -> Result<(), io::Error> {
19        let names: Vec<&str> = self.known_channels.iter().map(|c| c.name.as_str()).collect();
20        if !names.iter().all(|n| Cohort::validate_name(n)) {
21            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid channel name"));
22        }
23        if let Some(default) = &self.default_channel {
24            if !names.contains(&default.as_str()) {
25                return Err(io::Error::new(
26                    io::ErrorKind::InvalidData,
27                    "default channel not a known channel",
28                ));
29            }
30        }
31        Ok(())
32    }
33
34    pub fn get_default_channel(&self) -> Option<ChannelConfig> {
35        self.default_channel.as_ref().and_then(|default| self.get_channel(default))
36    }
37
38    pub fn get_channel(&self, name: &str) -> Option<ChannelConfig> {
39        self.known_channels.iter().find(|channel_config| channel_config.name == name).cloned()
40    }
41}
42
43#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
44pub struct ChannelConfig {
45    pub name: String,
46    pub repo: String,
47    pub appid: Option<String>,
48    pub check_interval_secs: Option<u64>,
49}
50
51impl ChannelConfig {
52    pub fn new_for_test(name: &str) -> Self {
53        testing::ChannelConfigBuilder::new(name, name.to_owned() + "-repo").build()
54    }
55
56    pub fn with_appid_for_test(name: &str, appid: &str) -> Self {
57        testing::ChannelConfigBuilder::new(name, name.to_owned() + "-repo").appid(appid).build()
58    }
59}
60
61pub mod testing {
62    use super::*;
63    #[derive(Debug, Default)]
64    pub struct ChannelConfigBuilder {
65        name: String,
66        repo: String,
67        appid: Option<String>,
68        check_interval_secs: Option<u64>,
69    }
70
71    impl ChannelConfigBuilder {
72        pub fn new(name: impl Into<String>, repo: impl Into<String>) -> Self {
73            ChannelConfigBuilder {
74                name: name.into(),
75                repo: repo.into(),
76                ..ChannelConfigBuilder::default()
77            }
78        }
79
80        pub fn appid(mut self, appid: impl Into<String>) -> Self {
81            self.appid = Some(appid.into());
82            self
83        }
84
85        pub fn check_interval_secs(mut self, check_interval_secs: u64) -> Self {
86            self.check_interval_secs = Some(check_interval_secs);
87            self
88        }
89
90        pub fn build(self) -> ChannelConfig {
91            ChannelConfig {
92                name: self.name,
93                repo: self.repo,
94                appid: self.appid,
95                check_interval_secs: self.check_interval_secs,
96            }
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use pretty_assertions::assert_eq;
105
106    #[test]
107    fn test_channel_configs_get_default() {
108        let configs = ChannelConfigs {
109            default_channel: Some("default_channel".to_string()),
110            known_channels: vec![
111                ChannelConfig::new_for_test("some_channel"),
112                ChannelConfig::new_for_test("default_channel"),
113                ChannelConfig::new_for_test("other"),
114            ],
115        };
116        assert_eq!(configs.get_default_channel().unwrap(), configs.known_channels[1]);
117    }
118
119    #[test]
120    fn test_channel_configs_get_default_none() {
121        let configs = ChannelConfigs {
122            default_channel: None,
123            known_channels: vec![ChannelConfig::new_for_test("some_channel")],
124        };
125        assert_eq!(configs.get_default_channel(), None);
126    }
127
128    #[test]
129    fn test_channel_configs_get_channel() {
130        let configs = ChannelConfigs {
131            default_channel: Some("default_channel".to_string()),
132            known_channels: vec![
133                ChannelConfig::new_for_test("some_channel"),
134                ChannelConfig::new_for_test("default_channel"),
135                ChannelConfig::new_for_test("other"),
136            ],
137        };
138        assert_eq!(configs.get_channel("other").unwrap(), configs.known_channels[2]);
139    }
140
141    #[test]
142    fn test_channel_configs_get_channel_missing() {
143        let configs = ChannelConfigs {
144            default_channel: Some("default_channel".to_string()),
145            known_channels: vec![
146                ChannelConfig::new_for_test("some_channel"),
147                ChannelConfig::new_for_test("default_channel"),
148                ChannelConfig::new_for_test("other"),
149            ],
150        };
151        assert_eq!(configs.get_channel("missing"), None);
152    }
153
154    #[test]
155    fn test_channel_cfg_builder_app_id() {
156        let config = testing::ChannelConfigBuilder::new("name", "repo").appid("appid").build();
157        assert_eq!("name", config.name);
158        assert_eq!("repo", config.repo);
159        assert_eq!(Some("appid".to_owned()), config.appid);
160        assert_eq!(None, config.check_interval_secs);
161    }
162
163    #[test]
164    fn test_channel_cfg_builder_check_interval() {
165        let config = testing::ChannelConfigBuilder::new("name", "repo")
166            .appid("appid")
167            .check_interval_secs(3600)
168            .build();
169        assert_eq!("name", config.name);
170        assert_eq!("repo", config.repo);
171        assert_eq!(Some("appid".to_owned()), config.appid);
172        assert_eq!(Some(3600), config.check_interval_secs);
173    }
174}