openmina_node_common/service/archive/
rpc.rs1use binprot::BinProtWrite;
2use mina_p2p_messages::{
3 rpc_kernel::{Message, NeedsLength, Query, RpcMethod},
4 v2::{self, ArchiveRpc},
5};
6use mio::{event::Event, net::TcpStream, Events, Interest, Poll, Registry, Token};
7use std::{
8 io::{self, Read, Write},
9 net::SocketAddr,
10};
11
12const MAX_RECURSION_DEPTH: u8 = 25;
13
14const HEADER_MSG: [u8; 7] = [2, 253, 82, 80, 67, 0, 1];
16const OK_MSG: [u8; 5] = [2, 1, 0, 1, 0];
17const CLOSE_MSG: [u8; 7] = [2, 254, 167, 7, 0, 1, 0];
19const HEARTBEAT_MSG: [u8; 1] = [0];
20
21fn prepend_length(message: &[u8]) -> Vec<u8> {
22 let length = message.len() as u64;
23 let mut length_bytes = length.to_le_bytes().to_vec();
24 length_bytes.append(&mut message.to_vec());
25 length_bytes
26}
27pub enum HandleResult {
28 MessageSent,
29 ConnectionClosed,
30 ConnectionAlive,
31 MessageWouldBlock,
32}
33
34impl HandleResult {
35 pub fn should_retry(&self) -> bool {
36 matches!(self, Self::ConnectionClosed)
37 }
38}
39
40pub fn send_diff(address: SocketAddr, data: v2::ArchiveRpc) -> io::Result<HandleResult> {
41 let rpc = encode_to_rpc(data)?;
42 process_rpc(address, &rpc)
43}
44
45fn encode_to_rpc(data: ArchiveRpc) -> io::Result<Vec<u8>> {
46 type Method = mina_p2p_messages::rpc::SendArchiveDiffUnversioned;
47 let mut v = vec![0; 8];
48
49 if let Err(e) = Message::Query(Query {
50 tag: Method::NAME.into(),
51 version: Method::VERSION,
52 id: 1,
53 data: NeedsLength(data),
54 })
55 .binprot_write(&mut v)
56 {
57 node::core::warn!(
58 summary = "Failed binprot serializastion",
59 error = e.to_string()
60 );
61 return Err(e);
62 }
63
64 let payload_length = (v.len() - 8) as u64;
65 v[..8].copy_from_slice(&payload_length.to_le_bytes());
66 v.splice(0..0, prepend_length(&HEARTBEAT_MSG).iter().cloned());
68 v.extend_from_slice(&prepend_length(&HEARTBEAT_MSG));
70
71 Ok(v)
72}
73
74fn process_rpc(address: SocketAddr, data: &[u8]) -> io::Result<HandleResult> {
75 let mut poll = Poll::new()?;
76 let mut events = Events::with_capacity(128);
77 let mut event_count = 0;
78
79 const TOKEN: Token = Token(0);
81
82 let mut stream = TcpStream::connect(address)?;
83
84 let mut handshake_received = false;
85 let mut handshake_sent = false;
86 let mut message_sent = false;
87 let mut first_heartbeat_received = false;
88 poll.registry()
89 .register(&mut stream, TOKEN, Interest::WRITABLE)?;
90
91 loop {
92 if let Err(e) = poll.poll(&mut events, None) {
93 if interrupted(&e) {
94 continue;
95 }
96 return Err(e);
97 }
98
99 for event in events.iter() {
100 event_count += 1;
101 if event_count > super::MAX_EVENT_COUNT {
103 return Err(io::Error::new(
104 io::ErrorKind::Other,
105 format!("FAILSAFE triggered, event count: {}", event_count),
106 ));
107 }
108 match event.token() {
109 TOKEN => {
110 match handle_connection_event(
111 poll.registry(),
112 &mut stream,
113 event,
114 data,
115 &mut handshake_received,
116 &mut handshake_sent,
117 &mut message_sent,
118 &mut first_heartbeat_received,
119 )? {
120 HandleResult::MessageSent => return Ok(HandleResult::MessageSent),
121 HandleResult::ConnectionClosed => {
122 return Ok(HandleResult::ConnectionClosed)
123 }
124 HandleResult::MessageWouldBlock => {
125 continue;
127 }
128 HandleResult::ConnectionAlive => {
129 if message_sent {
131 poll.registry().reregister(
132 &mut stream,
133 TOKEN,
134 Interest::READABLE,
135 )?;
136 continue;
137 }
138
139 if event.is_writable() {
140 poll.registry().reregister(
141 &mut stream,
142 TOKEN,
143 Interest::READABLE,
144 )?;
145 } else {
146 poll.registry().reregister(
147 &mut stream,
148 TOKEN,
149 Interest::WRITABLE,
150 )?;
151 }
152 continue;
153 }
154 }
155 }
156 _ => unreachable!(),
157 }
158 }
159 }
160}
161
162fn _send_heartbeat(connection: &mut TcpStream) -> io::Result<HandleResult> {
163 match connection.write_all(&HEARTBEAT_MSG) {
164 Ok(_) => {
165 connection.flush()?;
166 Ok(HandleResult::ConnectionAlive)
167 }
168 Err(ref err) if would_block(err) => Ok(HandleResult::MessageWouldBlock),
169 Err(ref err) if interrupted(err) => Ok(HandleResult::MessageWouldBlock),
170 Err(err) => Err(err),
171 }
172}
173
174struct RecursionGuard {
175 count: u8,
176 max_depth: u8,
177}
178
179impl RecursionGuard {
180 fn new(max_depth: u8) -> Self {
181 Self {
182 count: 0,
183 max_depth,
184 }
185 }
186
187 fn increment(&mut self) -> io::Result<()> {
188 self.count += 1;
189 if self.count > self.max_depth {
190 Err(io::ErrorKind::WriteZero.into())
191 } else {
192 Ok(())
193 }
194 }
195}
196
197fn send_data<F>(
198 connection: &mut TcpStream,
199 data: &[u8],
200 recursion_guard: &mut RecursionGuard,
201 on_success: F,
203) -> io::Result<HandleResult>
204where
205 F: FnOnce() -> io::Result<HandleResult>,
206{
207 match connection.write(data) {
208 Ok(n) if n < data.len() => {
209 recursion_guard.increment()?;
210 let remaining_data = data[n..].to_vec();
211 send_data(connection, &remaining_data, recursion_guard, on_success)
212 }
213 Ok(_) => {
214 connection.flush()?;
215 on_success()
216 }
217 Err(ref err) if would_block(err) => Ok(HandleResult::MessageWouldBlock),
218 Err(ref err) if interrupted(err) => {
219 recursion_guard
220 .increment()
221 .map_err(|_| io::ErrorKind::Interrupted)?;
222 send_data(connection, data, recursion_guard, on_success)
223 }
224 Err(err) => Err(err),
225 }
226}
227
228#[allow(clippy::too_many_arguments)]
229fn handle_connection_event(
230 registry: &Registry,
231 connection: &mut TcpStream,
232 event: &Event,
233 data: &[u8],
234 handshake_received: &mut bool,
235 handshake_sent: &mut bool,
236 message_sent: &mut bool,
237 first_heartbeat_received: &mut bool,
238) -> io::Result<HandleResult> {
239 if event.is_writable() {
240 if !*handshake_sent {
241 let msg = prepend_length(&HEADER_MSG);
242 send_data(
243 connection,
244 &msg,
245 &mut RecursionGuard::new(MAX_RECURSION_DEPTH),
246 || {
247 *handshake_sent = true;
248 Ok(HandleResult::ConnectionAlive)
249 },
250 )?;
251 return Ok(HandleResult::ConnectionAlive);
252 }
253
254 if *handshake_received && *handshake_sent && !*message_sent && *first_heartbeat_received {
255 send_data(
256 connection,
257 data,
258 &mut RecursionGuard::new(MAX_RECURSION_DEPTH),
259 || {
260 *message_sent = true;
261 Ok(HandleResult::ConnectionAlive)
262 },
263 )?;
264 }
265 }
266
267 if event.is_readable() {
268 let mut connection_closed = false;
269 let mut received_data = vec![0; 4096];
270 let mut bytes_read = 0;
271
272 loop {
273 match connection.read(&mut received_data[bytes_read..]) {
274 Ok(0) => {
275 connection_closed = true;
276 break;
277 }
278 Ok(n) => {
279 bytes_read += n;
280 if bytes_read == received_data.len() {
281 received_data.resize(received_data.len() + 1024, 0);
282 }
283 }
284 Err(ref err) if would_block(err) => break,
287 Err(ref err) if interrupted(err) => continue,
288 Err(err) => return Err(err),
290 }
291 }
292
293 if connection_closed {
294 registry.deregister(connection)?;
295 connection.shutdown(std::net::Shutdown::Both)?;
296 return Ok(HandleResult::ConnectionClosed);
297 }
298
299 if bytes_read < 8 {
300 return Ok(HandleResult::ConnectionAlive);
302 }
303
304 let raw_message = RawMessage::from_bytes(&received_data[..bytes_read]);
305 let messages = raw_message.parse_raw()?;
306
307 for message in messages {
308 match message {
309 ParsedMessage::Header => {
310 *handshake_received = true;
311 }
312 ParsedMessage::Ok | ParsedMessage::Close => {
313 connection.flush()?;
314 registry.deregister(connection)?;
315 connection.shutdown(std::net::Shutdown::Both)?;
316 return Ok(HandleResult::MessageSent);
317 }
318 ParsedMessage::Heartbeat => {
319 *first_heartbeat_received = true;
320 }
321 ParsedMessage::Unknown(msg) => {
322 registry.deregister(connection)?;
323 connection.shutdown(std::net::Shutdown::Both)?;
324 node::core::warn!(
325 summary = "Received unknown message",
326 msg = format!("{:?}", msg)
327 );
328 return Ok(HandleResult::ConnectionClosed);
329 }
330 }
331 }
332 }
333
334 Ok(HandleResult::ConnectionAlive)
335}
336
337fn would_block(err: &io::Error) -> bool {
338 err.kind() == io::ErrorKind::WouldBlock
339}
340
341fn interrupted(err: &io::Error) -> bool {
342 err.kind() == io::ErrorKind::Interrupted
343}
344
345enum ParsedMessage {
346 Heartbeat,
347 Ok,
348 Close,
349 Header,
350 Unknown(Vec<u8>),
351}
352
353struct RawMessage {
354 length: usize,
355 data: Vec<u8>,
356}
357
358impl RawMessage {
359 fn from_bytes(bytes: &[u8]) -> Self {
360 Self {
361 length: bytes.len(),
362 data: bytes.to_vec(),
363 }
364 }
365
366 fn parse_raw(&self) -> io::Result<Vec<ParsedMessage>> {
367 let mut parsed_bytes: usize = 0;
368
369 let mut messages = Vec::new();
371
372 while parsed_bytes < self.length {
373 let length = u64::from_le_bytes(
375 self.data[parsed_bytes..parsed_bytes + 8]
376 .try_into()
377 .unwrap(),
378 ) as usize;
379 parsed_bytes += 8;
380
381 if parsed_bytes + length > self.length {
382 return Err(io::Error::new(
383 io::ErrorKind::InvalidData,
384 "Message length exceeds raw message length",
385 ));
386 }
387
388 if length == HEADER_MSG.len()
389 && self.data[parsed_bytes..parsed_bytes + length] == HEADER_MSG
390 {
391 messages.push(ParsedMessage::Header);
392 } else if length == OK_MSG.len()
393 && self.data[parsed_bytes..parsed_bytes + length] == OK_MSG
394 {
395 messages.push(ParsedMessage::Ok);
396 } else if length == HEARTBEAT_MSG.len()
397 && self.data[parsed_bytes..parsed_bytes + length] == HEARTBEAT_MSG
398 {
399 messages.push(ParsedMessage::Heartbeat);
400 } else if length == CLOSE_MSG.len()
401 && self.data[parsed_bytes..parsed_bytes + length] == CLOSE_MSG
402 {
403 messages.push(ParsedMessage::Close);
404 } else {
405 messages.push(ParsedMessage::Unknown(
406 self.data[parsed_bytes..parsed_bytes + length].to_vec(),
407 ));
408 }
409
410 parsed_bytes += length;
411 }
412 Ok(messages)
413 }
414}