openmina_core/
distributed_pool.rs

1use std::{cmp::Ord, collections::BTreeMap, ops::RangeBounds};
2
3use crate::bug_condition;
4
5#[derive(Clone)]
6pub struct DistributedPool<State, Key: Ord> {
7    counter: u64,
8    list: BTreeMap<u64, State>,
9    by_key: BTreeMap<Key, u64>,
10}
11
12impl<State, Key: Ord> Default for DistributedPool<State, Key> {
13    fn default() -> Self {
14        Self {
15            counter: 0,
16            list: Default::default(),
17            by_key: Default::default(),
18        }
19    }
20}
21
22impl<State, Key> DistributedPool<State, Key>
23where
24    State: AsRef<Key>,
25    Key: Ord + Clone,
26{
27    pub fn is_empty(&self) -> bool {
28        self.list.is_empty()
29    }
30
31    pub fn len(&self) -> usize {
32        self.list.len()
33    }
34
35    pub fn contains(&self, key: &Key) -> bool {
36        self.by_key
37            .get(key)
38            .is_some_and(|i| self.list.contains_key(i))
39    }
40
41    pub fn get(&self, key: &Key) -> Option<&State> {
42        self.by_key.get(key).and_then(|i| self.list.get(i))
43    }
44
45    fn get_mut(&mut self, key: &Key) -> Option<&mut State> {
46        self.by_key.get(key).and_then(|i| self.list.get_mut(i))
47    }
48
49    pub fn range<R>(&self, range: R) -> impl '_ + DoubleEndedIterator<Item = (u64, &'_ State)>
50    where
51        R: RangeBounds<u64>,
52    {
53        self.list.range(range).map(|(k, v)| (*k, v))
54    }
55
56    pub fn last_index(&self) -> u64 {
57        self.list.last_key_value().map_or(0, |(k, _)| *k)
58    }
59
60    pub fn insert(&mut self, state: State) {
61        let key = state.as_ref().clone();
62        self.list.insert(self.counter, state);
63        self.by_key.insert(key, self.counter);
64        self.counter = self.counter.saturating_add(1);
65    }
66
67    pub fn update<F, R>(&mut self, key: &Key, f: F) -> Option<R>
68    where
69        F: FnOnce(&mut State) -> R,
70    {
71        let mut state = self.remove(key)?;
72        let res = f(&mut state);
73        self.insert(state);
74        Some(res)
75    }
76
77    /// Don't use if the change needs to be synced with other peers.
78    pub fn silent_update<F, R>(&mut self, key: &Key, f: F) -> Option<R>
79    where
80        F: FnOnce(&mut State) -> R,
81    {
82        self.get_mut(key).map(f)
83    }
84
85    pub fn remove(&mut self, key: &Key) -> Option<State> {
86        let index = self.by_key.remove(key)?;
87        self.list.remove(&index)
88    }
89
90    pub fn retain<F>(&mut self, mut f: F)
91    where
92        F: FnMut(&Key, &State) -> bool,
93    {
94        self.retain_and_update(|key, state| f(key, state))
95    }
96
97    pub fn retain_and_update<F>(&mut self, mut f: F)
98    where
99        F: FnMut(&Key, &mut State) -> bool,
100    {
101        let list = &mut self.list;
102        self.by_key.retain(|key, index| {
103            let Some(v) = list.get_mut(index) else {
104                bug_condition!("Pool: key found in the index, but the item not found");
105                return false;
106            };
107            if f(key, v) {
108                return true;
109            }
110            list.remove(index);
111            false
112        });
113    }
114
115    pub fn states(&self) -> impl Iterator<Item = &State> {
116        self.list.values()
117    }
118}
119
120impl<State, Key> DistributedPool<State, Key>
121where
122    State: AsRef<Key>,
123    Key: Ord + Clone,
124{
125    pub fn next_messages_to_send<F, T>(
126        &self,
127        (index, limit): (u64, u8),
128        extract_message: F,
129    ) -> (Vec<T>, u64, u64)
130    where
131        F: Fn(&State) -> Option<T>,
132    {
133        if limit == 0 {
134            let index = index.saturating_sub(1);
135            return (vec![], index, index);
136        }
137
138        self.range(index..)
139            .try_fold(
140                (vec![], None),
141                |(mut list, mut first_index), (index, job)| {
142                    if let Some(data) = extract_message(job) {
143                        let first_index = *first_index.get_or_insert(index);
144                        list.push(data);
145                        if list.len() >= limit as usize {
146                            return Err((list, first_index, index));
147                        }
148                    }
149
150                    Ok((list, first_index))
151                },
152            )
153            // Loop iterated on whole list.
154            .map(|(list, first_index)| (list, first_index.unwrap_or(index), self.last_index()))
155            // Loop preemptively ended.
156            .unwrap_or_else(|v| v)
157    }
158}
159
160mod ser {
161    use super::*;
162    use serde::{ser::SerializeStruct, Deserialize, Serialize};
163
164    #[derive(Deserialize)]
165    struct Pool<State> {
166        counter: u64,
167        list: BTreeMap<u64, State>,
168    }
169
170    impl<State, Key> Serialize for super::DistributedPool<State, Key>
171    where
172        State: Serialize,
173        Key: Ord,
174    {
175        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176        where
177            S: serde::Serializer,
178        {
179            let mut s = serializer.serialize_struct("Pool", 2)?;
180            s.serialize_field("counter", &self.counter)?;
181            s.serialize_field("list", &self.list)?;
182            s.end()
183        }
184    }
185    impl<'de, State, Key> Deserialize<'de> for super::DistributedPool<State, Key>
186    where
187        State: Deserialize<'de> + AsRef<Key>,
188        Key: Ord + Clone,
189    {
190        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191        where
192            D: serde::Deserializer<'de>,
193        {
194            let v = Pool::<State>::deserialize(deserializer)?;
195            let by_key = v
196                .list
197                .iter()
198                .map(|(k, v)| (v.as_ref().clone(), *k))
199                .collect();
200            Ok(Self {
201                counter: v.counter,
202                list: v.list,
203                by_key,
204            })
205        }
206    }
207}
208
209impl<State, Key: Ord> std::fmt::Debug for DistributedPool<State, Key> {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("Pool")
212            .field("counter", &self.counter)
213            .field("len", &self.list.len())
214            .finish()
215    }
216}