mina_p2p_messages/
versioned.rs

1use std::{fmt::Debug, marker::PhantomData};
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{ser::SerializeStruct, Deserialize, Serialize};
5
6/// `Bin_prot` uses integer to represent type version.
7pub type Ver = u32;
8
9/// Wrapper for a type that adds explicit version information.
10#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, MallocSizeOf)]
11pub struct Versioned<T, const V: Ver>(T);
12
13impl<T, const V: Ver> Versioned<T, V> {
14    pub fn inner(&self) -> &T {
15        &self.0
16    }
17
18    pub fn into_inner(self) -> T {
19        self.0
20    }
21}
22
23impl<T, const V: Ver> From<T> for Versioned<T, V> {
24    fn from(t: T) -> Self {
25        Self(t)
26    }
27}
28
29impl<T, const V: Ver> Serialize for Versioned<T, V>
30where
31    T: Serialize,
32{
33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: serde::Serializer,
36    {
37        if serializer.is_human_readable() {
38            return self.0.serialize(serializer);
39        }
40        let mut s = serializer.serialize_struct("MakeVersioned", 2)?;
41        s.serialize_field("version", &V)?;
42        s.serialize_field("t", &self.0)?;
43        s.end()
44    }
45}
46
47impl<'de, T, const V: Ver> Deserialize<'de> for Versioned<T, V>
48where
49    T: Deserialize<'de>,
50{
51    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52    where
53        D: serde::Deserializer<'de>,
54    {
55        if deserializer.is_human_readable() {
56            return T::deserialize(deserializer).map(Self);
57        }
58        struct FieldsVisitor<T, const V: Ver>(PhantomData<T>);
59        impl<'de, T, const V: Ver> serde::de::Visitor<'de> for FieldsVisitor<T, V>
60        where
61            T: Deserialize<'de>,
62        {
63            type Value = T;
64
65            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
66                formatter.write_str("expecting a struct")
67            }
68
69            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
70            where
71                A: serde::de::SeqAccess<'de>,
72            {
73                let version: Ver = seq
74                    .next_element()?
75                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
76                if version != V {
77                    return Err(serde::de::Error::custom(format!(
78                        "invalid version, expecting {}, actual {version}",
79                        V
80                    )));
81                }
82                let t = seq
83                    .next_element()?
84                    .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
85                Ok(t)
86            }
87        }
88
89        const FIELDS: &[&str] = &["version", "t"];
90        deserializer
91            .deserialize_struct(
92                "MakeVersioned",
93                FIELDS,
94                FieldsVisitor::<T, V>(Default::default()),
95            )
96            .map(Self)
97    }
98}
99
100#[derive(Debug)]
101pub struct VersionMismatchError {
102    expected: Ver,
103    actual: Ver,
104}
105
106impl std::fmt::Display for VersionMismatchError {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(
109            f,
110            "version mismatch, expected {}, actual {}",
111            self.expected, self.actual
112        )
113    }
114}
115
116impl std::error::Error for VersionMismatchError {
117    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
118        None
119    }
120}
121
122impl<T, const V: Ver> binprot::BinProtRead for Versioned<T, V>
123where
124    T: binprot::BinProtRead,
125{
126    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
127    where
128        Self: Sized,
129    {
130        let version: Ver = binprot::BinProtRead::binprot_read(r)?;
131        if version != V {
132            return Err(binprot::Error::CustomError(Box::new(
133                VersionMismatchError {
134                    expected: V,
135                    actual: version,
136                },
137            )));
138        }
139        let t: T = binprot::BinProtRead::binprot_read(r)?;
140        Ok(Self(t))
141    }
142}
143
144impl<T, const V: Ver> binprot::BinProtWrite for Versioned<T, V>
145where
146    T: binprot::BinProtWrite,
147{
148    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
149        binprot::BinProtWrite::binprot_write(&V, w)?;
150        binprot::BinProtWrite::binprot_write(&self.0, w)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use binprot::{BinProtRead, BinProtWrite};
157    use binprot_derive::{BinProtRead, BinProtWrite};
158    use serde::{Deserialize, Serialize};
159
160    use crate::versioned::Versioned;
161
162    fn binprot_read<T>(buf: &[u8]) -> Result<(T, &[u8]), binprot::Error>
163    where
164        T: BinProtRead,
165    {
166        let mut rest = buf;
167        let res = T::binprot_read(&mut rest)?;
168        Ok((res, rest))
169    }
170
171    fn binprot_write<T>(t: &T) -> std::io::Result<Vec<u8>>
172    where
173        T: BinProtWrite,
174    {
175        let mut buf = Vec::new();
176        t.binprot_write(&mut buf)?;
177        Ok(buf)
178    }
179
180    #[test]
181    fn binprot() {
182        #[derive(Debug, Serialize, Deserialize, PartialEq, BinProtRead, BinProtWrite)]
183        struct Foo {
184            a: u8,
185            b: u32,
186        }
187
188        for (foo, foo_bin_prot) in [
189            (Foo { a: 0x00, b: 0x00 }, b"\x00\x00" as &[u8]),
190            (Foo { a: 0x01, b: 0x01 }, b"\x01\x01"),
191            (Foo { a: 0x7f, b: 0x7fff }, b"\x7f\xfe\xff\x7f"),
192        ] {
193            let foo_json = serde_json::json!({"a": foo.a, "b": foo.b});
194
195            let bytes = binprot_write(&foo).unwrap();
196            assert_eq!(&bytes, foo_bin_prot);
197
198            let json = serde_json::to_value(&foo).unwrap();
199            assert_eq!(json, foo_json);
200
201            let (foo_de, rest) = binprot_read::<Foo>(foo_bin_prot).unwrap();
202            assert_eq!(rest.len(), 0);
203            assert_eq!(&foo_de, &foo);
204        }
205    }
206
207    #[test]
208    fn binprot_versioned() {
209        #[derive(Debug, Serialize, Deserialize, PartialEq, BinProtRead, BinProtWrite)]
210        struct Foo {
211            a: u8,
212            b: u32,
213        }
214
215        for (foo, foo_bin_prot) in [
216            (Foo { a: 0x00, b: 0x00 }, b"\x01\x00\x00" as &[u8]),
217            (Foo { a: 0x01, b: 0x01 }, b"\x01\x01\x01"),
218            (Foo { a: 0x7f, b: 0x7fff }, b"\x01\x7f\xfe\xff\x7f"),
219        ] {
220            type VersionedFoo = Versioned<Foo, 1>;
221            let foo_json = serde_json::json!({"a": foo.a, "b": foo.b});
222            let foo = Versioned::from(foo);
223
224            let bytes = binprot_write(&foo).unwrap();
225            assert_eq!(&bytes, foo_bin_prot);
226
227            let json = serde_json::to_value(&foo).unwrap();
228            assert_eq!(json, foo_json);
229
230            let (foo_de, rest) = binprot_read::<VersionedFoo>(foo_bin_prot).unwrap();
231            assert_eq!(rest.len(), 0);
232            assert_eq!(&foo_de, &foo);
233        }
234    }
235
236    #[test]
237    fn binprot_version_num_write() {
238        fn versioned<const V: u32>() -> Versioned<(), V> {
239            Versioned(())
240        }
241        assert_eq!(&binprot_write(&versioned::<0>()).unwrap(), b"\x00\x00");
242        assert_eq!(&binprot_write(&versioned::<1>()).unwrap(), b"\x01\x00");
243        assert_eq!(&binprot_write(&versioned::<0x7f>()).unwrap(), b"\x7f\x00");
244        assert_eq!(
245            &binprot_write(&versioned::<0x80>()).unwrap(),
246            b"\xfe\x80\x00\x00"
247        );
248    }
249}