1use std::io::Read;
3
4use binprot::{BinProtRead, BinProtWrite};
5use binprot_derive::{BinProtRead, BinProtWrite};
6use malloc_size_of_derive::MallocSizeOf;
7use serde::{Deserialize, Serialize};
8
9use crate::versioned::Ver;
10
11pub type BinprotTag = super::string::CharString;
13pub type RpcTag = &'static [u8];
15pub type RpcVersion = Ver;
17pub type QueryID = u64;
18pub type Sexp = (); #[derive(
21 Clone, Debug, Serialize, Deserialize, PartialEq, Eq, derive_more::From, derive_more::Into,
22)]
23pub struct RpcResult<T, E>(pub Result<T, E>);
24
25#[derive(Debug, BinProtRead, BinProtWrite)]
27pub enum RpcResultKind {
28 Ok,
29 Err,
30}
31
32impl<T, E> BinProtRead for RpcResult<T, E>
33where
34 T: BinProtRead,
35 E: BinProtRead,
36{
37 fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
38 where
39 Self: Sized,
40 {
41 Ok(match RpcResultKind::binprot_read(r)? {
42 RpcResultKind::Ok => Ok(T::binprot_read(r)?),
43 RpcResultKind::Err => Err(E::binprot_read(r)?),
44 }
45 .into())
46 }
47}
48
49impl<T, E> BinProtWrite for RpcResult<T, E>
50where
51 T: BinProtWrite,
52 E: BinProtWrite,
53{
54 fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
55 match &self.0 {
56 Ok(v) => {
57 RpcResultKind::Ok.binprot_write(w)?;
58 v.binprot_write(w)?;
59 }
60 Err(e) => {
61 RpcResultKind::Err.binprot_write(w)?;
62 e.binprot_write(w)?;
63 }
64 }
65 Ok(())
66 }
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, derive_more::From)]
70pub struct NeedsLength<T>(pub T);
71
72impl<T> NeedsLength<T> {
73 pub fn into_inner(self) -> T {
74 self.0
75 }
76}
77
78impl<T> BinProtRead for NeedsLength<T>
79where
80 T: BinProtRead,
81{
82 fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
83 where
84 Self: Sized,
85 {
86 let _size = binprot::Nat0::binprot_read(r)?.0;
87 let v = T::binprot_read(r)?;
93 Ok(v.into())
94 }
95}
96
97impl<T> BinProtWrite for NeedsLength<T>
98where
99 T: BinProtWrite,
100{
101 fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
102 let mut buf = Vec::new();
103 self.0.binprot_write(&mut buf)?;
104 binprot::Nat0(buf.len() as u64).binprot_write(w)?;
105 w.write_all(&buf)?;
106 Ok(())
107 }
108}
109
110#[allow(non_camel_case_types)]
129#[derive(
130 Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, thiserror::Error,
131)]
132pub enum Error {
133 #[error("binprot expection")]
134 Bin_io_exn, #[error("connection closed")]
136 Connection_closed,
137 #[error("write error")]
138 Write_error, #[error("uncaught exception")]
140 Uncaught_exn, #[error("unimplemented method {}:{}", .0.to_string(), .1)]
142 Unimplemented_rpc(BinprotTag, Ver),
143 #[error("unknown query id: {0}")]
144 Unknown_query_id(QueryID),
145}
146
147pub type QueryPayload<T> = NeedsLength<T>;
152
153#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
168pub struct Query<T> {
169 pub tag: BinprotTag,
170 pub version: Ver,
171 pub id: QueryID,
172 pub data: QueryPayload<T>,
173}
174
175pub type ResponsePayload<T> = RpcResult<NeedsLength<T>, Error>;
180
181#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
194pub struct Response<T> {
195 pub id: QueryID,
196 pub data: ResponsePayload<T>,
197}
198
199#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
202pub struct DebuggerResponse<T> {
203 pub tag: BinprotTag,
204 pub version: Ver,
205 pub id: QueryID,
206 pub data: ResponsePayload<T>,
207}
208
209impl<T> From<DebuggerResponse<T>> for Response<T> {
210 fn from(source: DebuggerResponse<T>) -> Self {
211 Response {
212 id: source.id,
213 data: source.data,
214 }
215 }
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
232pub enum Message<T> {
233 Heartbeat,
234 Query(Query<T>),
235 Response(Response<T>),
236}
237
238#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
239pub enum DebuggerMessage<T> {
240 Heartbeat,
241 Query(Query<T>),
242 Response(DebuggerResponse<T>),
243}
244
245impl<T> From<DebuggerMessage<T>> for Message<T> {
246 fn from(source: DebuggerMessage<T>) -> Self {
247 match source {
248 DebuggerMessage::Heartbeat => Message::Heartbeat,
249 DebuggerMessage::Query(query) => Message::Query(query),
250 DebuggerMessage::Response(response) => Message::Response(response.into()),
251 }
252 }
253}
254
255#[derive(
256 Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, MallocSizeOf,
257)]
258pub struct QueryHeader {
259 pub tag: BinprotTag,
260 pub version: Ver,
261 pub id: QueryID,
262}
263
264#[derive(
265 Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq, MallocSizeOf,
266)]
267pub struct ResponseHeader {
268 pub id: QueryID,
269}
270
271#[derive(Clone, Debug, Serialize, Deserialize, BinProtRead, BinProtWrite, PartialEq, Eq)]
272pub enum MessageHeader {
273 Heartbeat,
274 Query(QueryHeader),
275 Response(ResponseHeader),
276}
277
278pub trait RpcMethod {
279 const NAME: RpcTag;
280 const NAME_STR: &'static str;
281 const VERSION: Ver;
282 type Query: BinProtRead + BinProtWrite;
283 type Response: BinProtRead + BinProtWrite;
284
285 fn rpc_id() -> String {
286 format!("{}:{}", Self::NAME_STR, Self::VERSION)
287 }
288}
289
290pub trait BinableDecoder {
293 type Output;
294 fn handle(&self, r: Box<&mut dyn Read>) -> Self::Output;
295}
296
297pub trait PayloadBinprotReader: RpcMethod {
303 fn query_payload<R>(r: &mut R) -> Result<Self::Query, RpcQueryReadError>
304 where
305 R: Read;
306 fn response_payload<R>(r: &mut R) -> Result<Self::Response, RpcResponseReadError>
307 where
308 R: Read;
309}
310
311#[derive(Debug, thiserror::Error)]
312pub enum RpcQueryReadError {
313 #[error("rpc query {rpc_id}: failed to decode binprot: {error}")]
314 Binprot {
315 rpc_id: String,
316 error: binprot::Error,
317 },
318}
319
320#[derive(Debug, thiserror::Error)]
321pub enum RpcResponseReadError {
322 #[error("rpc response {rpc_id}: failed to decode binprot: {error}")]
323 Binprot {
324 rpc_id: String,
325 error: binprot::Error,
326 },
327 #[error("rpc response {rpc_id}: peer failed to respond: {error}")]
328 Failure { rpc_id: String, error: self::Error },
329}
330
331impl<T> PayloadBinprotReader for T
332where
333 T: RpcMethod,
334 T::Query: BinProtRead,
335 T::Response: BinProtRead,
336{
337 fn query_payload<R>(r: &mut R) -> Result<Self::Query, RpcQueryReadError>
338 where
339 R: Read,
340 {
341 QueryPayload::<Self::Query>::binprot_read(r)
342 .map(|NeedsLength(v)| v)
343 .map_err(|error| RpcQueryReadError::Binprot {
344 rpc_id: T::rpc_id(),
345 error,
346 })
347 }
348
349 fn response_payload<R>(r: &mut R) -> Result<Self::Response, RpcResponseReadError>
350 where
351 R: Read,
352 {
353 ResponsePayload::<Self::Response>::binprot_read(r)
354 .map(|v| Result::from(v).map(NeedsLength::into_inner))
355 .map_err(|error| RpcResponseReadError::Binprot {
356 rpc_id: T::rpc_id(),
357 error,
358 })?
359 .map_err(|error| RpcResponseReadError::Failure {
360 rpc_id: T::rpc_id(),
361 error,
362 })
363 }
364}
365
366#[derive(Debug, thiserror::Error)]
367pub enum RpcDebuggerReaderError {
368 #[error(transparent)]
369 BinProtError(#[from] binprot::Error),
370 #[error("Query expected")]
371 ExpectQuery,
372 #[error("Response expected")]
373 ExpectResponse,
374}
375
376pub trait RpcDebuggerReader: RpcMethod {
384 fn debugger_query<R>(r: &mut R) -> Result<Self::Query, RpcDebuggerReaderError>
385 where
386 R: Read;
387 fn debugger_response<R>(
388 r: &mut R,
389 ) -> Result<Result<Self::Response, Error>, RpcDebuggerReaderError>
390 where
391 R: Read;
392}
393
394impl<T> RpcDebuggerReader for T
395where
396 T: RpcMethod,
397 T::Query: BinProtRead,
398 T::Response: BinProtRead,
399{
400 fn debugger_query<R>(r: &mut R) -> Result<Self::Query, RpcDebuggerReaderError>
401 where
402 R: Read,
403 {
404 if let Message::Query(query) = Message::<T::Query>::binprot_read(r)? {
405 Ok(query.data.0)
406 } else {
407 Err(RpcDebuggerReaderError::ExpectQuery)
408 }
409 }
410
411 fn debugger_response<R>(
412 r: &mut R,
413 ) -> Result<Result<Self::Response, Error>, RpcDebuggerReaderError>
414 where
415 R: Read,
416 {
417 if let Message::Response(response) = Message::<T::Response>::binprot_read(r)? {
418 Ok(Result::from(response.data).map(NeedsLength::into_inner))
419 } else {
420 Err(RpcDebuggerReaderError::ExpectResponse)
421 }
422 }
423}
424
425#[derive(Debug, thiserror::Error)]
426pub enum JSONinifyError {
427 #[error(transparent)]
428 Binprot(#[from] binprot::Error),
429 #[error(transparent)]
430 JSON(#[from] serde_json::Error),
431}
432
433pub trait JSONinifyPayloadReader {
434 fn read_query(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError>;
435 fn read_response(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError>;
436}
437
438impl<T> JSONinifyPayloadReader for T
439where
440 T: RpcMethod,
441 T::Query: BinProtRead + Serialize,
442 T::Response: BinProtRead + Serialize,
443{
444 fn read_query(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError> {
445 let v = QueryPayload::<T::Query>::binprot_read(r).map(|NeedsLength(v)| v)?;
446 let json = serde_json::to_value(&v)?;
447 Ok(json)
448 }
449
450 fn read_response(&self, r: &mut dyn Read) -> Result<serde_json::Value, JSONinifyError> {
451 let v = ResponsePayload::<T::Response>::binprot_read(r)
452 .map(|v| Result::from(v).map(|NeedsLength(v)| v))?;
453 let json = serde_json::to_value(&v)?;
454 Ok(json)
455 }
456}
457
458pub trait Converter {
459 type Output;
460 fn convert(self) -> Self::Output;
461}
462
463pub trait RpcConverter: RpcMethod {
464 type Output;
465 fn read_query(&self, r: Box<dyn Read>) -> Result<Self::Output, binprot::Error>;
466 fn read_response(&self, r: Box<dyn Read>) -> Result<Self::Output, binprot::Error>;
467}
468
469impl<T, FQ, FR> RpcMethod for (T, FQ, FR)
470where
471 T: RpcMethod,
472{
473 const NAME: RpcTag = T::NAME;
474 const NAME_STR: &'static str = T::NAME_STR;
475 const VERSION: Ver = T::VERSION;
476 type Query = T::Query;
477 type Response = T::Response;
478}
479
480impl<T, FQ, FR, O> RpcConverter for (T, FQ, FR)
481where
482 T: RpcMethod,
483 T::Query: BinProtRead,
484 T::Response: BinProtRead,
485 FQ: Fn(T::Query) -> O,
486 FR: Fn(T::Query) -> O,
487{
488 type Output = O;
489
490 fn read_query(&self, mut r: Box<dyn Read>) -> Result<Self::Output, binprot::Error> {
491 let v = Self::Query::binprot_read(r.as_mut())?;
492 Ok(self.1(v))
493 }
494
495 fn read_response(&self, mut r: Box<dyn Read>) -> Result<Self::Output, binprot::Error> {
496 let v = Self::Query::binprot_read(r.as_mut())?;
497 Ok(self.2(v))
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use binprot::BinProtRead;
504 use binprot_derive::BinProtRead;
505
506 use crate::{
507 list::List,
508 rpc_kernel::{BinprotTag, NeedsLength, RpcResult},
509 utils::FromBinProtStream,
510 versioned::Ver,
511 };
512
513 use super::{Message, MessageHeader, Query, QueryHeader, Response, ResponseHeader};
514
515 #[test]
516 fn message_header() {
517 for (s, m) in [
518 (
519 "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
520 MessageHeader::Query(QueryHeader {
521 tag: "__Versioned_rpc.Menu".into(),
522 version: 1,
523 id: 0x00014f48,
524 }),
525 ),
526 (
527 concat!(
528 "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
529 "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
530 "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
531 "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
532 "79010c6765745f626573745f746970010c6765745f616e63657374727901",
533 "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
534 "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
535 "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
536 "65745f65706f63685f6c656467657201"
537 ),
538 MessageHeader::Response(ResponseHeader { id: 0x000157ec }),
539 ),
540 ] {
541 let s = hex::decode(s).unwrap();
542 let mut p = s.as_slice();
543 let msg = MessageHeader::read_from_stream(&mut p).unwrap();
544 assert_eq!(msg, m);
545 }
546 }
547
548 #[test]
549 fn multiple_messages() {
550 let s = hex::decode(concat!(
551 "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
552 "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
553 "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
554 "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
555 "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
556 "79010c6765745f626573745f746970010c6765745f616e63657374727901",
557 "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
558 "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
559 "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
560 "65745f65706f63685f6c656467657201"
561 ))
562 .unwrap();
563
564 let mut p = s.as_slice();
565 for msg in [
566 MessageHeader::Query(QueryHeader {
567 tag: "__Versioned_rpc.Menu".into(),
568 version: 1,
569 id: 0x00014f48,
570 }),
571 MessageHeader::Response(ResponseHeader { id: 0x000157ec }),
572 ] {
573 assert_eq!(MessageHeader::read_from_stream(&mut p).unwrap(), msg);
574 }
575 }
576
577 fn test_message<T>(encoded: &str, decoded: T)
578 where
579 T: BinProtRead + std::fmt::Debug + PartialEq,
580 {
581 let s = hex::decode(encoded).unwrap();
582 let mut p = s.as_slice();
583 let msg = T::read_from_stream(&mut p).unwrap();
584 assert_eq!(msg, decoded);
585 }
586
587 #[test]
588 fn message() {
589 test_message(
590 "1e0000000000000001145f5f56657273696f6e65645f7270632e4d656e7501fd484f01000100",
591 Message::Query(Query {
592 tag: "__Versioned_rpc.Menu".into(),
593 version: 1,
594 id: 0x00014f48,
595 data: ().into(),
596 }),
597 );
598
599 #[derive(Debug, BinProtRead, PartialEq)]
600 struct RpcTagVersion {
601 tag: BinprotTag,
602 version: Ver,
603 }
604
605 type QueryType = List<RpcTagVersion>;
606
607 test_message(
608 concat!(
609 "f80000000000000002fdec57010000feee000a166765745f736f6d655f69",
610 "6e697469616c5f706565727301336765745f7374616765645f6c65646765",
611 "725f6175785f616e645f70656e64696e675f636f696e62617365735f6174",
612 "5f686173680118616e737765725f73796e635f6c65646765725f71756572",
613 "79010c6765745f626573745f746970010c6765745f616e63657374727901",
614 "184765745f7472616e736974696f6e5f6b6e6f776c656467650114676574",
615 "5f7472616e736974696f6e5f636861696e011a6765745f7472616e736974",
616 "696f6e5f636861696e5f70726f6f66010a62616e5f6e6f74696679011067",
617 "65745f65706f63685f6c656467657201"
618 ),
619 Message::<QueryType>::Response(Response {
620 id: 0x000157ec,
621 data: RpcResult::from(Ok(NeedsLength::from(
622 [
623 "get_some_initial_peers",
624 "get_staged_ledger_aux_and_pending_coinbases_at_hash",
625 "answer_sync_ledger_query",
626 "get_best_tip",
627 "get_ancestry",
628 "Get_transition_knowledge",
629 "get_transition_chain",
630 "get_transition_chain_proof",
631 "ban_notify",
632 "get_epoch_ledger",
633 ]
634 .into_iter()
635 .map(|tag| RpcTagVersion {
636 tag: tag.into(),
637 version: 1,
638 })
639 .collect::<List<_>>(),
640 ))),
641 }),
642 );
643 }
644}