openmina_core/requests/
mod.rs

1use serde::{Deserialize, Serialize};
2use slab::Slab;
3use std::fmt;
4
5mod request_id;
6pub use request_id::{RequestId, RequestIdType};
7
8mod rpc_id;
9pub use rpc_id::{RpcId, RpcIdType};
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12struct PendingRequest<Request> {
13    counter: usize,
14    request: Request,
15}
16
17pub struct PendingRequests<IdType: RequestIdType, Request> {
18    list: Slab<PendingRequest<Request>>,
19    counter: usize,
20    last_added_req_id: RequestId<IdType>,
21}
22
23impl<IdType, Request> PendingRequests<IdType, Request>
24where
25    IdType: RequestIdType,
26{
27    pub fn new() -> Self {
28        Self {
29            list: Slab::new(),
30            counter: 0,
31            last_added_req_id: RequestId::new(0, 0),
32        }
33    }
34
35    #[inline]
36    pub fn len(&self) -> usize {
37        self.list.len()
38    }
39
40    pub fn is_empty(&self) -> bool {
41        self.len() == 0
42    }
43
44    pub fn counter(&self) -> usize {
45        self.counter
46    }
47
48    #[inline]
49    pub fn last_added_req_id(&self) -> RequestId<IdType> {
50        self.last_added_req_id
51    }
52
53    #[inline]
54    pub fn next_req_id(&self) -> RequestId<IdType> {
55        RequestId::new(self.list.vacant_key(), self.counter.wrapping_add(1))
56    }
57
58    #[inline]
59    pub fn contains(&self, id: RequestId<IdType>) -> bool {
60        self.get(id).is_some()
61    }
62
63    #[inline]
64    pub fn get(&self, id: RequestId<IdType>) -> Option<&Request> {
65        self.list
66            .get(id.locator())
67            .filter(|req| req.counter == id.counter())
68            .map(|x| &x.request)
69    }
70
71    #[inline]
72    pub fn get_mut(&mut self, id: RequestId<IdType>) -> Option<&mut Request> {
73        self.list
74            .get_mut(id.locator())
75            .filter(|req| req.counter == id.counter())
76            .map(|x| &mut x.request)
77    }
78
79    #[inline]
80    pub fn add(&mut self, request: Request) -> RequestId<IdType> {
81        self.counter = self.counter.wrapping_add(1);
82
83        let locator = self.list.insert(PendingRequest {
84            counter: self.counter,
85            request,
86        });
87
88        let req_id = RequestId::new(locator, self.counter);
89        self.last_added_req_id = req_id;
90
91        req_id
92    }
93
94    #[inline]
95    fn remove_pending(&mut self, id: RequestId<IdType>) -> Option<PendingRequest<Request>> {
96        self.get(id)?;
97        self.list.try_remove(id.locator())
98    }
99
100    #[inline]
101    pub fn remove(&mut self, id: RequestId<IdType>) -> Option<Request> {
102        self.remove_pending(id).map(|req| req.request)
103    }
104
105    #[inline]
106    pub fn update<F>(&mut self, id: RequestId<IdType>, update: F) -> bool
107    where
108        F: FnOnce(Request) -> Request,
109    {
110        if let Some(mut req) = self.remove_pending(id) {
111            req.request = update(req.request);
112            let new_locator = self.list.insert(req);
113            assert_eq!(id.locator(), new_locator, "when adding element to the slab right after removal, index should be the same as the index of the removed element");
114            true
115        } else {
116            false
117        }
118    }
119
120    pub fn iter(&self) -> impl Iterator<Item = (RequestId<IdType>, &Request)> {
121        self.list
122            .iter()
123            .map(|(locator, req)| (RequestId::new(locator, req.counter), &req.request))
124    }
125}
126
127impl<IdType, Request> Default for PendingRequests<IdType, Request>
128where
129    IdType: RequestIdType,
130{
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl<IdType, Request> Serialize for PendingRequests<IdType, Request>
137where
138    IdType: RequestIdType,
139    Request: Serialize,
140{
141    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
142    where
143        S: serde::Serializer,
144    {
145        use serde::ser::SerializeStruct;
146        let mut s = serializer.serialize_struct("PendingRequests", 3)?;
147        s.serialize_field("list", &self.list)?;
148        s.serialize_field("counter", &self.counter)?;
149        s.serialize_field("last_added_req_id", &self.last_added_req_id)?;
150        s.end()
151    }
152}
153
154impl<'de, IdType, Request> Deserialize<'de> for PendingRequests<IdType, Request>
155where
156    IdType: RequestIdType,
157    Request: Deserialize<'de>,
158{
159    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
160    where
161        D: serde::Deserializer<'de>,
162    {
163        use serde::de::{self, MapAccess, SeqAccess, Visitor};
164        struct RequestsVisitor<IdType, Request>(std::marker::PhantomData<(IdType, Request)>);
165
166        const FIELDS: &[&str] = &["list", "counter", "last_added_req_id"];
167
168        impl<'de, IdType, Request> Visitor<'de> for RequestsVisitor<IdType, Request>
169        where
170            IdType: RequestIdType,
171            Request: Deserialize<'de>,
172        {
173            type Value = PendingRequests<IdType, Request>;
174
175            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176                formatter.write_str("struct PendingRequests")
177            }
178
179            fn visit_seq<V>(self, mut seq: V) -> Result<PendingRequests<IdType, Request>, V::Error>
180            where
181                V: SeqAccess<'de>,
182            {
183                Ok(PendingRequests {
184                    list: seq
185                        .next_element()?
186                        .ok_or_else(|| de::Error::missing_field("list"))?,
187                    counter: seq
188                        .next_element()?
189                        .ok_or_else(|| de::Error::missing_field("counter"))?,
190                    last_added_req_id: seq
191                        .next_element()?
192                        .ok_or_else(|| de::Error::missing_field("last_added_req_id"))?,
193                })
194            }
195
196            fn visit_map<V>(self, mut map: V) -> Result<PendingRequests<IdType, Request>, V::Error>
197            where
198                V: MapAccess<'de>,
199            {
200                let mut list = None;
201                let mut counter = None;
202                let mut last_added_req_id = None;
203                while let Some(key) = map.next_key()? {
204                    match key {
205                        "list" => {
206                            if list.is_some() {
207                                return Err(de::Error::duplicate_field("list"));
208                            }
209                            list = Some(map.next_value()?);
210                        }
211                        "counter" => {
212                            if counter.is_some() {
213                                return Err(de::Error::duplicate_field("counter"));
214                            }
215                            counter = Some(map.next_value()?);
216                        }
217                        "last_added_req_id" => {
218                            if last_added_req_id.is_some() {
219                                return Err(de::Error::duplicate_field("last_added_req_id"));
220                            }
221                            last_added_req_id = Some(map.next_value()?);
222                        }
223                        field => return Err(de::Error::unknown_field(field, FIELDS)),
224                    }
225                }
226                let list = list.ok_or_else(|| de::Error::missing_field("list"))?;
227                let counter = counter.ok_or_else(|| de::Error::missing_field("counter"))?;
228                let last_added_req_id = last_added_req_id
229                    .ok_or_else(|| de::Error::missing_field("last_added_req_id"))?;
230                Ok(PendingRequests {
231                    list,
232                    counter,
233                    last_added_req_id,
234                })
235            }
236        }
237
238        let visitor = RequestsVisitor(Default::default());
239        deserializer.deserialize_struct("PendingRequests", FIELDS, visitor)
240    }
241}
242
243impl<IdType, Request> fmt::Debug for PendingRequests<IdType, Request>
244where
245    IdType: RequestIdType,
246    Request: fmt::Debug,
247{
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        f.debug_struct("PendingRequests")
250            .field("list", &self.list)
251            .field("counter", &self.counter)
252            .field("last_added_req_id", &self.last_added_req_id)
253            .finish()
254    }
255}
256
257impl<IdType, Request> Clone for PendingRequests<IdType, Request>
258where
259    IdType: RequestIdType,
260    Request: Clone,
261{
262    fn clone(&self) -> Self {
263        Self {
264            list: self.list.clone(),
265            counter: self.counter,
266            last_added_req_id: self.last_added_req_id,
267        }
268    }
269}