Skip to main content

o1_utils/
lazy_cache.rs

1#![allow(unsafe_code)]
2//! Polyfill of the `LazyLock` type in the std library as of Rust 1.80.
3//!
4//! The current file should be deleted soon, as we now support Rust 1.81 and
5//! use the official `LazyLock`, and [`LazyCache`] as a wrapper around `LazyLock`
6//! to allow for custom serialization definitions.
7
8use core::{cell::UnsafeCell, fmt, ops::Deref};
9use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
10
11#[cfg(feature = "std")]
12use alloc::boxed::Box;
13#[cfg(feature = "std")]
14type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
15
16/// A thread-safe, lazily-initialized value.
17pub struct LazyCache<T> {
18    #[cfg(feature = "std")]
19    pub(crate) once: std::sync::Once,
20    pub(crate) value: UnsafeCell<Option<T>>,
21    #[cfg(feature = "std")]
22    pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone)]
26pub enum LazyCacheError {
27    LockPoisoned,
28    UninitializedCache,
29    MissingFunctionOrInitializedTwice,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum LazyCacheErrorOr<E> {
34    Inner(E),
35    Outer(LazyCacheError),
36}
37
38// We never create a `&F` from a `&LazyCache<T, F>` so it is fine
39// to not impl `Sync` for `F`.
40unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
41unsafe impl<T: Send> Send for LazyCache<T> {}
42
43#[cfg(feature = "std")]
44impl<T> LazyCache<T> {
45    pub fn new<F>(f: F) -> Self
46    where
47        F: FnOnce() -> T + Send + Sync + 'static,
48    {
49        Self {
50            once: std::sync::Once::new(),
51            value: UnsafeCell::new(None),
52            init: UnsafeCell::new(Some(Box::new(f))),
53        }
54    }
55
56    fn try_initialize(&self) -> Result<(), LazyCacheError> {
57        let mut error = None;
58
59        self.once.call_once_force(|state| {
60            if state.is_poisoned() {
61                error = Some(LazyCacheError::LockPoisoned);
62                return;
63            }
64
65            let init_fn = unsafe { (*self.init.get()).take() };
66            match init_fn {
67                Some(f) => {
68                    let value = f();
69                    unsafe {
70                        *self.value.get() = Some(value);
71                    }
72                }
73                None => {
74                    error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
75                }
76            }
77        });
78
79        if let Some(e) = error {
80            return Err(e);
81        }
82
83        if self.once.is_completed() {
84            Ok(())
85        } else {
86            Err(LazyCacheError::LockPoisoned)
87        }
88    }
89
90    /// # Errors
91    ///
92    /// Returns `LazyCacheError` if initialization fails or the cache is poisoned.
93    pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
94        self.try_initialize()?;
95        unsafe {
96            (*self.value.get())
97                .as_ref()
98                .ok_or(LazyCacheError::UninitializedCache)
99        }
100    }
101
102    /// # Panics
103    ///
104    /// Panics if initialization fails or the cache is poisoned.
105    pub fn get(&self) -> &T {
106        self.try_get().unwrap()
107    }
108}
109
110impl<T> LazyCache<T> {
111    /// Creates a new lazy value that is already initialized.
112    pub fn preinit(value: T) -> Self {
113        #[cfg(feature = "std")]
114        {
115            let once = std::sync::Once::new();
116            once.call_once(|| {});
117            Self {
118                once,
119                value: UnsafeCell::new(Some(value)),
120                init: UnsafeCell::new(None),
121            }
122        }
123        #[cfg(not(feature = "std"))]
124        {
125            Self {
126                value: UnsafeCell::new(Some(value)),
127            }
128        }
129    }
130}
131
132#[cfg(not(feature = "std"))]
133impl<T> LazyCache<T> {
134    /// # Panics
135    ///
136    /// Panics if the cache has not been pre-initialized.
137    pub fn get(&self) -> &T {
138        unsafe {
139            (*self.value.get())
140                .as_ref()
141                .expect("LazyCache not initialized (no_std mode requires preinit)")
142        }
143    }
144}
145
146// Wrapper to support cases where the init function might return an error that
147// needs to be handled separately (for example, LookupConstraintSystem::crate())
148impl<T, E: Clone> LazyCache<Result<T, E>> {
149    /// # Errors
150    ///
151    /// Returns `LazyCacheErrorOr` if initialization fails or the inner result is an error.
152    pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
153        #[cfg(feature = "std")]
154        let result = self.try_get();
155        #[cfg(not(feature = "std"))]
156        let result = unsafe {
157            (*self.value.get())
158                .as_ref()
159                .ok_or(LazyCacheError::UninitializedCache)
160        };
161
162        match result {
163            Ok(Ok(v)) => Ok(v),
164            Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
165            Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
166        }
167    }
168}
169
170impl<T> Deref for LazyCache<T> {
171    type Target = T;
172
173    fn deref(&self) -> &Self::Target {
174        self.get()
175    }
176}
177
178impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        // SAFETY: It's safe to access self.value here, read-only
181        let value = unsafe { &*self.value.get() };
182        match value {
183            Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
184            None => f.write_str("LazyCache(<uninitialized>)"),
185        }
186    }
187}
188
189impl<T> Serialize for LazyCache<T>
190where
191    T: Serialize + Send + Sync + 'static,
192{
193    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194    where
195        S: Serializer,
196    {
197        self.get().serialize(serializer)
198    }
199}
200
201impl<'de, T> Deserialize<'de> for LazyCache<T>
202where
203    T: DeserializeOwned + Send + Sync + 'static,
204{
205    // Deserializing will create a `LazyCache` with a cached value or an error
206    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
207    where
208        D: serde::Deserializer<'de>,
209    {
210        let value = T::deserialize(deserializer)?;
211        Ok(Self::preinit(value))
212    }
213}
214
215#[cfg(all(test, feature = "std"))]
216// Unit tests for LazyCache
217mod test {
218    use super::*;
219    use std::{
220        sync::{Arc, Mutex, Once},
221        thread,
222    };
223
224    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
225    fn print_heap_usage(label: &str) {
226        use tikv_jemalloc_ctl::{epoch, stats};
227
228        epoch::advance().unwrap(); // refresh internal stats!
229        let allocated = stats::allocated::read().unwrap();
230        println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
231    }
232
233    /// Test creating and getting `LazyCache` values
234    #[test]
235    fn test_lazy_cache() {
236        // get
237        {
238            // Cached variant
239            let cache = LazyCache::preinit(100);
240            assert_eq!(*cache.get(), 100);
241
242            // Lazy variant
243            let lazy = LazyCache::new(|| {
244                let a = 10;
245                let b = 20;
246                a + b
247            });
248            assert_eq!(*lazy.get(), 30);
249            // Ensure the value is cached and can be accessed multiple times
250            assert_eq!(*lazy.get(), 30);
251        }
252
253        // function called only once
254        {
255            let counter = Arc::new(Mutex::new(0));
256            let counter_clone = Arc::clone(&counter);
257
258            let cache = LazyCache::new(move || {
259                let mut count = counter_clone.lock().unwrap();
260                *count += 1;
261                // counter_clone will be dropped here
262                99
263            });
264
265            assert_eq!(*cache.get(), 99);
266            assert_eq!(*cache.get(), 99); // Ensure cached
267            assert_eq!(*counter.lock().unwrap(), 1); // Function was called exactly once
268        }
269        // serde
270        {
271            let cache = LazyCache::preinit(10);
272            let serialized = serde_json::to_string(&cache).unwrap();
273            let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
274            assert_eq!(*deserialized.get(), 10);
275        }
276        // debug
277        {
278            let cache = LazyCache::preinit(10);
279            assert_eq!(format!("{cache:?}"), "LazyCache(10)");
280
281            let lazy = LazyCache::new(|| 20);
282            assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
283        }
284        // LazyCacheError::MissingFunctionOrInitializedTwice
285        {
286            let cache: LazyCache<i32> = LazyCache {
287                once: Once::new(),
288                value: UnsafeCell::new(None),
289                init: UnsafeCell::new(None), // No function set
290            };
291            let err = cache.try_get();
292            assert_eq!(
293                err.unwrap_err(),
294                LazyCacheError::MissingFunctionOrInitializedTwice
295            );
296        }
297        // LazyCacheError::LockPoisoned
298        {
299            let lazy = Arc::new(LazyCache::<()>::new(|| {
300                panic!("poison the lock");
301            }));
302
303            let lazy_clone = Arc::clone(&lazy);
304            let _ = thread::spawn(move || {
305                let _ = lazy_clone.try_initialize();
306            })
307            .join(); // triggers panic inside init
308
309            // Now the Once is poisoned
310            let result = lazy.try_initialize();
311            assert_eq!(result, Err(LazyCacheError::LockPoisoned));
312        }
313    }
314
315    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
316    #[test]
317    fn test_lazy_cache_allocation() {
318        use tikv_jemallocator::Jemalloc;
319
320        #[global_allocator]
321        static GLOBAL: Jemalloc = Jemalloc;
322
323        print_heap_usage("Start");
324
325        let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); // 1MB
326
327        print_heap_usage("Before initializing LazyCache");
328
329        let _ = cache.get();
330
331        print_heap_usage("After initializing LazyCache");
332    }
333}