openmina_node_common/service/archive/
rpc.rs

1use binprot::BinProtWrite;
2use mina_p2p_messages::{
3    rpc_kernel::{Message, NeedsLength, Query, RpcMethod},
4    v2::{self, ArchiveRpc},
5};
6use mio::{event::Event, net::TcpStream, Events, Interest, Poll, Registry, Token};
7use std::{
8    io::{self, Read, Write},
9    net::SocketAddr,
10};
11
12const MAX_RECURSION_DEPTH: u8 = 25;
13
14// messages
15const HEADER_MSG: [u8; 7] = [2, 253, 82, 80, 67, 0, 1];
16const OK_MSG: [u8; 5] = [2, 1, 0, 1, 0];
17// Note: this is the close message that the ocaml node receives
18const CLOSE_MSG: [u8; 7] = [2, 254, 167, 7, 0, 1, 0];
19const HEARTBEAT_MSG: [u8; 1] = [0];
20
21fn prepend_length(message: &[u8]) -> Vec<u8> {
22    let length = message.len() as u64;
23    let mut length_bytes = length.to_le_bytes().to_vec();
24    length_bytes.append(&mut message.to_vec());
25    length_bytes
26}
27pub enum HandleResult {
28    MessageSent,
29    ConnectionClosed,
30    ConnectionAlive,
31    MessageWouldBlock,
32}
33
34impl HandleResult {
35    pub fn should_retry(&self) -> bool {
36        matches!(self, Self::ConnectionClosed)
37    }
38}
39
40pub fn send_diff(address: SocketAddr, data: v2::ArchiveRpc) -> io::Result<HandleResult> {
41    let rpc = encode_to_rpc(data)?;
42    process_rpc(address, &rpc)
43}
44
45fn encode_to_rpc(data: ArchiveRpc) -> io::Result<Vec<u8>> {
46    type Method = mina_p2p_messages::rpc::SendArchiveDiffUnversioned;
47    let mut v = vec![0; 8];
48
49    if let Err(e) = Message::Query(Query {
50        tag: Method::NAME.into(),
51        version: Method::VERSION,
52        id: 1,
53        data: NeedsLength(data),
54    })
55    .binprot_write(&mut v)
56    {
57        node::core::warn!(
58            summary = "Failed binprot serializastion",
59            error = e.to_string()
60        );
61        return Err(e);
62    }
63
64    let payload_length = (v.len() - 8) as u64;
65    v[..8].copy_from_slice(&payload_length.to_le_bytes());
66    // Bake in the heartbeat message
67    v.splice(0..0, prepend_length(&HEARTBEAT_MSG).iter().cloned());
68    // also add the heartbeat message to the end of the message
69    v.extend_from_slice(&prepend_length(&HEARTBEAT_MSG));
70
71    Ok(v)
72}
73
74fn process_rpc(address: SocketAddr, data: &[u8]) -> io::Result<HandleResult> {
75    let mut poll = Poll::new()?;
76    let mut events = Events::with_capacity(128);
77    let mut event_count = 0;
78
79    // We still need a token even for one connection
80    const TOKEN: Token = Token(0);
81
82    let mut stream = TcpStream::connect(address)?;
83
84    let mut handshake_received = false;
85    let mut handshake_sent = false;
86    let mut message_sent = false;
87    let mut first_heartbeat_received = false;
88    poll.registry()
89        .register(&mut stream, TOKEN, Interest::WRITABLE)?;
90
91    loop {
92        if let Err(e) = poll.poll(&mut events, None) {
93            if interrupted(&e) {
94                continue;
95            }
96            return Err(e);
97        }
98
99        for event in events.iter() {
100            event_count += 1;
101            // Failsafe to prevent infinite loops
102            if event_count > super::MAX_EVENT_COUNT {
103                return Err(io::Error::new(
104                    io::ErrorKind::Other,
105                    format!("FAILSAFE triggered, event count: {}", event_count),
106                ));
107            }
108            match event.token() {
109                TOKEN => {
110                    match handle_connection_event(
111                        poll.registry(),
112                        &mut stream,
113                        event,
114                        data,
115                        &mut handshake_received,
116                        &mut handshake_sent,
117                        &mut message_sent,
118                        &mut first_heartbeat_received,
119                    )? {
120                        HandleResult::MessageSent => return Ok(HandleResult::MessageSent),
121                        HandleResult::ConnectionClosed => {
122                            return Ok(HandleResult::ConnectionClosed)
123                        }
124                        HandleResult::MessageWouldBlock => {
125                            // do nothing, wait for the next event
126                            continue;
127                        }
128                        HandleResult::ConnectionAlive => {
129                            // keep swapping between readable and writable until we successfully send the message, then keep in read mode.
130                            if message_sent {
131                                poll.registry().reregister(
132                                    &mut stream,
133                                    TOKEN,
134                                    Interest::READABLE,
135                                )?;
136                                continue;
137                            }
138
139                            if event.is_writable() {
140                                poll.registry().reregister(
141                                    &mut stream,
142                                    TOKEN,
143                                    Interest::READABLE,
144                                )?;
145                            } else {
146                                poll.registry().reregister(
147                                    &mut stream,
148                                    TOKEN,
149                                    Interest::WRITABLE,
150                                )?;
151                            }
152                            continue;
153                        }
154                    }
155                }
156                _ => unreachable!(),
157            }
158        }
159    }
160}
161
162fn _send_heartbeat(connection: &mut TcpStream) -> io::Result<HandleResult> {
163    match connection.write_all(&HEARTBEAT_MSG) {
164        Ok(_) => {
165            connection.flush()?;
166            Ok(HandleResult::ConnectionAlive)
167        }
168        Err(ref err) if would_block(err) => Ok(HandleResult::MessageWouldBlock),
169        Err(ref err) if interrupted(err) => Ok(HandleResult::MessageWouldBlock),
170        Err(err) => Err(err),
171    }
172}
173
174struct RecursionGuard {
175    count: u8,
176    max_depth: u8,
177}
178
179impl RecursionGuard {
180    fn new(max_depth: u8) -> Self {
181        Self {
182            count: 0,
183            max_depth,
184        }
185    }
186
187    fn increment(&mut self) -> io::Result<()> {
188        self.count += 1;
189        if self.count > self.max_depth {
190            Err(io::ErrorKind::WriteZero.into())
191        } else {
192            Ok(())
193        }
194    }
195}
196
197fn send_data<F>(
198    connection: &mut TcpStream,
199    data: &[u8],
200    recursion_guard: &mut RecursionGuard,
201    // closure that can be called when the data is sent
202    on_success: F,
203) -> io::Result<HandleResult>
204where
205    F: FnOnce() -> io::Result<HandleResult>,
206{
207    match connection.write(data) {
208        Ok(n) if n < data.len() => {
209            recursion_guard.increment()?;
210            let remaining_data = data[n..].to_vec();
211            send_data(connection, &remaining_data, recursion_guard, on_success)
212        }
213        Ok(_) => {
214            connection.flush()?;
215            on_success()
216        }
217        Err(ref err) if would_block(err) => Ok(HandleResult::MessageWouldBlock),
218        Err(ref err) if interrupted(err) => {
219            recursion_guard
220                .increment()
221                .map_err(|_| io::ErrorKind::Interrupted)?;
222            send_data(connection, data, recursion_guard, on_success)
223        }
224        Err(err) => Err(err),
225    }
226}
227
228#[allow(clippy::too_many_arguments)]
229fn handle_connection_event(
230    registry: &Registry,
231    connection: &mut TcpStream,
232    event: &Event,
233    data: &[u8],
234    handshake_received: &mut bool,
235    handshake_sent: &mut bool,
236    message_sent: &mut bool,
237    first_heartbeat_received: &mut bool,
238) -> io::Result<HandleResult> {
239    if event.is_writable() {
240        if !*handshake_sent {
241            let msg = prepend_length(&HEADER_MSG);
242            send_data(
243                connection,
244                &msg,
245                &mut RecursionGuard::new(MAX_RECURSION_DEPTH),
246                || {
247                    *handshake_sent = true;
248                    Ok(HandleResult::ConnectionAlive)
249                },
250            )?;
251            return Ok(HandleResult::ConnectionAlive);
252        }
253
254        if *handshake_received && *handshake_sent && !*message_sent && *first_heartbeat_received {
255            send_data(
256                connection,
257                data,
258                &mut RecursionGuard::new(MAX_RECURSION_DEPTH),
259                || {
260                    *message_sent = true;
261                    Ok(HandleResult::ConnectionAlive)
262                },
263            )?;
264        }
265    }
266
267    if event.is_readable() {
268        let mut connection_closed = false;
269        let mut received_data = vec![0; 4096];
270        let mut bytes_read = 0;
271
272        loop {
273            match connection.read(&mut received_data[bytes_read..]) {
274                Ok(0) => {
275                    connection_closed = true;
276                    break;
277                }
278                Ok(n) => {
279                    bytes_read += n;
280                    if bytes_read == received_data.len() {
281                        received_data.resize(received_data.len() + 1024, 0);
282                    }
283                }
284                // Would block "errors" are the OS's way of saying that the
285                // connection is not actually ready to perform this I/O operation.
286                Err(ref err) if would_block(err) => break,
287                Err(ref err) if interrupted(err) => continue,
288                // Other errors we'll consider fatal.
289                Err(err) => return Err(err),
290            }
291        }
292
293        if connection_closed {
294            registry.deregister(connection)?;
295            connection.shutdown(std::net::Shutdown::Both)?;
296            return Ok(HandleResult::ConnectionClosed);
297        }
298
299        if bytes_read < 8 {
300            // malformed message, at least the length should be present
301            return Ok(HandleResult::ConnectionAlive);
302        }
303
304        let raw_message = RawMessage::from_bytes(&received_data[..bytes_read]);
305        let messages = raw_message.parse_raw()?;
306
307        for message in messages {
308            match message {
309                ParsedMessage::Header => {
310                    *handshake_received = true;
311                }
312                ParsedMessage::Ok | ParsedMessage::Close => {
313                    connection.flush()?;
314                    registry.deregister(connection)?;
315                    connection.shutdown(std::net::Shutdown::Both)?;
316                    return Ok(HandleResult::MessageSent);
317                }
318                ParsedMessage::Heartbeat => {
319                    *first_heartbeat_received = true;
320                }
321                ParsedMessage::Unknown(msg) => {
322                    registry.deregister(connection)?;
323                    connection.shutdown(std::net::Shutdown::Both)?;
324                    node::core::warn!(
325                        summary = "Received unknown message",
326                        msg = format!("{:?}", msg)
327                    );
328                    return Ok(HandleResult::ConnectionClosed);
329                }
330            }
331        }
332    }
333
334    Ok(HandleResult::ConnectionAlive)
335}
336
337fn would_block(err: &io::Error) -> bool {
338    err.kind() == io::ErrorKind::WouldBlock
339}
340
341fn interrupted(err: &io::Error) -> bool {
342    err.kind() == io::ErrorKind::Interrupted
343}
344
345enum ParsedMessage {
346    Heartbeat,
347    Ok,
348    Close,
349    Header,
350    Unknown(Vec<u8>),
351}
352
353struct RawMessage {
354    length: usize,
355    data: Vec<u8>,
356}
357
358impl RawMessage {
359    fn from_bytes(bytes: &[u8]) -> Self {
360        Self {
361            length: bytes.len(),
362            data: bytes.to_vec(),
363        }
364    }
365
366    fn parse_raw(&self) -> io::Result<Vec<ParsedMessage>> {
367        let mut parsed_bytes: usize = 0;
368
369        // more than one message can be sent in a single packet
370        let mut messages = Vec::new();
371
372        while parsed_bytes < self.length {
373            // first 8 bytes are the length in little endian
374            let length = u64::from_le_bytes(
375                self.data[parsed_bytes..parsed_bytes + 8]
376                    .try_into()
377                    .unwrap(),
378            ) as usize;
379            parsed_bytes += 8;
380
381            if parsed_bytes + length > self.length {
382                return Err(io::Error::new(
383                    io::ErrorKind::InvalidData,
384                    "Message length exceeds raw message length",
385                ));
386            }
387
388            if length == HEADER_MSG.len()
389                && self.data[parsed_bytes..parsed_bytes + length] == HEADER_MSG
390            {
391                messages.push(ParsedMessage::Header);
392            } else if length == OK_MSG.len()
393                && self.data[parsed_bytes..parsed_bytes + length] == OK_MSG
394            {
395                messages.push(ParsedMessage::Ok);
396            } else if length == HEARTBEAT_MSG.len()
397                && self.data[parsed_bytes..parsed_bytes + length] == HEARTBEAT_MSG
398            {
399                messages.push(ParsedMessage::Heartbeat);
400            } else if length == CLOSE_MSG.len()
401                && self.data[parsed_bytes..parsed_bytes + length] == CLOSE_MSG
402            {
403                messages.push(ParsedMessage::Close);
404            } else {
405                messages.push(ParsedMessage::Unknown(
406                    self.data[parsed_bytes..parsed_bytes + length].to_vec(),
407                ));
408            }
409
410            parsed_bytes += length;
411        }
412        Ok(messages)
413    }
414}