1use serde::{Deserialize, Serialize};
2use slab::Slab;
3use std::fmt;
4
5mod request_id;
6pub use request_id::{RequestId, RequestIdType};
7
8mod rpc_id;
9pub use rpc_id::{RpcId, RpcIdType};
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12struct PendingRequest<Request> {
13 counter: usize,
14 request: Request,
15}
16
17pub struct PendingRequests<IdType: RequestIdType, Request> {
18 list: Slab<PendingRequest<Request>>,
19 counter: usize,
20 last_added_req_id: RequestId<IdType>,
21}
22
23impl<IdType, Request> PendingRequests<IdType, Request>
24where
25 IdType: RequestIdType,
26{
27 pub fn new() -> Self {
28 Self {
29 list: Slab::new(),
30 counter: 0,
31 last_added_req_id: RequestId::new(0, 0),
32 }
33 }
34
35 #[inline]
36 pub fn len(&self) -> usize {
37 self.list.len()
38 }
39
40 pub fn is_empty(&self) -> bool {
41 self.len() == 0
42 }
43
44 pub fn counter(&self) -> usize {
45 self.counter
46 }
47
48 #[inline]
49 pub fn last_added_req_id(&self) -> RequestId<IdType> {
50 self.last_added_req_id
51 }
52
53 #[inline]
54 pub fn next_req_id(&self) -> RequestId<IdType> {
55 RequestId::new(self.list.vacant_key(), self.counter.wrapping_add(1))
56 }
57
58 #[inline]
59 pub fn contains(&self, id: RequestId<IdType>) -> bool {
60 self.get(id).is_some()
61 }
62
63 #[inline]
64 pub fn get(&self, id: RequestId<IdType>) -> Option<&Request> {
65 self.list
66 .get(id.locator())
67 .filter(|req| req.counter == id.counter())
68 .map(|x| &x.request)
69 }
70
71 #[inline]
72 pub fn get_mut(&mut self, id: RequestId<IdType>) -> Option<&mut Request> {
73 self.list
74 .get_mut(id.locator())
75 .filter(|req| req.counter == id.counter())
76 .map(|x| &mut x.request)
77 }
78
79 #[inline]
80 pub fn add(&mut self, request: Request) -> RequestId<IdType> {
81 self.counter = self.counter.wrapping_add(1);
82
83 let locator = self.list.insert(PendingRequest {
84 counter: self.counter,
85 request,
86 });
87
88 let req_id = RequestId::new(locator, self.counter);
89 self.last_added_req_id = req_id;
90
91 req_id
92 }
93
94 #[inline]
95 fn remove_pending(&mut self, id: RequestId<IdType>) -> Option<PendingRequest<Request>> {
96 self.get(id)?;
97 self.list.try_remove(id.locator())
98 }
99
100 #[inline]
101 pub fn remove(&mut self, id: RequestId<IdType>) -> Option<Request> {
102 self.remove_pending(id).map(|req| req.request)
103 }
104
105 #[inline]
106 pub fn update<F>(&mut self, id: RequestId<IdType>, update: F) -> bool
107 where
108 F: FnOnce(Request) -> Request,
109 {
110 if let Some(mut req) = self.remove_pending(id) {
111 req.request = update(req.request);
112 let new_locator = self.list.insert(req);
113 assert_eq!(id.locator(), new_locator, "when adding element to the slab right after removal, index should be the same as the index of the removed element");
114 true
115 } else {
116 false
117 }
118 }
119
120 pub fn iter(&self) -> impl Iterator<Item = (RequestId<IdType>, &Request)> {
121 self.list
122 .iter()
123 .map(|(locator, req)| (RequestId::new(locator, req.counter), &req.request))
124 }
125}
126
127impl<IdType, Request> Default for PendingRequests<IdType, Request>
128where
129 IdType: RequestIdType,
130{
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl<IdType, Request> Serialize for PendingRequests<IdType, Request>
137where
138 IdType: RequestIdType,
139 Request: Serialize,
140{
141 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
142 where
143 S: serde::Serializer,
144 {
145 use serde::ser::SerializeStruct;
146 let mut s = serializer.serialize_struct("PendingRequests", 3)?;
147 s.serialize_field("list", &self.list)?;
148 s.serialize_field("counter", &self.counter)?;
149 s.serialize_field("last_added_req_id", &self.last_added_req_id)?;
150 s.end()
151 }
152}
153
154impl<'de, IdType, Request> Deserialize<'de> for PendingRequests<IdType, Request>
155where
156 IdType: RequestIdType,
157 Request: Deserialize<'de>,
158{
159 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
160 where
161 D: serde::Deserializer<'de>,
162 {
163 use serde::de::{self, MapAccess, SeqAccess, Visitor};
164 struct RequestsVisitor<IdType, Request>(std::marker::PhantomData<(IdType, Request)>);
165
166 const FIELDS: &[&str] = &["list", "counter", "last_added_req_id"];
167
168 impl<'de, IdType, Request> Visitor<'de> for RequestsVisitor<IdType, Request>
169 where
170 IdType: RequestIdType,
171 Request: Deserialize<'de>,
172 {
173 type Value = PendingRequests<IdType, Request>;
174
175 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176 formatter.write_str("struct PendingRequests")
177 }
178
179 fn visit_seq<V>(self, mut seq: V) -> Result<PendingRequests<IdType, Request>, V::Error>
180 where
181 V: SeqAccess<'de>,
182 {
183 Ok(PendingRequests {
184 list: seq
185 .next_element()?
186 .ok_or_else(|| de::Error::missing_field("list"))?,
187 counter: seq
188 .next_element()?
189 .ok_or_else(|| de::Error::missing_field("counter"))?,
190 last_added_req_id: seq
191 .next_element()?
192 .ok_or_else(|| de::Error::missing_field("last_added_req_id"))?,
193 })
194 }
195
196 fn visit_map<V>(self, mut map: V) -> Result<PendingRequests<IdType, Request>, V::Error>
197 where
198 V: MapAccess<'de>,
199 {
200 let mut list = None;
201 let mut counter = None;
202 let mut last_added_req_id = None;
203 while let Some(key) = map.next_key()? {
204 match key {
205 "list" => {
206 if list.is_some() {
207 return Err(de::Error::duplicate_field("list"));
208 }
209 list = Some(map.next_value()?);
210 }
211 "counter" => {
212 if counter.is_some() {
213 return Err(de::Error::duplicate_field("counter"));
214 }
215 counter = Some(map.next_value()?);
216 }
217 "last_added_req_id" => {
218 if last_added_req_id.is_some() {
219 return Err(de::Error::duplicate_field("last_added_req_id"));
220 }
221 last_added_req_id = Some(map.next_value()?);
222 }
223 field => return Err(de::Error::unknown_field(field, FIELDS)),
224 }
225 }
226 let list = list.ok_or_else(|| de::Error::missing_field("list"))?;
227 let counter = counter.ok_or_else(|| de::Error::missing_field("counter"))?;
228 let last_added_req_id = last_added_req_id
229 .ok_or_else(|| de::Error::missing_field("last_added_req_id"))?;
230 Ok(PendingRequests {
231 list,
232 counter,
233 last_added_req_id,
234 })
235 }
236 }
237
238 let visitor = RequestsVisitor(Default::default());
239 deserializer.deserialize_struct("PendingRequests", FIELDS, visitor)
240 }
241}
242
243impl<IdType, Request> fmt::Debug for PendingRequests<IdType, Request>
244where
245 IdType: RequestIdType,
246 Request: fmt::Debug,
247{
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("PendingRequests")
250 .field("list", &self.list)
251 .field("counter", &self.counter)
252 .field("last_added_req_id", &self.last_added_req_id)
253 .finish()
254 }
255}
256
257impl<IdType, Request> Clone for PendingRequests<IdType, Request>
258where
259 IdType: RequestIdType,
260 Request: Clone,
261{
262 fn clone(&self) -> Self {
263 Self {
264 list: self.list.clone(),
265 counter: self.counter,
266 last_added_req_id: self.last_added_req_id,
267 }
268 }
269}