1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright 2019 The Fuchsia Authors
//
// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
// This file may not be copied, modified, or distributed except according to
// those terms.

use serde::{Deserialize, Serialize};

pub mod request;
pub mod response;

pub const PROTOCOL_V3: &str = "3.0";

/// The cohort identifies the update 'track' or 'channel', and is used to implement the tracking of
/// membership in a fractional roll-out.  This is per-application data.
///
/// This is sent to Omaha to identify the cohort that the application is in.  This is returned (with
/// possibly new values) by Omaha to indicate that the application is now in a different cohort.  On
/// the next update check for that application, the updater needs to use this newly returned cohort
/// as the one that it sends to Omaha with that application.
///
/// For more information about cohorts, see the 'cohort', 'cohorthint', and 'cohortname' attributes
/// of the Request.App object at:
///
/// https://github.com/google/omaha/blob/HEAD/doc/ServerProtocolV3.md#app-request
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct Cohort {
    /// This is the cohort id itself.
    #[serde(rename = "cohort")]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub id: Option<String>,

    #[serde(rename = "cohorthint")]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub hint: Option<String>,

    #[serde(rename = "cohortname")]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
}

impl Cohort {
    /// Create a new Cohort instance from just a cohort id (channel name).
    pub fn new(id: &str) -> Cohort {
        Cohort {
            id: Some(id.to_string()),
            hint: None,
            name: None,
        }
    }

    pub fn from_hint(hint: &str) -> Cohort {
        Cohort {
            id: None,
            hint: Some(hint.to_string()),
            name: None,
        }
    }

    pub fn update_from_omaha(&mut self, omaha_cohort: Self) {
        // From Omaha spec:
        // If this attribute is transmitted in the response (even if the value is empty-string),
        // the client should overwrite the current cohort of this app with the sent value.
        if omaha_cohort.id.is_some() {
            self.id = omaha_cohort.id;
        }
        if omaha_cohort.hint.is_some() {
            self.hint = omaha_cohort.hint;
        }
        if omaha_cohort.name.is_some() {
            self.name = omaha_cohort.name;
        }
    }

    /// A validation function to test that a given Cohort hint or name is valid per the Omaha spec:
    ///  1-1024 ascii characters, with values in the range [\u20-\u7e].
    pub fn validate_name(name: &str) -> bool {
        !name.is_empty()
            && name.len() <= 1024
            && name.chars().all(|c| ('\u{20}'..='\u{7e}').contains(&c))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cohort_new() {
        let cohort = Cohort::new("my_cohort");
        assert_eq!(Some("my_cohort".to_string()), cohort.id);
        assert_eq!(None, cohort.hint);
        assert_eq!(None, cohort.name);
    }

    #[test]
    fn test_cohort_update_from_omaha() {
        let mut cohort = Cohort::from_hint("hint");
        let omaha_cohort = Cohort::new("my_cohort");
        cohort.update_from_omaha(omaha_cohort);
        assert_eq!(Some("my_cohort".to_string()), cohort.id);
        assert_eq!(Some("hint".to_string()), cohort.hint);
        assert_eq!(None, cohort.name);
    }

    #[test]
    fn test_cohort_update_from_omaha_none() {
        let mut cohort = Cohort {
            id: Some("id".to_string()),
            hint: Some("hint".to_string()),
            name: Some("name".to_string()),
        };
        let expected_cohort = cohort.clone();
        cohort.update_from_omaha(Cohort::default());
        assert_eq!(cohort, expected_cohort);
    }

    #[test]
    fn test_valid_cohort_names() {
        assert!(Cohort::validate_name("some-channel"));
        assert!(Cohort::validate_name("a"));

        let max_len_name = "a".repeat(1024);
        assert!(Cohort::validate_name(&max_len_name));
    }

    #[test]
    fn test_invalid_cohort_name_length() {
        assert!(!Cohort::validate_name(""));

        let too_long_name = "a".repeat(1025);
        assert!(!Cohort::validate_name(&too_long_name));
    }

    #[test]
    fn test_invalid_cohort_name_chars() {
        assert!(!Cohort::validate_name("some\u{09}channel"));
        assert!(!Cohort::validate_name("some\u{07f}channel"));
        assert!(!Cohort::validate_name("some\u{080}channel"));
    }
}