mina_p2p_messages/
rpc_kernel.rs

1//! Partial implementation of Janestreet `core_rpc_kernel`.
2use std::io::Read;
3
4use binprot::{BinProtRead, BinProtWrite};
5use binprot_derive::{BinProtRead, BinProtWrite};
6use malloc_size_of_derive::MallocSizeOf;
7use serde::{Deserialize, Serialize};
8
9use crate::versioned::Ver;
10
11/// Binprot representation of the RPC method tag.
12pub type BinprotTag = super::string::CharString;
13/// Internal representation of the RPC method tag.
14pub type RpcTag = &'static [u8];
15/// RPC method version.
16pub type RpcVersion = Ver;
17pub type QueryID = u64;
18pub type Sexp = (); // TODO
19
20#[derive(
21    Clone, Debug, Serialize, Deserialize, PartialEq, Eq, derive_more::From, derive_more::Into,
22)]
23pub struct RpcResult<T, E>(pub Result<T, E>);
24
25/// Auxiliary type to encode [RpcResult]'s tag.
26#[derive(Debug, BinProtRead, BinProtWrite)]
27pub enum RpcResultKind {
28    Ok,
29    Err,
30}
31
32impl<T, E> BinProtRead for RpcResult<T, E>
33where
34    T: BinProtRead,
35    E: BinProtRead,
36{
37    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
38    where
39        Self: Sized,
40    {
41        Ok(match RpcResultKind::binprot_read(r)? {
42            RpcResultKind::Ok => Ok(T::binprot_read(r)?),
43            RpcResultKind::Err => Err(E::binprot_read(r)?),
44        }
45        .into())
46    }
47}
48
49impl<T, E> BinProtWrite for RpcResult<T, E>
50where
51    T: BinProtWrite,
52    E: BinProtWrite,
53{
54    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
55        match &self.0 {
56            Ok(v) => {
57                RpcResultKind::Ok.binprot_write(w)?;
58                v.binprot_write(w)?;
59            }
60            Err(e) => {
61                RpcResultKind::Err.binprot_write(w)?;
62                e.binprot_write(w)?;
63            }
64        }
65        Ok(())
66    }
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, derive_more::From)]
70pub struct NeedsLength<T>(pub T);
71
72impl<T> NeedsLength<T> {
73    pub fn into_inner(self) -> T {
74        self.0
75    }
76}
77
78impl<T> BinProtRead for NeedsLength<T>
79where
80    T: BinProtRead,
81{
82    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
83    where
84        Self: Sized,
85    {
86        let _size = binprot::Nat0::binprot_read(r)?.0;
87        // Trait function requires r to be ?Sized, so we cannot use `take`
88        // use std::io;
89        // let mut r = r.take(size);
90        // let v = T::binprot_read(&mut r)?;
91        // io::copy(&mut r, &mut io::sink())?;
92        let v = T::binprot_read(r)?;
93        Ok(v.into())
94    }
95}
96
97impl<T> BinProtWrite for NeedsLength<T>
98where
99    T: BinProtWrite,
100{
101    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
102        let mut buf = Vec::new();
103        self.0.binprot_write(&mut buf)?;
104        binprot::Nat0(buf.len() as u64).binprot_write(w)?;
105        w.write_all(&buf)?;
106        Ok(())
107    }
108}
109
110/// RPC error.
111///
112/// ```ocaml
113/// module Rpc_error : sig
114///   open Core_kernel
115///
116///   type t =
117///     | Bin_io_exn        of Sexp.t
118///     | Connection_closed
119///     | Write_error       of Sexp.t
120///     | Uncaught_exn      of Sexp.t
121///     | Unimplemented_rpc of Rpc_tag.t * [`Version of int]
122///     | Unknown_query_id  of Query_id.t
123///   [@@deriving bin_io, sexp, compare]
124///
125///   include Comparable.S with type t := t
126/// end
127/// ```
128#[allow(non_camel_case_types)]
129#[derive(
130    Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, thiserror::Error,
131)]
132pub enum Error {
133    #[error("binprot expection")]
134    Bin_io_exn, //(Sexp),
135    #[error("connection closed")]
136    Connection_closed,
137    #[error("write error")]
138    Write_error, //(Sexp),
139    #[error("uncaught exception")]
140    Uncaught_exn, //(Sexp),
141    #[error("unimplemented method {}:{}", .0.to_string(), .1)]
142    Unimplemented_rpc(BinprotTag, Ver),
143    #[error("unknown query id: {0}")]
144    Unknown_query_id(QueryID),
145}
146
147/// Type used for encoding RPC query payload.
148///
149/// Effectively this is the bin_prot encoding of the data prepended with `Nat0`
150/// encoding of its encoded size.
151pub type QueryPayload<T> = NeedsLength<T>;
152
153/// RPC query.
154///
155/// ```ocaml
156/// module Query = struct
157///   type 'a needs_length =
158///     { tag     : Rpc_tag.t
159///     ; version : int
160///     ; id      : Query_id.t
161///     ; data    : 'a
162///     }
163///   [@@deriving bin_io, sexp_of]
164///   type 'a t = 'a needs_length [@@deriving bin_read]
165/// end
166/// ```
167#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
168pub struct Query<T> {
169    pub tag: BinprotTag,
170    pub version: Ver,
171    pub id: QueryID,
172    pub data: QueryPayload<T>,
173}
174
175/// Type used to encode response payload.
176///
177/// Response can be either successfull, consisting of the result value prepended
178/// with its length, or an error of type [Error].
179pub type ResponsePayload<T> = RpcResult<NeedsLength<T>, Error>;
180
181/// RPC response.
182///
183/// ```ocaml
184/// module Response = struct
185///   type 'a needs_length =
186///     { id   : Query_id.t
187///     ; data : 'a Rpc_result.t
188///     }
189///   [@@deriving bin_io, sexp_of]
190///   type 'a t = 'a needs_length [@@deriving bin_read]
191/// end
192/// ```
193#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
194pub struct Response<T> {
195    pub id: QueryID,
196    pub data: ResponsePayload<T>,
197}
198
199/// RPC response in the form used by the Mina Network Debugger, with prepended
200/// RPC tag and version.
201#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
202pub struct DebuggerResponse<T> {
203    pub tag: BinprotTag,
204    pub version: Ver,
205    pub id: QueryID,
206    pub data: ResponsePayload<T>,
207}
208
209impl<T> From<DebuggerResponse<T>> for Response<T> {
210    fn from(source: DebuggerResponse<T>) -> Self {
211        Response {
212            id: source.id,
213            data: source.data,
214        }
215    }
216}
217
218/// RPC message.
219///
220/// ```ocaml
221/// module Message = struct
222///   type 'a needs_length =
223///     | Heartbeat
224///     | Query     of 'a Query.   needs_length
225///     | Response  of 'a Response.needs_length
226///   [@@deriving bin_io, sexp_of]
227///   type 'a t = 'a needs_length [@@deriving bin_read, sexp_of]
228///   type nat0_t = Nat0.t needs_length [@@deriving bin_read, bin_write]
229/// end
230/// ```
231#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
232pub enum Message<T> {
233    Heartbeat,
234    Query(Query<T>),
235    Response(Response<T>),
236}
237
238#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
239pub enum DebuggerMessage<T> {
240    Heartbeat,
241    Query(Query<T>),
242    Response(DebuggerResponse<T>),
243}
244
245impl<T> From<DebuggerMessage<T>> for Message<T> {
246    fn from(source: DebuggerMessage<T>) -> Self {
247        match source {
248            DebuggerMessage::Heartbeat => Message::Heartbeat,
249            DebuggerMessage::Query(query) => Message::Query(query),
250            DebuggerMessage::Response(response) => Message::Response(response.into()),
251        }
252    }
253}
254
255#[derive(
256    Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, MallocSizeOf,
257)]
258pub struct QueryHeader {
259    pub tag: BinprotTag,
260    pub version: Ver,
261    pub id: QueryID,
262}
263
264#[derive(
265    Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, MallocSizeOf,
266)]
267pub struct ResponseHeader {
268    pub id: QueryID,
269}
270
271#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
272pub enum MessageHeader {
273    Heartbeat,
274    Query(QueryHeader),
275    Response(ResponseHeader),
276}
277
278pub trait RpcMethod {
279    const NAME: RpcTag;
280    const NAME_STR: &'static str;
281    const VERSION: Ver;
282    type Query: BinProtRead + BinProtWrite;
283    type Response: BinProtRead + BinProtWrite;
284
285    fn rpc_id() -> String {
286        format!("{}:{}", Self::NAME_STR, Self::VERSION)
287    }
288}
289
290/// Reads binable (bin_prot-encoded) value from a stream, handles it and returns
291/// a result.
292pub trait BinableDecoder {
293    type Output;
294    fn handle(&self, r: Box<&mut dyn Read>) -> Self::Output;
295}
296
297/// Trait for reading RPC query and response payloads.
298///
299/// This is a helper trait that makes it easier to decode raw payload data from
300/// bin_prot encoded data, following the message header. It simply decodes data
301/// wrapped in auxiliary types and returns unwrapped data.
302pub trait PayloadBinprotReader: RpcMethod {
303    fn query_payload<R>(r: &mut R) -> Result<Self::Query, RpcQueryReadError>
304    where
305        R: Read;
306    fn response_payload<R>(r: &mut R) -> Result<Self::Response, RpcResponseReadError>
307    where
308        R: Read;
309}
310
311#[derive(Debug, thiserror::Error)]
312pub enum RpcQueryReadError {
313    #[error("rpc query {rpc_id}: failed to decode binprot: {error}")]
314    Binprot {
315        rpc_id: String,
316        error: binprot::Error,
317    },
318}
319
320#[derive(Debug, thiserror::Error)]
321pub enum RpcResponseReadError {
322    #[error("rpc response {rpc_id}: failed to decode binprot: {error}")]
323    Binprot {
324        rpc_id: String,
325        error: binprot::Error,
326    },
327    #[error("rpc response {rpc_id}: peer failed to respond: {error}")]
328    Failure { rpc_id: String, error: self::Error },
329}
330
331impl<T> PayloadBinprotReader for T
332where
333    T: RpcMethod,
334    T::Query: BinProtRead,
335    T::Response: BinProtRead,
336{
337    fn query_payload<R>(r: &mut R) -> Result<Self::Query, RpcQueryReadError>
338    where
339        R: Read,
340    {
341        QueryPayload::<Self::Query>::binprot_read(r)
342            .map(|NeedsLength(v)| v)
343            .map_err(|error| RpcQueryReadError::Binprot {
344                rpc_id: T::rpc_id(),
345                error,
346            })
347    }
348
349    fn response_payload<R>(r: &mut R) -> Result<Self::Response, RpcResponseReadError>
350    where
351        R: Read,
352    {
353        ResponsePayload::<Self::Response>::binprot_read(r)
354            .map(|v| Result::from(v).map(NeedsLength::into_inner))
355            .map_err(|error| RpcResponseReadError::Binprot {
356                rpc_id: T::rpc_id(),
357                error,
358            })?
359            .map_err(|error| RpcResponseReadError::Failure {
360                rpc_id: T::rpc_id(),
361                error,
362            })
363    }
364}
365
366#[derive(Debug, thiserror::Error)]
367pub enum RpcDebuggerReaderError {
368    #[error(transparent)]
369    BinProtError(#[from] binprot::Error),
370    #[error("Query expected")]
371    ExpectQuery,
372    #[error("Response expected")]
373    ExpectResponse,
374}
375
376/// Trait for reading RPC query and response in the format provided by the
377/// debugger.
378///
379/// This is a helper trait that makes it easier to decode data obtain from the
380/// Mina Network Debugger, that stores [DebuggerResponse] that has tag and
381/// version encoded, instead of [Response]. It simply decodes data wrapped in
382/// auxiliary types and returns unwrapped data.
383pub trait RpcDebuggerReader: RpcMethod {
384    fn debugger_query<R>(r: &mut R) -> Result<Self::Query, RpcDebuggerReaderError>
385    where
386        R: Read;
387    fn debugger_response<R>(
388        r: &mut R,
389    ) -> Result<Result<Self::Response, Error>, RpcDebuggerReaderError>
390    where
391        R: Read;
392}
393
394impl<T> RpcDebuggerReader for T
395where
396    T: RpcMethod,
397    T::Query: BinProtRead,
398    T::Response: BinProtRead,
399{
400    fn debugger_query<R>(r: &mut R) -> Result<Self::Query, RpcDebuggerReaderError>
401    where
402        R: Read,
403    {
404        if let Message::Query(query) = Message::<T::Query>::binprot_read(r)? {
405            Ok(query.data.0)
406        } else {
407            Err(RpcDebuggerReaderError::ExpectQuery)
408        }
409    }
410
411    fn debugger_response<R>(
412        r: &mut R,
413    ) -> Result<Result<Self::Response, Error>, RpcDebuggerReaderError>
414    where
415        R: Read,
416    {
417        if let Message::Response(response) = Message::<T::Response>::binprot_read(r)? {
418            Ok(Result::from(response.data).map(NeedsLength::into_inner))
419        } else {
420            Err(RpcDebuggerReaderError::ExpectResponse)
421        }
422    }
423}
424
425#[derive(Debug, thiserror::Error)]
426pub enum JSONinifyError {
427    #[error(transparent)]
428    Binprot(#[from] binprot::Error),
429    #[error(transparent)]
430    JSON(#[from] serde_json::Error),
431}
432
433pub trait JSONinifyPayloadReader {
434    fn read_query(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError>;
435    fn read_response(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError>;
436}
437
438impl<T> JSONinifyPayloadReader for T
439where
440    T: RpcMethod,
441    T::Query: BinProtRead + Serialize,
442    T::Response: BinProtRead + Serialize,
443{
444    fn read_query(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError> {
445        let v = QueryPayload::<T::Query>::binprot_read(r).map(|NeedsLength(v)| v)?;
446        let json = serde_json::to_value(&v)?;
447        Ok(json)
448    }
449
450    fn read_response(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError> {
451        let v = ResponsePayload::<T::Response>::binprot_read(r)
452            .map(|v| Result::from(v).map(|NeedsLength(v)| v))?;
453        let json = serde_json::to_value(&v)?;
454        Ok(json)
455    }
456}
457
458pub trait Converter {
459    type Output;
460    fn convert(self) -> Self::Output;
461}
462
463pub trait RpcConverter: RpcMethod {
464    type Output;
465    fn read_query(&self, r: Box<dyn Read>) -> Result<Self::Output, binprot::Error>;
466    fn read_response(&self, r: Box<dyn Read>) -> Result<Self::Output, binprot::Error>;
467}
468
469impl<T, FQ, FR> RpcMethod for (T, FQ, FR)
470where
471    T: RpcMethod,
472{
473    const NAME: RpcTag = T::NAME;
474    const NAME_STR: &'static str = T::NAME_STR;
475    const VERSION: Ver = T::VERSION;
476    type Query = T::Query;
477    type Response = T::Response;
478}
479
480impl<T, FQ, FR, O> RpcConverter for (T, FQ, FR)
481where
482    T: RpcMethod,
483    T::Query: BinProtRead,
484    T::Response: BinProtRead,
485    FQ: Fn(T::Query) -> O,
486    FR: Fn(T::Query) -> O,
487{
488    type Output = O;
489
490    fn read_query(&self, mut r: Box<dyn Read>) -> Result<Self::Output, binprot::Error> {
491        let v = Self::Query::binprot_read(r.as_mut())?;
492        Ok(self.1(v))
493    }
494
495    fn read_response(&self, mut r: Box<dyn Read>) -> Result<Self::Output, binprot::Error> {
496        let v = Self::Query::binprot_read(r.as_mut())?;
497        Ok(self.2(v))
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use binprot::BinProtRead;
504    use binprot_derive::BinProtRead;
505
506    use crate::{
507        list::List,
508        rpc_kernel::{BinprotTag, NeedsLength, RpcResult},
509        utils::FromBinProtStream,
510        versioned::Ver,
511    };
512
513    use super::{Message, MessageHeader, Query, QueryHeader, Response, ResponseHeader};
514
515    #[test]
516    fn message_header() {
517        for (s, m) in [
518            (
519                "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
520                MessageHeader::Query(QueryHeader {
521                    tag: "__Versioned_rpc.Menu".into(),
522                    version: 1,
523                    id: 0x00014f48,
524                }),
525            ),
526            (
527                concat!(
528                    "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
529                    "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
530                    "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
531                    "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
532                    "79010c6765745f626573745f746970010c6765745f616e63657374727901",
533                    "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
534                    "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
535                    "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
536                    "65745f65706f63685f6c656467657201"
537                ),
538                MessageHeader::Response(ResponseHeader { id: 0x000157ec }),
539            ),
540        ] {
541            let s = hex::decode(s).unwrap();
542            let mut p = s.as_slice();
543            let msg = MessageHeader::read_from_stream(&mut p).unwrap();
544            assert_eq!(msg, m);
545        }
546    }
547
548    #[test]
549    fn multiple_messages() {
550        let s = hex::decode(concat!(
551            "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
552            "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
553            "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
554            "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
555            "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
556            "79010c6765745f626573745f746970010c6765745f616e63657374727901",
557            "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
558            "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
559            "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
560            "65745f65706f63685f6c656467657201"
561        ))
562        .unwrap();
563
564        let mut p = s.as_slice();
565        for msg in [
566            MessageHeader::Query(QueryHeader {
567                tag: "__Versioned_rpc.Menu".into(),
568                version: 1,
569                id: 0x00014f48,
570            }),
571            MessageHeader::Response(ResponseHeader { id: 0x000157ec }),
572        ] {
573            assert_eq!(MessageHeader::read_from_stream(&mut p).unwrap(), msg);
574        }
575    }
576
577    fn test_message<T>(encoded: &str, decoded: T)
578    where
579        T: BinProtRead + std::fmt::Debug + PartialEq,
580    {
581        let s = hex::decode(encoded).unwrap();
582        let mut p = s.as_slice();
583        let msg = T::read_from_stream(&mut p).unwrap();
584        assert_eq!(msg, decoded);
585    }
586
587    #[test]
588    fn message() {
589        test_message(
590            "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
591            Message::Query(Query {
592                tag: "__Versioned_rpc.Menu".into(),
593                version: 1,
594                id: 0x00014f48,
595                data: ().into(),
596            }),
597        );
598
599        #[derive(Debug, BinProtRead, PartialEq)]
600        struct RpcTagVersion {
601            tag: BinprotTag,
602            version: Ver,
603        }
604
605        type QueryType = List<RpcTagVersion>;
606
607        test_message(
608            concat!(
609                "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
610                "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
611                "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
612                "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
613                "79010c6765745f626573745f746970010c6765745f616e63657374727901",
614                "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
615                "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
616                "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
617                "65745f65706f63685f6c656467657201"
618            ),
619            Message::<QueryType>::Response(Response {
620                id: 0x000157ec,
621                data: RpcResult::from(Ok(NeedsLength::from(
622                    [
623                        "get_some_initial_peers",
624                        "get_staged_ledger_aux_and_pending_coinbases_at_hash",
625                        "answer_sync_ledger_query",
626                        "get_best_tip",
627                        "get_ancestry",
628                        "Get_transition_knowledge",
629                        "get_transition_chain",
630                        "get_transition_chain_proof",
631                        "ban_notify",
632                        "get_epoch_ledger",
633                    ]
634                    .into_iter()
635                    .map(|tag| RpcTagVersion {
636                        tag: tag.into(),
637                        version: 1,
638                    })
639                    .collect::<List<_>>(),
640                ))),
641            }),
642        );
643    }
644}