1use std::io::{Read, Write};
2
3use binprot::{
4 byteorder::{LittleEndian, ReadBytesExt},
5 BinProtRead, BinProtWrite,
6};
7use serde::{Deserialize, Serialize};
8
9pub fn decode_int<T, R>(r: &mut R) -> Result<T, binprot::Error>
11where
12 T: BinProtRead,
13 R: Read,
14{
15 T::binprot_read(r)
16}
17
18pub fn decode_string<R>(r: &mut R) -> Result<String, binprot::Error>
20where
21 R: Read,
22{
23 binprot::SmallString1k::binprot_read(r).map(|s| s.0)
24}
25
26pub fn decode_int_from_slice<T>(slice: &[u8]) -> Result<(T, usize), binprot::Error>
30where
31 T: BinProtRead,
32{
33 let mut ptr = slice;
34 Ok((decode_int(&mut ptr)?, slice.len() - ptr.len()))
35}
36
37pub fn decode_string_from_slice(slice: &[u8]) -> Result<(String, usize), binprot::Error> {
40 let mut ptr = slice;
41 Ok((decode_string(&mut ptr)?, slice.len() - ptr.len()))
42}
43
44pub fn decode_bstr_from_slice(slice: &[u8]) -> Result<&[u8], binprot::Error> {
47 let mut ptr = slice;
48 let len = binprot::Nat0::binprot_read(&mut ptr)?.0 as usize;
49 Ok(&ptr[..len])
50}
51
52pub fn stream_decode_size<R>(r: &mut R) -> Result<usize, binprot::Error>
55where
56 R: Read,
57{
58 let len = r.read_u64::<LittleEndian>()?;
59 len.try_into()
60 .map_err(|_| binprot::Error::CustomError("integer conversion".into()))
61}
62
63pub fn get_sized_slice(mut slice: &[u8]) -> Result<&[u8], binprot::Error> {
66 let len = (&mut slice).read_u64::<LittleEndian>()? as usize;
67 Ok(&slice[..len])
68}
69
70pub trait FromBinProtStream: BinProtRead + Sized {
71 fn read_from_stream<R>(r: &mut R) -> Result<Self, binprot::Error>
79 where
80 R: Read,
81 {
82 use std::io;
83 let len = r.read_u64::<LittleEndian>()?;
84 let mut r = r.take(len);
85 let v = Self::binprot_read(&mut r)?;
86 let _discarded = io::copy(&mut r, &mut io::sink())?;
87 Ok(v)
88 }
89}
90
91impl<T> FromBinProtStream for T where T: BinProtRead {}
92
93#[derive(Clone, Debug)]
94pub struct Greedy(Vec<u8>);
95
96impl Serialize for Greedy {
97 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
98 where
99 S: serde::Serializer,
100 {
101 let hex = hex::encode(&self.0);
102 hex.serialize(serializer)
103 }
104}
105
106impl<'de> Deserialize<'de> for Greedy {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 let hex = String::deserialize(deserializer)?;
112 Ok(Self(hex::decode(hex).unwrap()))
113 }
114}
115
116impl BinProtRead for Greedy {
117 fn binprot_read<R: Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
118 where
119 Self: Sized,
120 {
121 let mut buf = Vec::new();
122 r.read_to_end(&mut buf)?;
123 Ok(Self(buf))
124 }
125}
126
127impl BinProtWrite for Greedy {
128 fn binprot_write<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
129 w.write_all(&self.0)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use crate::utils::{decode_bstr_from_slice, get_sized_slice};
136
137 use super::{decode_int_from_slice, decode_string_from_slice};
138
139 #[test]
140 fn u8() {
141 for (b, i, l) in [(b"\x00", 0_u8, 1), (b"\x7f", 0x7f, 1)] {
142 assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
143 }
144 }
145
146 #[test]
147 fn i8() {
148 for (b, i, l) in [(b"\x00", 0_i8, 1), (b"\x7f", 0x7f, 1)] {
149 assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
150 }
151 }
152
153 #[test]
154 fn u16() {
155 for (b, i, l) in [
156 (&b"\x00"[..], 0_u16, 1),
157 (b"\x7f", 0x7f, 1),
158 (b"\xfe\x80\x00", 0x80, 3),
159 ] {
160 assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
161 }
162 }
163
164 #[test]
165 fn i16() {
166 for (b, i, l) in [
167 (&b"\x00"[..], 0_i16, 1),
168 (b"\x7f", 0x7f, 1),
169 (b"\xfe\x80\x00", 0x80, 3),
170 ] {
171 assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
172 }
173 }
174
175 #[test]
176 fn string() {
177 let tests: &[(&[u8], &str, usize)] = &[
178 (b"\x00", "", 1),
179 (b"\x00\xff", "", 1),
180 (b"\x01a", "a", 2),
181 (b"\x0bsome string", "some string", 12),
182 ];
183 for (b, s, l) in tests {
184 let (string, len) = decode_string_from_slice(b).unwrap();
185 assert_eq!((string.as_str(), len), (*s, *l));
186 }
187 }
188
189 #[test]
190 fn bstr() {
191 let tests: &[(&[u8], &[u8])] = &[
192 (b"\x00", b""),
193 (b"\x00\xff", b""),
194 (b"\x01a", b"a"),
195 (b"\x0bsome string", b"some string"),
196 (b"\x0bsome string with more bytes", b"some string"),
197 ];
198 for (b, s) in tests {
199 let bstr = decode_bstr_from_slice(b).unwrap();
200 assert_eq!(bstr, *s);
201 }
202 }
203
204 #[test]
205 fn slice() {
206 let tests: &[(&[u8], &[u8])] = &[
207 (b"\x00\x00\x00\x00\x00\x00\x00\x00", b""),
208 (b"\x00\x00\x00\x00\x00\x00\x00\x00\xff", b""),
209 (b"\x01\x00\x00\x00\x00\x00\x00\x00\xff", b"\xff"),
210 ];
211 for (b, s) in tests {
212 let slice = get_sized_slice(b).unwrap();
213 assert_eq!(slice, *s);
214 }
215 }
216
217 #[test]
218 fn stream() {
219 use super::FromBinProtStream;
220 let tests: &[(&[u8], &[u8], usize)] = &[
221 (b"\x01\x00\x00\x00\x00\x00\x00\x00\x00", b"", 9),
222 (b"\x02\x00\x00\x00\x00\x00\x00\x00\x01b", b"b", 10),
223 (b"\x02\x00\x00\x00\x00\x00\x00\x00\x01bcdf", b"b", 10),
224 (b"\x05\x00\x00\x00\x00\x00\x00\x00\x01bcdf", b"b", 13),
225 ];
226 for (b, s, l) in tests {
227 let mut p = *b;
228 let string = crate::string::ByteString::read_from_stream(&mut p).unwrap();
229 assert_eq!((string.as_ref(), b.len() - p.len()), (*s, *l));
230 }
231 }
232}