1use std::{
2 collections::{BTreeSet, VecDeque},
3 io,
4 pin::Pin,
5 sync::Arc,
6 task::{self, Context, Poll},
7};
8
9use libp2p::futures::{AsyncRead, AsyncWrite};
10
11use mina_p2p_messages::{
12 binprot::{self, BinProtRead, BinProtWrite},
13 rpc_kernel::RpcTag,
14};
15
16use mina_p2p_messages::{
17 rpc::VersionedRpcMenuV1,
18 rpc_kernel::{
19 Message, MessageHeader, NeedsLength, Query, QueryHeader, Response, ResponseHeader,
20 ResponsePayload, RpcMethod, RpcResult,
21 },
22};
23
24#[derive(Debug)]
25pub enum Received {
26 Query {
27 header: QueryHeader,
28 bytes: Vec<u8>,
29 },
30 Response {
31 header: ResponseHeader,
32 bytes: Vec<u8>,
33 },
34 Menu(Vec<(String, u32)>),
35 HandshakeDone,
36 }
38
39pub struct Inner {
40 menu: Arc<BTreeSet<(RpcTag, u32)>>,
41 command_queue: VecDeque<(usize, Vec<u8>)>,
42 buffer: Buffer,
43 ask_menu: bool,
44}
45
46impl Inner {
47 pub fn new(menu: Arc<BTreeSet<(RpcTag, u32)>>, ask_menu: bool) -> Self {
48 Inner {
49 menu,
50 command_queue: {
51 if ask_menu {
52 let msg = Message::<<VersionedRpcMenuV1 as RpcMethod>::Query>::Query(Query {
53 tag: <VersionedRpcMenuV1 as RpcMethod>::NAME.into(),
54 version: <VersionedRpcMenuV1 as RpcMethod>::VERSION,
55 id: 0,
56 data: NeedsLength(()),
57 });
58 let mut bytes = vec![0; 8];
59 msg.binprot_write(&mut bytes).expect("valid constant");
60 let len = (bytes.len() - 8) as u64;
61 bytes[..8].clone_from_slice(&len.to_le_bytes());
62 [(0, Self::HANDSHAKE_MSG.to_vec()), (0, bytes)]
63 .into_iter()
64 .collect()
65 } else {
66 [(0, Self::HANDSHAKE_MSG.to_vec())].into_iter().collect()
67 }
68 },
69 buffer: Buffer::default(),
70 ask_menu,
71 }
72 }
73}
74
75struct Buffer {
76 offset: usize,
77 buf: Vec<u8>,
78}
79
80impl Default for Buffer {
81 fn default() -> Self {
82 Buffer {
83 offset: 0,
84 buf: vec![0; Self::INITIAL_SIZE],
85 }
86 }
87}
88
89impl Buffer {
90 const INITIAL_SIZE: usize = 0x1000;
91
92 pub fn poll_fill<T>(&mut self, cx: &mut Context<'_>, io: &mut T) -> Poll<io::Result<usize>>
93 where
94 T: AsyncRead + Unpin,
95 {
96 loop {
97 let read =
98 task::ready!(Pin::new(&mut *io).poll_read(cx, &mut self.buf[self.offset..]))?;
99 self.offset += read;
100 if self.offset < self.buf.len() {
101 return Poll::Ready(Ok(read));
102 } else {
103 self.buf.resize(2 * self.buf.len(), 0);
104 }
105 }
106 }
107
108 pub fn try_cut(&mut self) -> Option<Result<(MessageHeader, Vec<u8>), binprot::Error>> {
109 if self.offset >= 8 {
110 let msg_len = u64::from_le_bytes(
111 self.buf[..8]
112 .try_into()
113 .expect("cannot fail, offset is >= 8"),
114 ) as usize;
115 if self.offset >= 8 + msg_len {
116 self.offset -= 8 + msg_len;
117 let mut all_bytes = &self.buf[8..(8 + msg_len)];
118 let header = match MessageHeader::binprot_read(&mut all_bytes) {
119 Ok(v) => v,
120 Err(err) => return Some(Err(err)),
121 };
122 let bytes = all_bytes.to_vec();
123 self.buf = self.buf[(8 + msg_len)..].to_vec();
124 let new_len = self.buf.len().next_power_of_two().max(Self::INITIAL_SIZE);
125 self.buf.resize(new_len, 0);
126 return Some(Ok((header, bytes)));
127 }
128 }
129
130 None
131 }
132}
133
134impl Inner {
135 const HANDSHAKE_MSG: [u8; 15] = *b"\x07\x00\x00\x00\x00\x00\x00\x00\x02\xfdRPC\x00\x01";
136
137 pub fn add(&mut self, bytes: Vec<u8>) {
138 self.command_queue.push_back((0, bytes));
139 }
140
141 pub fn poll<T>(&mut self, cx: &mut Context<'_>, io: &mut T) -> Poll<io::Result<Received>>
142 where
143 T: AsyncRead + AsyncWrite + Unpin,
144 {
145 let mut send_pending = false;
146 let mut recv_pending = false;
147
148 loop {
149 if !send_pending && !self.command_queue.is_empty() {
150 match self.poll_send(cx, io) {
151 Poll::Pending => send_pending = true,
152 Poll::Ready(r) => r?,
153 }
154 }
155
156 if !recv_pending {
157 match self.poll_recv(cx, io) {
158 Poll::Pending => {
159 recv_pending = true;
160 if self.command_queue.is_empty() {
161 return Poll::Pending;
162 }
163 }
164 Poll::Ready(r) => return Poll::Ready(r),
165 }
166 }
167
168 if (send_pending || self.command_queue.is_empty()) && recv_pending {
169 return Poll::Pending;
170 }
171 }
172 }
173
174 pub fn poll_recv<T>(
175 &mut self,
176 cx: &mut Context<'_>,
177 mut io: &mut T,
178 ) -> Poll<io::Result<Received>>
179 where
180 T: AsyncRead + Unpin,
181 {
182 let h_id = u64::from_le_bytes(*b"RPC\x00\x00\x00\x00\x00");
183 while let Some(v) = self.buffer.try_cut() {
184 let (header, bytes) = v.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
186 match header {
187 MessageHeader::Heartbeat => {
188 self.add(b"\x01\x00\x00\x00\x00\x00\x00\x00\x00".to_vec());
190 }
191 MessageHeader::Response(ResponseHeader { id }) if id == h_id => {
192 return Poll::Ready(Ok(Received::HandshakeDone));
193 }
194 MessageHeader::Response(ResponseHeader { id }) if id == 0 && self.ask_menu => {
195 let mut bytes_slice = bytes.as_slice();
196 type P = ResponsePayload<<VersionedRpcMenuV1 as RpcMethod>::Response>;
197 let menu = P::binprot_read(&mut bytes_slice)
198 .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
199 .0
200 .ok()
201 .map(|NeedsLength(x)| x)
202 .unwrap_or_default()
203 .into_iter()
204 .map(|(tag, version)| (tag.to_string_lossy(), version))
205 .collect();
206 return Poll::Ready(Ok(Received::Menu(menu)));
207 }
208 MessageHeader::Response(header) => {
209 return Poll::Ready(Ok(Received::Response { header, bytes }))
210 }
211 MessageHeader::Query(QueryHeader { tag, version, id })
212 if &tag == VersionedRpcMenuV1::NAME
213 && version == VersionedRpcMenuV1::VERSION =>
214 {
215 let msg = Message::<<VersionedRpcMenuV1 as RpcMethod>::Response>::Response(
216 Response {
217 id,
218 data: RpcResult(Ok(NeedsLength(
219 self.menu
220 .iter()
221 .cloned()
222 .map(|(tag, version)| (tag.into(), version))
223 .collect(),
224 ))),
225 },
226 );
227 let mut bytes = vec![0; 8];
228 msg.binprot_write(&mut bytes)?;
229 let len = (bytes.len() - 8) as u64;
230 bytes[..8].clone_from_slice(&len.to_le_bytes());
231
232 self.add(bytes);
233 }
234 MessageHeader::Query(header) => {
235 return Poll::Ready(Ok(Received::Query { header, bytes }))
236 }
237 };
238 }
239
240 if task::ready!(self.buffer.poll_fill(cx, &mut io))? != 0 {
241 self.poll_recv(cx, io)
242 } else {
243 Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
244 }
245 }
246
247 pub fn poll_send<T>(&mut self, cx: &mut Context<'_>, mut io: &mut T) -> Poll<io::Result<()>>
248 where
249 T: AsyncWrite + Unpin,
250 {
251 while let Some((offset, buf)) = self.command_queue.front_mut() {
252 if *offset < buf.len() {
253 let written = task::ready!(Pin::new(&mut io).poll_write(cx, &buf[*offset..]))?;
254 *offset += written;
255 if *offset >= buf.len() {
256 self.command_queue.pop_front();
257 }
258 }
259 }
260
261 Poll::Ready(Ok(()))
262 }
263}