1use std::{fmt::Debug, marker::PhantomData};
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{ser::SerializeStruct, Deserialize, Serialize};
5
6pub type Ver = u32;
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, MallocSizeOf)]
11pub struct Versioned<T, const V: Ver>(T);
12
13impl<T, const V: Ver> Versioned<T, V> {
14 pub fn inner(&self) -> &T {
15 &self.0
16 }
17
18 pub fn into_inner(self) -> T {
19 self.0
20 }
21}
22
23impl<T, const V: Ver> From<T> for Versioned<T, V> {
24 fn from(t: T) -> Self {
25 Self(t)
26 }
27}
28
29impl<T, const V: Ver> Serialize for Versioned<T, V>
30where
31 T: Serialize,
32{
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: serde::Serializer,
36 {
37 if serializer.is_human_readable() {
38 return self.0.serialize(serializer);
39 }
40 let mut s = serializer.serialize_struct("MakeVersioned", 2)?;
41 s.serialize_field("version", &V)?;
42 s.serialize_field("t", &self.0)?;
43 s.end()
44 }
45}
46
47impl<'de, T, const V: Ver> Deserialize<'de> for Versioned<T, V>
48where
49 T: Deserialize<'de>,
50{
51 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52 where
53 D: serde::Deserializer<'de>,
54 {
55 if deserializer.is_human_readable() {
56 return T::deserialize(deserializer).map(Self);
57 }
58 struct FieldsVisitor<T, const V: Ver>(PhantomData<T>);
59 impl<'de, T, const V: Ver> serde::de::Visitor<'de> for FieldsVisitor<T, V>
60 where
61 T: Deserialize<'de>,
62 {
63 type Value = T;
64
65 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
66 formatter.write_str("expecting a struct")
67 }
68
69 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
70 where
71 A: serde::de::SeqAccess<'de>,
72 {
73 let version: Ver = seq
74 .next_element()?
75 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
76 if version != V {
77 return Err(serde::de::Error::custom(format!(
78 "invalid version, expecting {}, actual {version}",
79 V
80 )));
81 }
82 let t = seq
83 .next_element()?
84 .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
85 Ok(t)
86 }
87 }
88
89 const FIELDS: &[&str] = &["version", "t"];
90 deserializer
91 .deserialize_struct(
92 "MakeVersioned",
93 FIELDS,
94 FieldsVisitor::<T, V>(Default::default()),
95 )
96 .map(Self)
97 }
98}
99
100#[derive(Debug)]
101pub struct VersionMismatchError {
102 expected: Ver,
103 actual: Ver,
104}
105
106impl std::fmt::Display for VersionMismatchError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(
109 f,
110 "version mismatch, expected {}, actual {}",
111 self.expected, self.actual
112 )
113 }
114}
115
116impl std::error::Error for VersionMismatchError {
117 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
118 None
119 }
120}
121
122impl<T, const V: Ver> binprot::BinProtRead for Versioned<T, V>
123where
124 T: binprot::BinProtRead,
125{
126 fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
127 where
128 Self: Sized,
129 {
130 let version: Ver = binprot::BinProtRead::binprot_read(r)?;
131 if version != V {
132 return Err(binprot::Error::CustomError(Box::new(
133 VersionMismatchError {
134 expected: V,
135 actual: version,
136 },
137 )));
138 }
139 let t: T = binprot::BinProtRead::binprot_read(r)?;
140 Ok(Self(t))
141 }
142}
143
144impl<T, const V: Ver> binprot::BinProtWrite for Versioned<T, V>
145where
146 T: binprot::BinProtWrite,
147{
148 fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
149 binprot::BinProtWrite::binprot_write(&V, w)?;
150 binprot::BinProtWrite::binprot_write(&self.0, w)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use binprot::{BinProtRead, BinProtWrite};
157 use binprot_derive::{BinProtRead, BinProtWrite};
158 use serde::{Deserialize, Serialize};
159
160 use crate::versioned::Versioned;
161
162 fn binprot_read<T>(buf: &[u8]) -> Result<(T, &[u8]), binprot::Error>
163 where
164 T: BinProtRead,
165 {
166 let mut rest = buf;
167 let res = T::binprot_read(&mut rest)?;
168 Ok((res, rest))
169 }
170
171 fn binprot_write<T>(t: &T) -> std::io::Result<Vec<u8>>
172 where
173 T: BinProtWrite,
174 {
175 let mut buf = Vec::new();
176 t.binprot_write(&mut buf)?;
177 Ok(buf)
178 }
179
180 #[test]
181 fn binprot() {
182 #[derive(Debug, Serialize, Deserialize, PartialEq, BinProtRead, BinProtWrite)]
183 struct Foo {
184 a: u8,
185 b: u32,
186 }
187
188 for (foo, foo_bin_prot) in [
189 (Foo { a: 0x00, b: 0x00 }, b"\x00\x00" as &[u8]),
190 (Foo { a: 0x01, b: 0x01 }, b"\x01\x01"),
191 (Foo { a: 0x7f, b: 0x7fff }, b"\x7f\xfe\xff\x7f"),
192 ] {
193 let foo_json = serde_json::json!({"a": foo.a, "b": foo.b});
194
195 let bytes = binprot_write(&foo).unwrap();
196 assert_eq!(&bytes, foo_bin_prot);
197
198 let json = serde_json::to_value(&foo).unwrap();
199 assert_eq!(json, foo_json);
200
201 let (foo_de, rest) = binprot_read::<Foo>(foo_bin_prot).unwrap();
202 assert_eq!(rest.len(), 0);
203 assert_eq!(&foo_de, &foo);
204 }
205 }
206
207 #[test]
208 fn binprot_versioned() {
209 #[derive(Debug, Serialize, Deserialize, PartialEq, BinProtRead, BinProtWrite)]
210 struct Foo {
211 a: u8,
212 b: u32,
213 }
214
215 for (foo, foo_bin_prot) in [
216 (Foo { a: 0x00, b: 0x00 }, b"\x01\x00\x00" as &[u8]),
217 (Foo { a: 0x01, b: 0x01 }, b"\x01\x01\x01"),
218 (Foo { a: 0x7f, b: 0x7fff }, b"\x01\x7f\xfe\xff\x7f"),
219 ] {
220 type VersionedFoo = Versioned<Foo, 1>;
221 let foo_json = serde_json::json!({"a": foo.a, "b": foo.b});
222 let foo = Versioned::from(foo);
223
224 let bytes = binprot_write(&foo).unwrap();
225 assert_eq!(&bytes, foo_bin_prot);
226
227 let json = serde_json::to_value(&foo).unwrap();
228 assert_eq!(json, foo_json);
229
230 let (foo_de, rest) = binprot_read::<VersionedFoo>(foo_bin_prot).unwrap();
231 assert_eq!(rest.len(), 0);
232 assert_eq!(&foo_de, &foo);
233 }
234 }
235
236 #[test]
237 fn binprot_version_num_write() {
238 fn versioned<const V: u32>() -> Versioned<(), V> {
239 Versioned(())
240 }
241 assert_eq!(&binprot_write(&versioned::<0>()).unwrap(), b"\x00\x00");
242 assert_eq!(&binprot_write(&versioned::<1>()).unwrap(), b"\x01\x00");
243 assert_eq!(&binprot_write(&versioned::<0x7f>()).unwrap(), b"\x7f\x00");
244 assert_eq!(
245 &binprot_write(&versioned::<0x80>()).unwrap(),
246 b"\xfe\x80\x00\x00"
247 );
248 }
249}