omaha_client/
version.rs

1// Copyright 2020 The Fuchsia Authors
2//
3// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6// This file may not be copied, modified, or distributed except according to
7// those terms.
8
9use {
10    itertools::Itertools,
11    serde::{
12        de::{self, Visitor},
13        Deserialize, Deserializer, Serialize, Serializer,
14    },
15    std::{fmt, str::FromStr},
16};
17
18/// This is a utility wrapper around Omaha-style versions - in the form of A.B.C.D, A.B.C, A.B or A.
19#[derive(Clone, Copy, Eq, Ord, PartialOrd, PartialEq)]
20pub struct Version([u32; 4]);
21
22impl fmt::Display for Version {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        write!(f, "{}", self.0.iter().format("."))
25    }
26}
27
28impl fmt::Debug for Version {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        // The Debug trait just forwards to the Display trait implementation for this type
31        fmt::Display::fmt(self, f)
32    }
33}
34
35#[derive(Debug, thiserror::Error)]
36struct TooManyNumbersError;
37
38impl fmt::Display for TooManyNumbersError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.write_str("Too many numbers in version, the maximum is 4.")
41    }
42}
43
44impl FromStr for Version {
45    type Err = anyhow::Error;
46
47    fn from_str(s: &str) -> Result<Self, Self::Err> {
48        let nums = s.split('.').map(|s| s.parse::<u32>());
49
50        let mut array: [u32; 4] = [0; 4];
51        for (i, v) in nums.enumerate() {
52            if i >= 4 {
53                return Err(TooManyNumbersError.into());
54            }
55            array[i] = v?;
56        }
57        Ok(Version(array))
58    }
59}
60
61macro_rules! impl_from {
62    ($($t:ty),+) => {
63        $(
64            impl From<$t> for Version {
65                fn from(v: $t) -> Self {
66                    let mut array: [u32; 4] = [0; 4];
67                    array.split_at_mut(v.len()).0.copy_from_slice(&v);
68                    Version(array)
69                }
70            }
71        )+
72    }
73}
74impl_from!([u32; 1], [u32; 2], [u32; 3], [u32; 4]);
75
76struct VersionVisitor;
77
78impl<'de> Visitor<'de> for VersionVisitor {
79    type Value = Version;
80
81    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
82        formatter.write_str("a string of the format A.B.C.D")
83    }
84
85    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
86    where
87        E: de::Error,
88    {
89        Version::from_str(v).map_err(de::Error::custom)
90    }
91}
92
93impl<'de> Deserialize<'de> for Version {
94    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95    where
96        D: Deserializer<'de>,
97    {
98        deserializer.deserialize_str(VersionVisitor)
99    }
100}
101
102impl Serialize for Version {
103    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104    where
105        S: Serializer,
106    {
107        serializer.serialize_str(&self.to_string())
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_version_display() {
117        let version = Version::from([1, 2, 3, 4]);
118        assert_eq!("1.2.3.4", version.to_string());
119
120        let version = Version::from([0, 6, 4, 7]);
121        assert_eq!("0.6.4.7", version.to_string());
122    }
123
124    #[test]
125    fn test_version_debug() {
126        let version = Version::from([1, 2, 3, 4]);
127        assert_eq!("1.2.3.4", format!("{version:?}"));
128
129        let version = Version::from([0, 6, 4, 7]);
130        assert_eq!("0.6.4.7", format!("{version:?}"));
131    }
132
133    #[test]
134    fn test_version_parse() {
135        let version = Version::from([1, 2, 3, 4]);
136        assert_eq!("1.2.3.4".parse::<Version>().unwrap(), version);
137
138        let version = Version::from([6, 4, 7]);
139        assert_eq!("6.4.7".parse::<Version>().unwrap(), version);
140
141        let version = Version::from([999]);
142        assert_eq!("999".parse::<Version>().unwrap(), version);
143    }
144
145    #[test]
146    fn test_version_parse_leading_zeros() {
147        let version = Version::from([1, 2, 3, 4]);
148        assert_eq!("1.02.003.0004".parse::<Version>().unwrap(), version);
149
150        let version = Version::from([6, 4, 7]);
151        assert_eq!("06.4.07".parse::<Version>().unwrap(), version);
152
153        let version = Version::from([999]);
154        assert_eq!("0000999".parse::<Version>().unwrap(), version);
155    }
156
157    #[test]
158    fn test_version_parse_error() {
159        assert!("1.2.3.4.5".parse::<Version>().is_err());
160        assert!("1.2.".parse::<Version>().is_err());
161        assert!(".1.2".parse::<Version>().is_err());
162        assert!("-1".parse::<Version>().is_err());
163        assert!("abc".parse::<Version>().is_err());
164        assert!(".".parse::<Version>().is_err());
165        assert!("".parse::<Version>().is_err());
166        assert!("999999999999999999999999".parse::<Version>().is_err());
167    }
168
169    #[test]
170    fn test_version_to_string() {
171        assert_eq!(&"1.2".parse::<Version>().unwrap().to_string(), "1.2.0.0");
172        assert_eq!(
173            &"1.2.3.4".parse::<Version>().unwrap().to_string(),
174            "1.2.3.4"
175        );
176        assert_eq!(&"1".parse::<Version>().unwrap().to_string(), "1.0.0.0");
177        assert_eq!(&"3.2.1".parse::<Version>().unwrap().to_string(), "3.2.1.0");
178    }
179
180    #[test]
181    fn test_version_compare() {
182        assert!(Version::from([1, 2, 3, 4]) < Version::from([2, 0, 3]));
183        assert!(Version::from([1, 2, 3]) < Version::from([1, 2, 3, 4]));
184        assert!(Version::from([1, 0]) == Version::from([1, 0, 0]));
185        assert!(Version::from([1, 0]) <= Version::from([1, 0, 0]));
186        assert!(Version::from([1, 0]) >= Version::from([1, 0, 0]));
187        assert!(Version::from([1]) == Version::from([1, 0, 0, 0]));
188        assert!(Version::from([0]) == Version::from([0, 0, 0, 0]));
189        assert!(Version::from([0, 1, 0]) > Version::from([0, 0, 1, 0]));
190        assert!(Version::from([0]) < Version::from([0, 0, 1, 0]));
191        assert!(Version::from([1]) < Version::from([1, 0, 1, 0]));
192        assert!(Version::from([1, 0]) < Version::from([1, 0, 0, 1]));
193        assert!(Version::from([1, 0, 0]) > Version::from([0, 1, 2, 0]));
194    }
195
196    #[test]
197    fn test_version_deserialize() {
198        let v: Version = serde_json::from_str(r#""1.2.3.4""#).unwrap();
199        assert_eq!(v, Version::from([1, 2, 3, 4]));
200        let v: Version = serde_json::from_str(r#""1.2.3""#).unwrap();
201        assert_eq!(v, Version::from([1, 2, 3]));
202        serde_json::from_str::<Version>(r#""1.2.3.4.5""#)
203            .expect_err("Parsing invalid version should fail");
204    }
205
206    #[test]
207    fn test_version_serialize() {
208        let v = Version::from([1, 2, 3, 4]);
209        assert_eq!(serde_json::to_string(&v).unwrap(), r#""1.2.3.4""#);
210        let v = Version::from([1, 2, 3]);
211        assert_eq!(serde_json::to_string(&v).unwrap(), r#""1.2.3.0""#);
212    }
213}