mina_p2p_messages/
number.rs

1use std::{fmt::Display, marker::PhantomData, str::FromStr};
2
3use malloc_size_of::MallocSizeOf;
4use serde::{de::Visitor, Deserialize, Serialize};
5
6#[derive(
7    Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, derive_more::From, derive_more::Deref,
8)]
9pub struct Number<T>(pub T);
10
11impl<T: std::fmt::Debug> std::fmt::Debug for Number<T> {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        // Avoid vertical alignment
14        f.write_fmt(format_args!("Number({inner:?})", inner = self.0))
15    }
16}
17
18impl<T> MallocSizeOf for Number<T> {
19    fn size_of(&self, _ops: &mut malloc_size_of::MallocSizeOfOps) -> usize {
20        0
21    }
22}
23
24pub type Int32 = Number<i32>;
25pub type UInt32 = Number<u32>;
26pub type Int64 = Number<i64>;
27pub type UInt64 = Number<u64>;
28pub type Float64 = Number<f64>;
29
30impl Int32 {
31    pub const fn as_u32(&self) -> u32 {
32        self.0 as u32
33    }
34}
35
36impl Int64 {
37    pub const fn as_u64(&self) -> u64 {
38        self.0 as u64
39    }
40}
41
42impl UInt32 {
43    pub const fn as_u32(&self) -> u32 {
44        self.0
45    }
46}
47
48impl UInt64 {
49    pub const fn as_u64(&self) -> u64 {
50        self.0
51    }
52}
53
54impl From<u32> for Number<i32> {
55    fn from(value: u32) -> Self {
56        Self(value as i32)
57    }
58}
59
60impl From<u64> for Number<i64> {
61    fn from(value: u64) -> Self {
62        Self(value as i64)
63    }
64}
65
66impl From<&u32> for Number<i32> {
67    fn from(value: &u32) -> Self {
68        Self(*value as i32)
69    }
70}
71
72impl From<&u64> for Number<i64> {
73    fn from(value: &u64) -> Self {
74        Self(*value as i64)
75    }
76}
77
78impl From<&u32> for Number<u32> {
79    fn from(value: &u32) -> Self {
80        Self(*value)
81    }
82}
83
84impl From<&u64> for Number<u64> {
85    fn from(value: &u64) -> Self {
86        Self(*value)
87    }
88}
89
90impl<T> Serialize for Number<T>
91where
92    T: Serialize + Display,
93{
94    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95    where
96        S: serde::Serializer,
97    {
98        if !serializer.is_human_readable() {
99            return self.0.serialize(serializer);
100        }
101        serializer.serialize_str(&self.0.to_string())
102    }
103}
104
105impl<'de, T> Deserialize<'de> for Number<T>
106where
107    T: Deserialize<'de> + FromStr,
108{
109    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110    where
111        D: serde::Deserializer<'de>,
112    {
113        if !deserializer.is_human_readable() {
114            return T::deserialize(deserializer).map(Self);
115        }
116        struct V<T>(PhantomData<T>);
117        impl<'de, T> Visitor<'de> for V<T>
118        where
119            T: FromStr,
120        {
121            type Value = T;
122
123            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
124                formatter.write_str("a stringified number or a literal integer")
125            }
126
127            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
128            where
129                E: serde::de::Error,
130            {
131                v.parse().map_err(|_| {
132                    serde::de::Error::custom("failed to parse string as number".to_string())
133                })
134            }
135
136            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
137            where
138                E: serde::de::Error,
139            {
140                v.parse().map_err(|_| {
141                    serde::de::Error::custom("failed to parse string as number".to_string())
142                })
143            }
144
145            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
146            where
147                E: serde::de::Error,
148            {
149                v.parse().map_err(|_| {
150                    serde::de::Error::custom("failed to parse string as number".to_string())
151                })
152            }
153
154            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
155            where
156                E: serde::de::Error,
157            {
158                let s = v.to_string();
159                s.parse()
160                    .map_err(|_| serde::de::Error::custom("failed to parse integer as number"))
161            }
162
163            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
164            where
165                E: serde::de::Error,
166            {
167                let s = v.to_string();
168                s.parse().map_err(|_| {
169                    serde::de::Error::custom("failed to parse unsigned integer as number")
170                })
171            }
172        }
173        deserializer
174            .deserialize_any(V::<T>(Default::default()))
175            .map(Self)
176    }
177}
178
179macro_rules! binprot_number {
180    ($base_type:ident, $binprot_type:ident) => {
181        impl binprot::BinProtRead for Number<$base_type> {
182            fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
183            where
184                Self: Sized,
185            {
186                $binprot_type::binprot_read(r).map(|v| Self(v as $base_type))
187            }
188        }
189
190        impl binprot::BinProtWrite for Number<$base_type> {
191            fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
192                (self.0 as $binprot_type).binprot_write(w)
193            }
194        }
195    };
196}
197
198binprot_number!(i32, i32);
199binprot_number!(i64, i64);
200binprot_number!(u32, i32);
201binprot_number!(u64, i64);
202binprot_number!(f64, f64);
203
204#[cfg(test)]
205mod tests {
206    use binprot::{BinProtRead, BinProtWrite};
207
208    macro_rules! number_test {
209        ($name:ident, $ty:ident) => {
210            #[test]
211            fn $name() {
212                for n in [
213                    0,
214                    1,
215                    u8::MAX as $ty,
216                    u16::MAX as $ty,
217                    u32::MAX as $ty,
218                    u64::MAX as $ty,
219                    i8::MAX as $ty,
220                    i16::MAX as $ty,
221                    i32::MAX as $ty,
222                    i64::MAX as $ty,
223                ] {
224                    let n: super::Number<$ty> = n.into();
225                    let mut buf = Vec::new();
226                    n.binprot_write(&mut buf).unwrap();
227                    let mut r = buf.as_slice();
228                    let n_ = super::Number::<$ty>::binprot_read(&mut r).unwrap();
229                    assert_eq!(r.len(), 0);
230                    assert_eq!(n, n_);
231                }
232            }
233        };
234    }
235
236    macro_rules! max_number_test {
237        ($name:ident, $ty:ident) => {
238            #[test]
239            fn $name() {
240                let binprot = b"\xff\xff";
241                let mut r = &binprot[..];
242                let n = super::Number::<$ty>::binprot_read(&mut r).unwrap();
243                assert_eq!(n.0, $ty::MAX);
244
245                let n: super::Number<$ty> = $ty::MAX.into();
246                let mut buf = Vec::new();
247                n.binprot_write(&mut buf).unwrap();
248                assert_eq!(buf.as_slice(), b"\xff\xff");
249            }
250        };
251    }
252
253    number_test!(i32_roundtrip, i32);
254    number_test!(u32_roundtrip, u32);
255    number_test!(i64_roundtrip, i64);
256    number_test!(u64_roundtrip, u64);
257
258    max_number_test!(u32_max, u32);
259    max_number_test!(u64_max, u64);
260}