1use std::{array, fmt::Formatter, marker::PhantomData};
2
3use binprot::{BinProtRead, BinProtWrite};
4use malloc_size_of_derive::MallocSizeOf;
5use rsexp::{OfSexp, SexpOf};
6use serde::ser::SerializeTuple;
7#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
8pub struct PaddedSeq<T, const N: usize>(pub [T; N]);
9
10#[cfg(feature = "openapi")]
11const _: () = {
12 impl<T: utoipa::ToSchema, const N: usize> utoipa::PartialSchema for PaddedSeq<T, N> {
13 fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
14 utoipa::openapi::schema::Array::builder()
15 .items(T::schema())
16 .min_items(Some(N))
17 .max_items(Some(N))
18 .build()
19 .into()
20 }
21 }
22 impl<T: utoipa::ToSchema, const N: usize> utoipa::ToSchema for PaddedSeq<T, N> {
23 fn name() -> std::borrow::Cow<'static, str> {
24 std::borrow::Cow::Owned(format!("PaddedSeq<{}, {}>", T::name(), N))
25 }
26 }
27};
28
29impl<T, const N: usize> Default for PaddedSeq<T, N>
30where
31 T: Default,
32{
33 fn default() -> Self {
34 Self(array::from_fn(|_| T::default()))
35 }
36}
37
38impl<T, const N: usize> std::ops::Deref for PaddedSeq<T, N> {
39 type Target = [T; N];
40
41 fn deref(&self) -> &Self::Target {
42 &self.0
43 }
44}
45
46impl<T: OfSexp, const N: usize> OfSexp for PaddedSeq<T, N> {
47 fn of_sexp(s: &rsexp::Sexp) -> Result<Self, rsexp::IntoSexpError>
48 where
49 Self: Sized,
50 {
51 let elts = s.extract_list("PaddedSeq")?;
52 if elts.len() != N {
53 return Err(rsexp::IntoSexpError::ListLengthMismatch {
54 type_: "PaddedSeq",
55 expected_len: N,
56 list_len: elts.len(),
57 });
58 }
59
60 let mut converted: [Option<T>; N] = [(); N].map(|_| None);
61
62 for (i, item) in elts.iter().enumerate() {
63 converted[i] = Some(T::of_sexp(item)?);
64 }
65
66 Ok(Self(converted.map(|item| item.unwrap())))
68 }
69}
70
71impl<T: SexpOf, const N: usize> rsexp::SexpOf for PaddedSeq<T, N> {
72 fn sexp_of(&self) -> rsexp::Sexp {
73 let elements: Vec<rsexp::Sexp> = self.0.iter().map(|item| item.sexp_of()).collect();
74
75 rsexp::Sexp::List(elements)
76 }
77}
78
79impl<T: BinProtRead, const N: usize> binprot::BinProtRead for PaddedSeq<T, N> {
80 fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
81 where
82 Self: Sized,
83 {
84 let mut vec = Vec::with_capacity(N);
85 for _i in 0..N {
86 vec.push(BinProtRead::binprot_read(r)?);
87 }
88 let _: () = BinProtRead::binprot_read(r)?;
89 match vec.try_into() {
90 Ok(arr) => Ok(PaddedSeq(arr)),
91 Err(_) => unreachable!(),
92 }
93 }
94}
95
96impl<T: BinProtWrite, const N: usize> binprot::BinProtWrite for PaddedSeq<T, N> {
97 fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
98 for elt in &self.0 {
99 elt.binprot_write(w)?;
100 }
101 ().binprot_write(w)?;
102 Ok(())
103 }
104}
105
106impl<T, const N: usize> serde::Serialize for PaddedSeq<T, N>
107where
108 T: serde::Serialize,
109{
110 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111 where
112 S: serde::Serializer,
113 {
114 let mut serializer = serializer.serialize_tuple(N + 1)?;
115 for elt in &self.0 {
116 serializer.serialize_element(elt)?;
117 }
118 serializer.end()
119 }
120}
121
122impl<'de, T, const N: usize> serde::Deserialize<'de> for PaddedSeq<T, N>
123where
124 T: serde::Deserialize<'de>,
125{
126 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 struct Visitor<'de, T, const S: usize>
131 where
132 T: serde::Deserialize<'de>,
133 {
134 marker: PhantomData<PaddedSeq<T, S>>,
135 lifetime: PhantomData<&'de ()>,
136 }
137 impl<'de, T, const S: usize> serde::de::Visitor<'de> for Visitor<'de, T, S>
138 where
139 T: serde::Deserialize<'de>,
140 {
141 type Value = PaddedSeq<T, S>;
142 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
143 Formatter::write_str(formatter, "tuple struct PaddedSeq")
144 }
145 #[inline]
146 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147 where
148 A: serde::de::SeqAccess<'de>,
149 {
150 let mut vec = Vec::with_capacity(S);
151 for i in 0..S {
152 match serde::de::SeqAccess::next_element(&mut seq)? {
153 Some(value) => vec.push(value),
154 None => {
155 return Err(serde::de::Error::invalid_length(
156 i,
157 &concat!(
158 "tuple struct PaddedSeq with ",
159 stringify!(S),
160 " element(s)"
161 ),
162 ));
163 }
164 }
165 }
166 let res = match <[T; S]>::try_from(vec) {
167 Ok(a) => a,
168 Err(_) => unreachable!(),
169 };
170 Ok(PaddedSeq(res))
171 }
172 }
173 deserializer.deserialize_tuple(
174 N,
175 Visitor {
176 marker: PhantomData::<PaddedSeq<T, N>>,
177 lifetime: PhantomData,
178 },
179 )
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn to_json() {
189 let v = PaddedSeq([1, 2, 3]);
190 let json = serde_json::to_string(&v).unwrap();
191 assert_eq!(&json, "[1,2,3]");
192 }
193
194 #[test]
195 fn from_json() {
196 let json = "[1, 2, 3]";
197 let v = serde_json::from_str::<PaddedSeq<_, 3>>(json).unwrap();
198 assert_eq!(v, PaddedSeq([1, 2, 3]));
199 }
200
201 #[test]
202 fn to_binprot() {
203 let v = PaddedSeq([1, 2, 3]);
204 let mut binprot = Vec::new();
205 v.binprot_write(&mut binprot).unwrap();
206 assert_eq!(&binprot, b"\x01\x02\x03\x00");
207 }
208
209 #[test]
210 fn from_binprot() {
211 let binprot = b"\x01\x02\x03\x00";
212 let v = PaddedSeq::<_, 3>::binprot_read(&mut &binprot[..]).unwrap();
213 assert_eq!(v, PaddedSeq([1, 2, 3]));
214 }
215}