Skip to main content

mina_p2p_messages/
pseq.rs

1use std::{array, fmt::Formatter, marker::PhantomData};
2
3use binprot::{BinProtRead, BinProtWrite};
4use malloc_size_of_derive::MallocSizeOf;
5use rsexp::{OfSexp, SexpOf};
6use serde::ser::SerializeTuple;
7#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
8pub struct PaddedSeq<T, const N: usize>(pub [T; N]);
9
10#[cfg(feature = "openapi")]
11const _: () = {
12    impl<T: utoipa::ToSchema, const N: usize> utoipa::PartialSchema for PaddedSeq<T, N> {
13        fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
14            utoipa::openapi::schema::Array::builder()
15                .items(T::schema())
16                .min_items(Some(N))
17                .max_items(Some(N))
18                .build()
19                .into()
20        }
21    }
22    impl<T: utoipa::ToSchema, const N: usize> utoipa::ToSchema for PaddedSeq<T, N> {
23        fn name() -> std::borrow::Cow<'static, str> {
24            std::borrow::Cow::Owned(format!("PaddedSeq<{}, {}>", T::name(), N))
25        }
26    }
27};
28
29impl<T, const N: usize> Default for PaddedSeq<T, N>
30where
31    T: Default,
32{
33    fn default() -> Self {
34        Self(array::from_fn(|_| T::default()))
35    }
36}
37
38impl<T, const N: usize> std::ops::Deref for PaddedSeq<T, N> {
39    type Target = [T; N];
40
41    fn deref(&self) -> &Self::Target {
42        &self.0
43    }
44}
45
46impl<T: OfSexp, const N: usize> OfSexp for PaddedSeq<T, N> {
47    fn of_sexp(s: &rsexp::Sexp) -> Result<Self, rsexp::IntoSexpError>
48    where
49        Self: Sized,
50    {
51        let elts = s.extract_list("PaddedSeq")?;
52        if elts.len() != N {
53            return Err(rsexp::IntoSexpError::ListLengthMismatch {
54                type_: "PaddedSeq",
55                expected_len: N,
56                list_len: elts.len(),
57            });
58        }
59
60        let mut converted: [Option<T>; N] = [(); N].map(|_| None);
61
62        for (i, item) in elts.iter().enumerate() {
63            converted[i] = Some(T::of_sexp(item)?);
64        }
65
66        // Unwrap cannot fail, otherwise we wouldn't have rechead this point
67        Ok(Self(converted.map(|item| item.unwrap())))
68    }
69}
70
71impl<T: SexpOf, const N: usize> rsexp::SexpOf for PaddedSeq<T, N> {
72    fn sexp_of(&self) -> rsexp::Sexp {
73        let elements: Vec<rsexp::Sexp> = self.0.iter().map(|item| item.sexp_of()).collect();
74
75        rsexp::Sexp::List(elements)
76    }
77}
78
79impl<T: BinProtRead, const N: usize> binprot::BinProtRead for PaddedSeq<T, N> {
80    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
81    where
82        Self: Sized,
83    {
84        let mut vec = Vec::with_capacity(N);
85        for _i in 0..N {
86            vec.push(BinProtRead::binprot_read(r)?);
87        }
88        let _: () = BinProtRead::binprot_read(r)?;
89        match vec.try_into() {
90            Ok(arr) => Ok(PaddedSeq(arr)),
91            Err(_) => unreachable!(),
92        }
93    }
94}
95
96impl<T: BinProtWrite, const N: usize> binprot::BinProtWrite for PaddedSeq<T, N> {
97    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
98        for elt in &self.0 {
99            elt.binprot_write(w)?;
100        }
101        ().binprot_write(w)?;
102        Ok(())
103    }
104}
105
106impl<T, const N: usize> serde::Serialize for PaddedSeq<T, N>
107where
108    T: serde::Serialize,
109{
110    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: serde::Serializer,
113    {
114        let mut serializer = serializer.serialize_tuple(N + 1)?;
115        for elt in &self.0 {
116            serializer.serialize_element(elt)?;
117        }
118        serializer.end()
119    }
120}
121
122impl<'de, T, const N: usize> serde::Deserialize<'de> for PaddedSeq<T, N>
123where
124    T: serde::Deserialize<'de>,
125{
126    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127    where
128        D: serde::Deserializer<'de>,
129    {
130        struct Visitor<'de, T, const S: usize>
131        where
132            T: serde::Deserialize<'de>,
133        {
134            marker: PhantomData<PaddedSeq<T, S>>,
135            lifetime: PhantomData<&'de ()>,
136        }
137        impl<'de, T, const S: usize> serde::de::Visitor<'de> for Visitor<'de, T, S>
138        where
139            T: serde::Deserialize<'de>,
140        {
141            type Value = PaddedSeq<T, S>;
142            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
143                Formatter::write_str(formatter, "tuple struct PaddedSeq")
144            }
145            #[inline]
146            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147            where
148                A: serde::de::SeqAccess<'de>,
149            {
150                let mut vec = Vec::with_capacity(S);
151                for i in 0..S {
152                    match serde::de::SeqAccess::next_element(&mut seq)? {
153                        Some(value) => vec.push(value),
154                        None => {
155                            return Err(serde::de::Error::invalid_length(
156                                i,
157                                &concat!(
158                                    "tuple struct PaddedSeq with ",
159                                    stringify!(S),
160                                    " element(s)"
161                                ),
162                            ));
163                        }
164                    }
165                }
166                let res = match <[T; S]>::try_from(vec) {
167                    Ok(a) => a,
168                    Err(_) => unreachable!(),
169                };
170                Ok(PaddedSeq(res))
171            }
172        }
173        deserializer.deserialize_tuple(
174            N,
175            Visitor {
176                marker: PhantomData::<PaddedSeq<T, N>>,
177                lifetime: PhantomData,
178            },
179        )
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn to_json() {
189        let v = PaddedSeq([1, 2, 3]);
190        let json = serde_json::to_string(&v).unwrap();
191        assert_eq!(&json, "[1,2,3]");
192    }
193
194    #[test]
195    fn from_json() {
196        let json = "[1, 2, 3]";
197        let v = serde_json::from_str::<PaddedSeq<_, 3>>(json).unwrap();
198        assert_eq!(v, PaddedSeq([1, 2, 3]));
199    }
200
201    #[test]
202    fn to_binprot() {
203        let v = PaddedSeq([1, 2, 3]);
204        let mut binprot = Vec::new();
205        v.binprot_write(&mut binprot).unwrap();
206        assert_eq!(&binprot, b"\x01\x02\x03\x00");
207    }
208
209    #[test]
210    fn from_binprot() {
211        let binprot = b"\x01\x02\x03\x00";
212        let v = PaddedSeq::<_, 3>::binprot_read(&mut &binprot[..]).unwrap();
213        assert_eq!(v, PaddedSeq([1, 2, 3]));
214    }
215}