Skip to main content

o1_utils/
lazy_cache.rs

1#![allow(unsafe_code)]
2//! Polyfill of the `LazyLock` type in the std library as of Rust 1.80.
3//!
4//! The current file should be deleted soon, as we now support Rust 1.81 and
5//! use the official `LazyLock`, and [`LazyCache`] as a wrapper around `LazyLock`
6//! to allow for custom serialization definitions.
7
8use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
9use std::{cell::UnsafeCell, fmt, ops::Deref, sync::Once};
10
11type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
12
13/// A thread-safe, lazily-initialized value.
14pub struct LazyCache<T> {
15    pub(crate) once: Once,
16    pub(crate) value: UnsafeCell<Option<T>>,
17    pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
18}
19
20#[derive(Debug, PartialEq, Eq, Clone)]
21pub enum LazyCacheError {
22    LockPoisoned,
23    UninitializedCache,
24    MissingFunctionOrInitializedTwice,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub enum LazyCacheErrorOr<E> {
29    Inner(E),
30    Outer(LazyCacheError),
31}
32
33// We never create a `&F` from a `&LazyCache<T, F>` so it is fine
34// to not impl `Sync` for `F`.
35unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
36unsafe impl<T: Send> Send for LazyCache<T> {}
37
38// auto-derived `Send` impl is OK.
39//unsafe impl<T: Send, F: Send> Send for LazyCache<T, F> {}
40
41impl<T> LazyCache<T> {
42    pub fn new<F>(f: F) -> Self
43    where
44        F: FnOnce() -> T + Send + Sync + 'static,
45    {
46        Self {
47            once: Once::new(),
48            value: UnsafeCell::new(None),
49            init: UnsafeCell::new(Some(Box::new(f))),
50        }
51    }
52
53    /// Creates a new lazy value that is already initialized.
54    pub fn preinit(value: T) -> Self {
55        let once = Once::new();
56        once.call_once(|| {});
57        Self {
58            once,
59            value: UnsafeCell::new(Some(value)),
60            init: UnsafeCell::new(None),
61        }
62    }
63
64    fn try_initialize(&self) -> Result<(), LazyCacheError> {
65        let mut error = None;
66
67        self.once.call_once_force(|state| {
68            if state.is_poisoned() {
69                error = Some(LazyCacheError::LockPoisoned);
70                return;
71            }
72
73            let init_fn = unsafe { (*self.init.get()).take() };
74            match init_fn {
75                Some(f) => {
76                    let value = f();
77                    unsafe {
78                        *self.value.get() = Some(value);
79                    }
80                }
81                None => {
82                    error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
83                }
84            }
85        });
86
87        if let Some(e) = error {
88            return Err(e);
89        }
90
91        if self.once.is_completed() {
92            Ok(())
93        } else {
94            Err(LazyCacheError::LockPoisoned)
95        }
96    }
97
98    /// # Errors
99    ///
100    /// Returns `LazyCacheError` if initialization fails or the cache is poisoned.
101    pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
102        self.try_initialize()?;
103        unsafe {
104            (*self.value.get())
105                .as_ref()
106                .ok_or(LazyCacheError::UninitializedCache)
107        }
108    }
109
110    /// # Panics
111    ///
112    /// Panics if initialization fails or the cache is poisoned.
113    pub fn get(&self) -> &T {
114        self.try_get().unwrap()
115    }
116}
117
118// Wrapper to support cases where the init function might return an error that
119// needs to be handled separately (for example, LookupConstraintSystem::crate())
120impl<T, E: Clone> LazyCache<Result<T, E>> {
121    /// # Errors
122    ///
123    /// Returns `LazyCacheErrorOr` if initialization fails or the inner result is an error.
124    pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
125        match self.try_get() {
126            Ok(Ok(v)) => Ok(v),
127            Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
128            Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
129        }
130    }
131}
132
133impl<T> Deref for LazyCache<T> {
134    type Target = T;
135
136    fn deref(&self) -> &Self::Target {
137        self.get()
138    }
139}
140
141impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        // SAFETY: It's safe to access self.value here, read-only
144        let value = unsafe { &*self.value.get() };
145        match value {
146            Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
147            None => f.write_str("LazyCache(<uninitialized>)"),
148        }
149    }
150}
151
152impl<T> Serialize for LazyCache<T>
153where
154    T: Serialize + Send + Sync + 'static,
155{
156    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157    where
158        S: Serializer,
159    {
160        self.get().serialize(serializer)
161    }
162}
163
164impl<'de, T> Deserialize<'de> for LazyCache<T>
165where
166    T: DeserializeOwned + Send + Sync + 'static,
167{
168    // Deserializing will create a `LazyCache` with a cached value or an error
169    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170    where
171        D: serde::Deserializer<'de>,
172    {
173        let value = T::deserialize(deserializer)?;
174        Ok(Self::preinit(value))
175    }
176}
177
178#[cfg(test)]
179// Unit tests for LazyCache
180mod test {
181    use super::*;
182    use std::{
183        sync::{Arc, Mutex},
184        thread,
185    };
186
187    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
188    fn print_heap_usage(label: &str) {
189        use tikv_jemalloc_ctl::{epoch, stats};
190
191        epoch::advance().unwrap(); // refresh internal stats!
192        let allocated = stats::allocated::read().unwrap();
193        println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
194    }
195
196    /// Test creating and getting `LazyCache` values
197    #[test]
198    fn test_lazy_cache() {
199        // get
200        {
201            // Cached variant
202            let cache = LazyCache::preinit(100);
203            assert_eq!(*cache.get(), 100);
204
205            // Lazy variant
206            let lazy = LazyCache::new(|| {
207                let a = 10;
208                let b = 20;
209                a + b
210            });
211            assert_eq!(*lazy.get(), 30);
212            // Ensure the value is cached and can be accessed multiple times
213            assert_eq!(*lazy.get(), 30);
214        }
215
216        // function called only once
217        {
218            let counter = Arc::new(Mutex::new(0));
219            let counter_clone = Arc::clone(&counter);
220
221            let cache = LazyCache::new(move || {
222                let mut count = counter_clone.lock().unwrap();
223                *count += 1;
224                // counter_clone will be dropped here
225                99
226            });
227
228            assert_eq!(*cache.get(), 99);
229            assert_eq!(*cache.get(), 99); // Ensure cached
230            assert_eq!(*counter.lock().unwrap(), 1); // Function was called exactly once
231        }
232        // serde
233        {
234            let cache = LazyCache::preinit(10);
235            let serialized = serde_json::to_string(&cache).unwrap();
236            let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
237            assert_eq!(*deserialized.get(), 10);
238        }
239        // debug
240        {
241            let cache = LazyCache::preinit(10);
242            assert_eq!(format!("{cache:?}"), "LazyCache(10)");
243
244            let lazy = LazyCache::new(|| 20);
245            assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
246        }
247        // LazyCacheError::MissingFunctionOrInitializedTwice
248        {
249            let cache: LazyCache<i32> = LazyCache {
250                once: Once::new(),
251                value: UnsafeCell::new(None),
252                init: UnsafeCell::new(None), // No function set
253            };
254            let err = cache.try_get();
255            assert_eq!(
256                err.unwrap_err(),
257                LazyCacheError::MissingFunctionOrInitializedTwice
258            );
259        }
260        // LazyCacheError::LockPoisoned
261        {
262            let lazy = Arc::new(LazyCache::<()>::new(|| {
263                panic!("poison the lock");
264            }));
265
266            let lazy_clone = Arc::clone(&lazy);
267            let _ = thread::spawn(move || {
268                let _ = lazy_clone.try_initialize();
269            })
270            .join(); // triggers panic inside init
271
272            // Now the Once is poisoned
273            let result = lazy.try_initialize();
274            assert_eq!(result, Err(LazyCacheError::LockPoisoned));
275        }
276    }
277
278    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
279    #[test]
280    fn test_lazy_cache_allocation() {
281        use tikv_jemallocator::Jemalloc;
282
283        #[global_allocator]
284        static GLOBAL: Jemalloc = Jemalloc;
285
286        print_heap_usage("Start");
287
288        let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); // 1MB
289
290        print_heap_usage("Before initializing LazyCache");
291
292        let _ = cache.get();
293
294        print_heap_usage("After initializing LazyCache");
295    }
296}