bitflags_serde_legacy/
lib.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5/// Helper macro that generates `serde` implementations for a `bitflags` 2.x types that are
6/// compatible with the `bitflags` 1.x serialization format.
7#[macro_export]
8macro_rules! impl_traits {
9    ($ty:ident) => {
10        mod __private_bitflags_serde_legacy {
11            use ::bitflags::Flags;
12            use ::serde::{de, Deserialize, Serialize};
13            use ::std::marker::PhantomData;
14
15            #[derive(Serialize, Deserialize)]
16            struct Helper {
17                bits: <super::$ty as Flags>::Bits,
18            }
19
20            impl ::serde::Serialize for super::$ty {
21                fn serialize<S: ::serde::Serializer>(
22                    &self,
23                    serializer: S,
24                ) -> Result<S::Ok, S::Error> {
25                    let helper = Helper { bits: self.bits() };
26                    ::serde::Serialize::serialize(&helper, serializer)
27                }
28            }
29
30            impl<'de> ::serde::Deserialize<'de> for super::$ty {
31                fn deserialize<D: ::serde::Deserializer<'de>>(
32                    deserializer: D,
33                ) -> Result<Self, D::Error> {
34                    if deserializer.is_human_readable() {
35                        deserializer.deserialize_any(HumanReadableVisitor(PhantomData))
36                    } else {
37                        deserializer.deserialize_any(BinaryVisitor(PhantomData))
38                    }
39                }
40            }
41
42            struct HumanReadableVisitor(PhantomData<<super::$ty as Flags>::Bits>);
43
44            impl<'de> de::Visitor<'de> for HumanReadableVisitor {
45                type Value = super::$ty;
46
47                fn expecting(
48                    &self,
49                    formatter: &mut ::std::fmt::Formatter<'_>,
50                ) -> ::std::fmt::Result {
51                    formatter.write_str("string or map")
52                }
53
54                fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
55                where
56                    E: de::Error,
57                {
58                    ::bitflags::parser::from_str(value).map_err(|e| E::custom(e))
59                }
60
61                fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
62                where
63                    M: de::MapAccess<'de>,
64                {
65                    let helper = Helper::deserialize(de::value::MapAccessDeserializer::new(map))?;
66                    Ok(Flags::from_bits_retain(helper.bits))
67                }
68            }
69
70            struct BinaryVisitor(PhantomData<<super::$ty as Flags>::Bits>);
71
72            macro_rules! delegate_binary_visitor {
73                ($method:ident, $value_ty:ty, $deserializer:ident) => {
74                    fn $method<E>(self, value: $value_ty) -> Result<Self::Value, E>
75                    where
76                        E: de::Error,
77                    {
78                        let bits: <super::$ty as Flags>::Bits =
79                            Deserialize::deserialize(de::value::$deserializer::new(value))?;
80                        Ok(Flags::from_bits_retain(bits))
81                    }
82                };
83            }
84
85            impl<'de> de::Visitor<'de> for BinaryVisitor {
86                type Value = super::$ty;
87
88                fn expecting(
89                    &self,
90                    formatter: &mut ::std::fmt::Formatter<'_>,
91                ) -> ::std::fmt::Result {
92                    formatter.write_str("string or map")
93                }
94
95                delegate_binary_visitor!(visit_i8, i8, I8Deserializer);
96                delegate_binary_visitor!(visit_i16, i16, I16Deserializer);
97                delegate_binary_visitor!(visit_i32, i32, I32Deserializer);
98                delegate_binary_visitor!(visit_i64, i64, I64Deserializer);
99                delegate_binary_visitor!(visit_i128, i128, I128Deserializer);
100                delegate_binary_visitor!(visit_u8, u8, U8Deserializer);
101                delegate_binary_visitor!(visit_u16, u16, U16Deserializer);
102                delegate_binary_visitor!(visit_u32, u32, U32Deserializer);
103                delegate_binary_visitor!(visit_u64, u64, U64Deserializer);
104                delegate_binary_visitor!(visit_u128, u128, U128Deserializer);
105                delegate_binary_visitor!(visit_char, char, CharDeserializer);
106                delegate_binary_visitor!(visit_str, &str, StrDeserializer);
107                delegate_binary_visitor!(visit_borrowed_str, &'de str, BorrowedStrDeserializer);
108                delegate_binary_visitor!(visit_string, String, StringDeserializer);
109                delegate_binary_visitor!(visit_bytes, &[u8], BytesDeserializer);
110                delegate_binary_visitor!(
111                    visit_borrowed_bytes,
112                    &'de [u8],
113                    BorrowedBytesDeserializer
114                );
115
116                fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
117                where
118                    M: de::MapAccess<'de>,
119                {
120                    let helper = Helper::deserialize(de::value::MapAccessDeserializer::new(map))?;
121                    Ok(Flags::from_bits_retain(helper.bits))
122                }
123            }
124        }
125    };
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use serde::{Deserialize, Serialize};
132
133    bitflags::bitflags! {
134        #[derive(Debug, PartialEq)]
135        pub struct LegacyFlags: u32 {
136            const A = 0b00000001;
137            const B = 0b00000010;
138        }
139    }
140
141    impl_traits!(LegacyFlags);
142
143    bitflags::bitflags! {
144        #[derive(Debug, PartialEq, Serialize, Deserialize)]
145        pub struct NewFlags: u32 {
146            const A = 0b00000001;
147            const B = 0b00000010;
148        }
149    }
150
151    #[test]
152    fn test_roundtrip() {
153        let flags = LegacyFlags::A | LegacyFlags::B;
154
155        let flags_json = serde_json::to_value(&flags).unwrap();
156        assert_eq!(flags_json, serde_json::json!({ "bits": 3 }));
157
158        let flags: LegacyFlags = serde_json::from_value(flags_json).unwrap();
159        assert_eq!(flags, LegacyFlags::A | LegacyFlags::B);
160    }
161
162    #[test]
163    fn test_parses_new_human_readable_format() {
164        let new_flags = NewFlags::A | NewFlags::B;
165        let new_flags_json = serde_json::to_value(&new_flags).unwrap();
166        assert_eq!(new_flags_json, serde_json::json!("A | B"));
167
168        let legacy_flags: LegacyFlags = serde_json::from_value(new_flags_json).unwrap();
169        assert_eq!(legacy_flags, LegacyFlags::A | LegacyFlags::B);
170    }
171
172    #[test]
173    fn test_parses_new_binary_format() {
174        let new_flags = NewFlags::A | NewFlags::B;
175        let mut new_flags_cbor: Vec<u8> = Vec::new();
176        ciborium::into_writer(&new_flags, &mut new_flags_cbor).unwrap();
177        let legacy_flags: LegacyFlags = ciborium::from_reader(new_flags_cbor.as_slice()).unwrap();
178        assert_eq!(legacy_flags, LegacyFlags::A | LegacyFlags::B);
179    }
180}