libp2p_rpc_behaviour/
state.rs

1use std::{
2    collections::{BTreeSet, VecDeque},
3    io,
4    pin::Pin,
5    sync::Arc,
6    task::{self, Context, Poll},
7};
8
9use libp2p::futures::{AsyncRead, AsyncWrite};
10
11use mina_p2p_messages::{
12    binprot::{self, BinProtRead, BinProtWrite},
13    rpc_kernel::RpcTag,
14};
15
16use mina_p2p_messages::{
17    rpc::VersionedRpcMenuV1,
18    rpc_kernel::{
19        Message, MessageHeader, NeedsLength, Query, QueryHeader, Response, ResponseHeader,
20        ResponsePayload, RpcMethod, RpcResult,
21    },
22};
23
24#[derive(Debug)]
25pub enum Received {
26    Query {
27        header: QueryHeader,
28        bytes: Vec<u8>,
29    },
30    Response {
31        header: ResponseHeader,
32        bytes: Vec<u8>,
33    },
34    Menu(Vec<(String, u32)>),
35    HandshakeDone,
36    // SentConfirmation(i64),
37}
38
39pub struct Inner {
40    menu: Arc<BTreeSet<(RpcTag, u32)>>,
41    command_queue: VecDeque<(usize, Vec<u8>)>,
42    buffer: Buffer,
43    ask_menu: bool,
44}
45
46impl Inner {
47    pub fn new(menu: Arc<BTreeSet<(RpcTag, u32)>>, ask_menu: bool) -> Self {
48        Inner {
49            menu,
50            command_queue: {
51                if ask_menu {
52                    let msg = Message::<<VersionedRpcMenuV1 as RpcMethod>::Query>::Query(Query {
53                        tag: <VersionedRpcMenuV1 as RpcMethod>::NAME.into(),
54                        version: <VersionedRpcMenuV1 as RpcMethod>::VERSION,
55                        id: 0,
56                        data: NeedsLength(()),
57                    });
58                    let mut bytes = vec![0; 8];
59                    msg.binprot_write(&mut bytes).expect("valid constant");
60                    let len = (bytes.len() - 8) as u64;
61                    bytes[..8].clone_from_slice(&len.to_le_bytes());
62                    [(0, Self::HANDSHAKE_MSG.to_vec()), (0, bytes)]
63                        .into_iter()
64                        .collect()
65                } else {
66                    [(0, Self::HANDSHAKE_MSG.to_vec())].into_iter().collect()
67                }
68            },
69            buffer: Buffer::default(),
70            ask_menu,
71        }
72    }
73}
74
75struct Buffer {
76    offset: usize,
77    buf: Vec<u8>,
78}
79
80impl Default for Buffer {
81    fn default() -> Self {
82        Buffer {
83            offset: 0,
84            buf: vec![0; Self::INITIAL_SIZE],
85        }
86    }
87}
88
89impl Buffer {
90    const INITIAL_SIZE: usize = 0x1000;
91
92    pub fn poll_fill<T>(&mut self, cx: &mut Context<'_>, io: &mut T) -> Poll<io::Result<usize>>
93    where
94        T: AsyncRead + Unpin,
95    {
96        loop {
97            let read =
98                task::ready!(Pin::new(&mut *io).poll_read(cx, &mut self.buf[self.offset..]))?;
99            self.offset += read;
100            if self.offset < self.buf.len() {
101                return Poll::Ready(Ok(read));
102            } else {
103                self.buf.resize(2 * self.buf.len(), 0);
104            }
105        }
106    }
107
108    pub fn try_cut(&mut self) -> Option<Result<(MessageHeader, Vec<u8>), binprot::Error>> {
109        if self.offset >= 8 {
110            let msg_len = u64::from_le_bytes(
111                self.buf[..8]
112                    .try_into()
113                    .expect("cannot fail, offset is >= 8"),
114            ) as usize;
115            if self.offset >= 8 + msg_len {
116                self.offset -= 8 + msg_len;
117                let mut all_bytes = &self.buf[8..(8 + msg_len)];
118                let header = match MessageHeader::binprot_read(&mut all_bytes) {
119                    Ok(v) => v,
120                    Err(err) => return Some(Err(err)),
121                };
122                let bytes = all_bytes.to_vec();
123                self.buf = self.buf[(8 + msg_len)..].to_vec();
124                let new_len = self.buf.len().next_power_of_two().max(Self::INITIAL_SIZE);
125                self.buf.resize(new_len, 0);
126                return Some(Ok((header, bytes)));
127            }
128        }
129
130        None
131    }
132}
133
134impl Inner {
135    const HANDSHAKE_MSG: [u8; 15] = *b"\x07\x00\x00\x00\x00\x00\x00\x00\x02\xfdRPC\x00\x01";
136
137    pub fn add(&mut self, bytes: Vec<u8>) {
138        self.command_queue.push_back((0, bytes));
139    }
140
141    pub fn poll<T>(&mut self, cx: &mut Context<'_>, io: &mut T) -> Poll<io::Result<Received>>
142    where
143        T: AsyncRead + AsyncWrite + Unpin,
144    {
145        let mut send_pending = false;
146        let mut recv_pending = false;
147
148        loop {
149            if !send_pending && !self.command_queue.is_empty() {
150                match self.poll_send(cx, io) {
151                    Poll::Pending => send_pending = true,
152                    Poll::Ready(r) => r?,
153                }
154            }
155
156            if !recv_pending {
157                match self.poll_recv(cx, io) {
158                    Poll::Pending => {
159                        recv_pending = true;
160                        if self.command_queue.is_empty() {
161                            return Poll::Pending;
162                        }
163                    }
164                    Poll::Ready(r) => return Poll::Ready(r),
165                }
166            }
167
168            if (send_pending || self.command_queue.is_empty()) && recv_pending {
169                return Poll::Pending;
170            }
171        }
172    }
173
174    pub fn poll_recv<T>(
175        &mut self,
176        cx: &mut Context<'_>,
177        mut io: &mut T,
178    ) -> Poll<io::Result<Received>>
179    where
180        T: AsyncRead + Unpin,
181    {
182        let h_id = u64::from_le_bytes(*b"RPC\x00\x00\x00\x00\x00");
183        while let Some(v) = self.buffer.try_cut() {
184            // TODO: proper error type
185            let (header, bytes) = v.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
186            match header {
187                MessageHeader::Heartbeat => {
188                    // TODO: handle heartbeat properly
189                    self.add(b"\x01\x00\x00\x00\x00\x00\x00\x00\x00".to_vec());
190                }
191                MessageHeader::Response(ResponseHeader { id }) if id == h_id => {
192                    return Poll::Ready(Ok(Received::HandshakeDone));
193                }
194                MessageHeader::Response(ResponseHeader { id }) if id == 0 && self.ask_menu => {
195                    let mut bytes_slice = bytes.as_slice();
196                    type P = ResponsePayload<<VersionedRpcMenuV1 as RpcMethod>::Response>;
197                    let menu = P::binprot_read(&mut bytes_slice)
198                        .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
199                        .0
200                        .ok()
201                        .map(|NeedsLength(x)| x)
202                        .unwrap_or_default()
203                        .into_iter()
204                        .map(|(tag, version)| (tag.to_string_lossy(), version))
205                        .collect();
206                    return Poll::Ready(Ok(Received::Menu(menu)));
207                }
208                MessageHeader::Response(header) => {
209                    return Poll::Ready(Ok(Received::Response { header, bytes }))
210                }
211                MessageHeader::Query(QueryHeader { tag, version, id })
212                    if &tag == VersionedRpcMenuV1::NAME
213                        && version == VersionedRpcMenuV1::VERSION =>
214                {
215                    let msg = Message::<<VersionedRpcMenuV1 as RpcMethod>::Response>::Response(
216                        Response {
217                            id,
218                            data: RpcResult(Ok(NeedsLength(
219                                self.menu
220                                    .iter()
221                                    .cloned()
222                                    .map(|(tag, version)| (tag.into(), version))
223                                    .collect(),
224                            ))),
225                        },
226                    );
227                    let mut bytes = vec![0; 8];
228                    msg.binprot_write(&mut bytes)?;
229                    let len = (bytes.len() - 8) as u64;
230                    bytes[..8].clone_from_slice(&len.to_le_bytes());
231
232                    self.add(bytes);
233                }
234                MessageHeader::Query(header) => {
235                    return Poll::Ready(Ok(Received::Query { header, bytes }))
236                }
237            };
238        }
239
240        if task::ready!(self.buffer.poll_fill(cx, &mut io))? != 0 {
241            self.poll_recv(cx, io)
242        } else {
243            Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
244        }
245    }
246
247    pub fn poll_send<T>(&mut self, cx: &mut Context<'_>, mut io: &mut T) -> Poll<io::Result<()>>
248    where
249        T: AsyncWrite + Unpin,
250    {
251        while let Some((offset, buf)) = self.command_queue.front_mut() {
252            if *offset < buf.len() {
253                let written = task::ready!(Pin::new(&mut io).poll_write(cx, &buf[*offset..]))?;
254                *offset += written;
255                if *offset >= buf.len() {
256                    self.command_queue.pop_front();
257                }
258            }
259        }
260
261        Poll::Ready(Ok(()))
262    }
263}