1use std::{cmp::Ord, collections::BTreeMap, ops::RangeBounds};
2
3use crate::bug_condition;
4
5#[derive(Clone)]
6pub struct DistributedPool<State, Key: Ord> {
7 counter: u64,
8 list: BTreeMap<u64, State>,
9 by_key: BTreeMap<Key, u64>,
10}
11
12impl<State, Key: Ord> Default for DistributedPool<State, Key> {
13 fn default() -> Self {
14 Self {
15 counter: 0,
16 list: Default::default(),
17 by_key: Default::default(),
18 }
19 }
20}
21
22impl<State, Key> DistributedPool<State, Key>
23where
24 State: AsRef<Key>,
25 Key: Ord + Clone,
26{
27 pub fn is_empty(&self) -> bool {
28 self.list.is_empty()
29 }
30
31 pub fn len(&self) -> usize {
32 self.list.len()
33 }
34
35 pub fn contains(&self, key: &Key) -> bool {
36 self.by_key
37 .get(key)
38 .is_some_and(|i| self.list.contains_key(i))
39 }
40
41 pub fn get(&self, key: &Key) -> Option<&State> {
42 self.by_key.get(key).and_then(|i| self.list.get(i))
43 }
44
45 fn get_mut(&mut self, key: &Key) -> Option<&mut State> {
46 self.by_key.get(key).and_then(|i| self.list.get_mut(i))
47 }
48
49 pub fn range<R>(&self, range: R) -> impl '_ + DoubleEndedIterator<Item = (u64, &'_ State)>
50 where
51 R: RangeBounds<u64>,
52 {
53 self.list.range(range).map(|(k, v)| (*k, v))
54 }
55
56 pub fn last_index(&self) -> u64 {
57 self.list.last_key_value().map_or(0, |(k, _)| *k)
58 }
59
60 pub fn insert(&mut self, state: State) {
61 let key = state.as_ref().clone();
62 self.list.insert(self.counter, state);
63 self.by_key.insert(key, self.counter);
64 self.counter = self.counter.saturating_add(1);
65 }
66
67 pub fn update<F, R>(&mut self, key: &Key, f: F) -> Option<R>
68 where
69 F: FnOnce(&mut State) -> R,
70 {
71 let mut state = self.remove(key)?;
72 let res = f(&mut state);
73 self.insert(state);
74 Some(res)
75 }
76
77 pub fn silent_update<F, R>(&mut self, key: &Key, f: F) -> Option<R>
79 where
80 F: FnOnce(&mut State) -> R,
81 {
82 self.get_mut(key).map(f)
83 }
84
85 pub fn remove(&mut self, key: &Key) -> Option<State> {
86 let index = self.by_key.remove(key)?;
87 self.list.remove(&index)
88 }
89
90 pub fn retain<F>(&mut self, mut f: F)
91 where
92 F: FnMut(&Key, &State) -> bool,
93 {
94 self.retain_and_update(|key, state| f(key, state))
95 }
96
97 pub fn retain_and_update<F>(&mut self, mut f: F)
98 where
99 F: FnMut(&Key, &mut State) -> bool,
100 {
101 let list = &mut self.list;
102 self.by_key.retain(|key, index| {
103 let Some(v) = list.get_mut(index) else {
104 bug_condition!("Pool: key found in the index, but the item not found");
105 return false;
106 };
107 if f(key, v) {
108 return true;
109 }
110 list.remove(index);
111 false
112 });
113 }
114
115 pub fn states(&self) -> impl Iterator<Item = &State> {
116 self.list.values()
117 }
118}
119
120impl<State, Key> DistributedPool<State, Key>
121where
122 State: AsRef<Key>,
123 Key: Ord + Clone,
124{
125 pub fn next_messages_to_send<F, T>(
126 &self,
127 (index, limit): (u64, u8),
128 extract_message: F,
129 ) -> (Vec<T>, u64, u64)
130 where
131 F: Fn(&State) -> Option<T>,
132 {
133 if limit == 0 {
134 let index = index.saturating_sub(1);
135 return (vec![], index, index);
136 }
137
138 self.range(index..)
139 .try_fold(
140 (vec![], None),
141 |(mut list, mut first_index), (index, job)| {
142 if let Some(data) = extract_message(job) {
143 let first_index = *first_index.get_or_insert(index);
144 list.push(data);
145 if list.len() >= limit as usize {
146 return Err((list, first_index, index));
147 }
148 }
149
150 Ok((list, first_index))
151 },
152 )
153 .map(|(list, first_index)| (list, first_index.unwrap_or(index), self.last_index()))
155 .unwrap_or_else(|v| v)
157 }
158}
159
160mod ser {
161 use super::*;
162 use serde::{ser::SerializeStruct, Deserialize, Serialize};
163
164 #[derive(Deserialize)]
165 struct Pool<State> {
166 counter: u64,
167 list: BTreeMap<u64, State>,
168 }
169
170 impl<State, Key> Serialize for super::DistributedPool<State, Key>
171 where
172 State: Serialize,
173 Key: Ord,
174 {
175 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176 where
177 S: serde::Serializer,
178 {
179 let mut s = serializer.serialize_struct("Pool", 2)?;
180 s.serialize_field("counter", &self.counter)?;
181 s.serialize_field("list", &self.list)?;
182 s.end()
183 }
184 }
185 impl<'de, State, Key> Deserialize<'de> for super::DistributedPool<State, Key>
186 where
187 State: Deserialize<'de> + AsRef<Key>,
188 Key: Ord + Clone,
189 {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: serde::Deserializer<'de>,
193 {
194 let v = Pool::<State>::deserialize(deserializer)?;
195 let by_key = v
196 .list
197 .iter()
198 .map(|(k, v)| (v.as_ref().clone(), *k))
199 .collect();
200 Ok(Self {
201 counter: v.counter,
202 list: v.list,
203 by_key,
204 })
205 }
206 }
207}
208
209impl<State, Key: Ord> std::fmt::Debug for DistributedPool<State, Key> {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("Pool")
212 .field("counter", &self.counter)
213 .field("len", &self.list.len())
214 .finish()
215 }
216}