mina_p2p_messages/
pseq.rs

1use std::{array, fmt::Formatter, marker::PhantomData};
2
3use binprot::{BinProtRead, BinProtWrite};
4use malloc_size_of_derive::MallocSizeOf;
5use rsexp::{OfSexp, SexpOf};
6use serde::ser::SerializeTuple;
7#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
8pub struct PaddedSeq<T, const N: usize>(pub [T; N]);
9
10impl<T, const N: usize> Default for PaddedSeq<T, N>
11where
12    T: Default,
13{
14    fn default() -> Self {
15        Self(array::from_fn(|_| T::default()))
16    }
17}
18
19impl<T, const N: usize> std::ops::Deref for PaddedSeq<T, N> {
20    type Target = [T; N];
21
22    fn deref(&self) -> &Self::Target {
23        &self.0
24    }
25}
26
27impl<T: OfSexp, const N: usize> OfSexp for PaddedSeq<T, N> {
28    fn of_sexp(s: &rsexp::Sexp) -> Result<Self, rsexp::IntoSexpError>
29    where
30        Self: Sized,
31    {
32        let elts = s.extract_list("PaddedSeq")?;
33        if elts.len() != N {
34            return Err(rsexp::IntoSexpError::ListLengthMismatch {
35                type_: "PaddedSeq",
36                expected_len: N,
37                list_len: elts.len(),
38            });
39        }
40
41        let mut converted: [Option<T>; N] = [(); N].map(|_| None);
42
43        for (i, item) in elts.iter().enumerate() {
44            converted[i] = Some(T::of_sexp(item)?);
45        }
46
47        // Unwrap cannot fail, otherwise we wouldn't have rechead this point
48        Ok(Self(converted.map(|item| item.unwrap())))
49    }
50}
51
52impl<T: SexpOf, const N: usize> rsexp::SexpOf for PaddedSeq<T, N> {
53    fn sexp_of(&self) -> rsexp::Sexp {
54        let elements: Vec<rsexp::Sexp> = self.0.iter().map(|item| item.sexp_of()).collect();
55
56        rsexp::Sexp::List(elements)
57    }
58}
59
60impl<T: BinProtRead, const N: usize> binprot::BinProtRead for PaddedSeq<T, N> {
61    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
62    where
63        Self: Sized,
64    {
65        let mut vec = Vec::with_capacity(N);
66        for _i in 0..N {
67            vec.push(BinProtRead::binprot_read(r)?);
68        }
69        let _: () = BinProtRead::binprot_read(r)?;
70        match vec.try_into() {
71            Ok(arr) => Ok(PaddedSeq(arr)),
72            Err(_) => unreachable!(),
73        }
74    }
75}
76
77impl<T: BinProtWrite, const N: usize> binprot::BinProtWrite for PaddedSeq<T, N> {
78    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
79        for elt in &self.0 {
80            elt.binprot_write(w)?;
81        }
82        ().binprot_write(w)?;
83        Ok(())
84    }
85}
86
87impl<T, const N: usize> serde::Serialize for PaddedSeq<T, N>
88where
89    T: serde::Serialize,
90{
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: serde::Serializer,
94    {
95        let mut serializer = serializer.serialize_tuple(N + 1)?;
96        for elt in &self.0 {
97            serializer.serialize_element(elt)?;
98        }
99        serializer.end()
100    }
101}
102
103impl<'de, T, const N: usize> serde::Deserialize<'de> for PaddedSeq<T, N>
104where
105    T: serde::Deserialize<'de>,
106{
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        struct Visitor<'de, T, const S: usize>
112        where
113            T: serde::Deserialize<'de>,
114        {
115            marker: PhantomData<PaddedSeq<T, S>>,
116            lifetime: PhantomData<&'de ()>,
117        }
118        impl<'de, T, const S: usize> serde::de::Visitor<'de> for Visitor<'de, T, S>
119        where
120            T: serde::Deserialize<'de>,
121        {
122            type Value = PaddedSeq<T, S>;
123            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
124                Formatter::write_str(formatter, "tuple struct PaddedSeq")
125            }
126            #[inline]
127            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
128            where
129                A: serde::de::SeqAccess<'de>,
130            {
131                let mut vec = Vec::with_capacity(S);
132                for i in 0..S {
133                    match serde::de::SeqAccess::next_element(&mut seq)? {
134                        Some(value) => vec.push(value),
135                        None => {
136                            return Err(serde::de::Error::invalid_length(
137                                i,
138                                &concat!(
139                                    "tuple struct PaddedSeq with ",
140                                    stringify!(S),
141                                    " element(s)"
142                                ),
143                            ));
144                        }
145                    }
146                }
147                let res = match <[T; S]>::try_from(vec) {
148                    Ok(a) => a,
149                    Err(_) => unreachable!(),
150                };
151                Ok(PaddedSeq(res))
152            }
153        }
154        deserializer.deserialize_tuple(
155            N,
156            Visitor {
157                marker: PhantomData::<PaddedSeq<T, N>>,
158                lifetime: PhantomData,
159            },
160        )
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn to_json() {
170        let v = PaddedSeq([1, 2, 3]);
171        let json = serde_json::to_string(&v).unwrap();
172        assert_eq!(&json, "[1,2,3]");
173    }
174
175    #[test]
176    fn from_json() {
177        let json = "[1, 2, 3]";
178        let v = serde_json::from_str::<PaddedSeq<_, 3>>(json).unwrap();
179        assert_eq!(v, PaddedSeq([1, 2, 3]));
180    }
181
182    #[test]
183    fn to_binprot() {
184        let v = PaddedSeq([1, 2, 3]);
185        let mut binprot = Vec::new();
186        v.binprot_write(&mut binprot).unwrap();
187        assert_eq!(&binprot, b"\x01\x02\x03\x00");
188    }
189
190    #[test]
191    fn from_binprot() {
192        let binprot = b"\x01\x02\x03\x00";
193        let v = PaddedSeq::<_, 3>::binprot_read(&mut &binprot[..]).unwrap();
194        assert_eq!(v, PaddedSeq([1, 2, 3]));
195    }
196}