Skip to main content

o1_utils/
lazy_cache.rs

1#![allow(unsafe_code)]
2//! Polyfill of the `LazyLock` type in the std library as of Rust 1.80.
3//!
4//! The current file should be deleted soon, as we now support Rust 1.81 and
5//! use the official `LazyLock`, and [`LazyCache`] as a wrapper around `LazyLock`
6//! to allow for custom serialization definitions.
7
8use core::{cell::UnsafeCell, fmt, ops::Deref};
9use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
10
11#[cfg(feature = "std")]
12use alloc::boxed::Box;
13#[cfg(feature = "std")]
14type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
15
16/// A thread-safe, lazily-initialized value.
17pub struct LazyCache<T> {
18    #[cfg(feature = "std")]
19    pub(crate) once: std::sync::Once,
20    pub(crate) value: UnsafeCell<Option<T>>,
21    #[cfg(feature = "std")]
22    pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone)]
26pub enum LazyCacheError {
27    LockPoisoned,
28    UninitializedCache,
29    MissingFunctionOrInitializedTwice,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum LazyCacheErrorOr<E> {
34    Inner(E),
35    Outer(LazyCacheError),
36}
37
38// We never create a `&F` from a `&LazyCache<T, F>` so it is fine
39// to not impl `Sync` for `F`.
40unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
41unsafe impl<T: Send> Send for LazyCache<T> {}
42
43#[cfg(feature = "std")]
44impl<T> LazyCache<T> {
45    pub fn new<F>(f: F) -> Self
46    where
47        F: FnOnce() -> T + Send + Sync + 'static,
48    {
49        Self {
50            once: std::sync::Once::new(),
51            value: UnsafeCell::new(None),
52            init: UnsafeCell::new(Some(Box::new(f))),
53        }
54    }
55
56    fn try_initialize(&self) -> Result<(), LazyCacheError> {
57        let mut error = None;
58
59        self.once.call_once_force(|state| {
60            if state.is_poisoned() {
61                error = Some(LazyCacheError::LockPoisoned);
62                return;
63            }
64
65            let init_fn = unsafe { (*self.init.get()).take() };
66            match init_fn {
67                Some(f) => {
68                    let value = f();
69                    unsafe {
70                        *self.value.get() = Some(value);
71                    }
72                }
73                None => {
74                    error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
75                }
76            }
77        });
78
79        if let Some(e) = error {
80            return Err(e);
81        }
82
83        if self.once.is_completed() {
84            Ok(())
85        } else {
86            Err(LazyCacheError::LockPoisoned)
87        }
88    }
89
90    /// # Errors
91    ///
92    /// Returns `LazyCacheError` if initialization fails or the cache is poisoned.
93    pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
94        self.try_initialize()?;
95        unsafe {
96            (*self.value.get())
97                .as_ref()
98                .ok_or(LazyCacheError::UninitializedCache)
99        }
100    }
101
102    /// # Panics
103    ///
104    /// Panics if initialization fails or the cache is poisoned.
105    pub fn get(&self) -> &T {
106        self.try_get().unwrap()
107    }
108}
109
110impl<T> LazyCache<T> {
111    #[allow(clippy::nursery)]
112    /// Creates a new lazy value that is already initialized.
113    pub fn preinit(value: T) -> Self {
114        #[cfg(feature = "std")]
115        {
116            let once = std::sync::Once::new();
117            once.call_once(|| {});
118            Self {
119                once,
120                value: UnsafeCell::new(Some(value)),
121                init: UnsafeCell::new(None),
122            }
123        }
124        #[cfg(not(feature = "std"))]
125        {
126            Self {
127                value: UnsafeCell::new(Some(value)),
128            }
129        }
130    }
131}
132
133#[cfg(not(feature = "std"))]
134impl<T> LazyCache<T> {
135    /// # Panics
136    ///
137    /// Panics if the cache has not been pre-initialized.
138    pub fn get(&self) -> &T {
139        unsafe {
140            (*self.value.get())
141                .as_ref()
142                .expect("LazyCache not initialized (no_std mode requires preinit)")
143        }
144    }
145}
146
147// Wrapper to support cases where the init function might return an error that
148// needs to be handled separately (for example, LookupConstraintSystem::crate())
149impl<T, E: Clone> LazyCache<Result<T, E>> {
150    /// # Errors
151    ///
152    /// Returns `LazyCacheErrorOr` if initialization fails or the inner result is an error.
153    pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
154        #[cfg(feature = "std")]
155        let result = self.try_get();
156        #[cfg(not(feature = "std"))]
157        let result = unsafe {
158            (*self.value.get())
159                .as_ref()
160                .ok_or(LazyCacheError::UninitializedCache)
161        };
162
163        match result {
164            Ok(Ok(v)) => Ok(v),
165            Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
166            Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
167        }
168    }
169}
170
171impl<T> Deref for LazyCache<T> {
172    type Target = T;
173
174    fn deref(&self) -> &Self::Target {
175        self.get()
176    }
177}
178
179impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        // SAFETY: It's safe to access self.value here, read-only
182        let value = unsafe { &*self.value.get() };
183        match value {
184            Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
185            None => f.write_str("LazyCache(<uninitialized>)"),
186        }
187    }
188}
189
190impl<T> Serialize for LazyCache<T>
191where
192    T: Serialize + Send + Sync + 'static,
193{
194    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195    where
196        S: Serializer,
197    {
198        self.get().serialize(serializer)
199    }
200}
201
202impl<'de, T> Deserialize<'de> for LazyCache<T>
203where
204    T: DeserializeOwned + Send + Sync + 'static,
205{
206    // Deserializing will create a `LazyCache` with a cached value or an error
207    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
208    where
209        D: serde::Deserializer<'de>,
210    {
211        let value = T::deserialize(deserializer)?;
212        Ok(Self::preinit(value))
213    }
214}
215
216#[cfg(all(test, feature = "std"))]
217// Unit tests for LazyCache
218mod test {
219    use super::*;
220    use std::{
221        sync::{Arc, Mutex, Once},
222        thread,
223    };
224
225    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
226    fn print_heap_usage(label: &str) {
227        use tikv_jemalloc_ctl::{epoch, stats};
228
229        epoch::advance().unwrap(); // refresh internal stats!
230        let allocated = stats::allocated::read().unwrap();
231        println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
232    }
233
234    /// Test creating and getting `LazyCache` values
235    #[test]
236    fn test_lazy_cache() {
237        // get
238        {
239            // Cached variant
240            let cache = LazyCache::preinit(100);
241            assert_eq!(*cache.get(), 100);
242
243            // Lazy variant
244            let lazy = LazyCache::new(|| {
245                let a = 10;
246                let b = 20;
247                a + b
248            });
249            assert_eq!(*lazy.get(), 30);
250            // Ensure the value is cached and can be accessed multiple times
251            assert_eq!(*lazy.get(), 30);
252        }
253
254        // function called only once
255        {
256            let counter = Arc::new(Mutex::new(0));
257            let counter_clone = Arc::clone(&counter);
258
259            let cache = LazyCache::new(move || {
260                let mut count = counter_clone.lock().unwrap();
261                *count += 1;
262                // counter_clone will be dropped here
263                99
264            });
265
266            assert_eq!(*cache.get(), 99);
267            assert_eq!(*cache.get(), 99); // Ensure cached
268            assert_eq!(*counter.lock().unwrap(), 1); // Function was called exactly once
269        }
270        // serde
271        {
272            let cache = LazyCache::preinit(10);
273            let serialized = serde_json::to_string(&cache).unwrap();
274            let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
275            assert_eq!(*deserialized.get(), 10);
276        }
277        // debug
278        {
279            let cache = LazyCache::preinit(10);
280            assert_eq!(format!("{cache:?}"), "LazyCache(10)");
281
282            let lazy = LazyCache::new(|| 20);
283            assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
284        }
285        // LazyCacheError::MissingFunctionOrInitializedTwice
286        {
287            let cache: LazyCache<i32> = LazyCache {
288                once: Once::new(),
289                value: UnsafeCell::new(None),
290                init: UnsafeCell::new(None), // No function set
291            };
292            let err = cache.try_get();
293            assert_eq!(
294                err.unwrap_err(),
295                LazyCacheError::MissingFunctionOrInitializedTwice
296            );
297        }
298        // LazyCacheError::LockPoisoned
299        {
300            let lazy = Arc::new(LazyCache::<()>::new(|| {
301                panic!("poison the lock");
302            }));
303
304            let lazy_clone = Arc::clone(&lazy);
305            let _ = thread::spawn(move || {
306                let _ = lazy_clone.try_initialize();
307            })
308            .join(); // triggers panic inside init
309
310            // Now the Once is poisoned
311            let result = lazy.try_initialize();
312            assert_eq!(result, Err(LazyCacheError::LockPoisoned));
313        }
314    }
315
316    #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
317    #[test]
318    fn test_lazy_cache_allocation() {
319        use tikv_jemallocator::Jemalloc;
320
321        #[global_allocator]
322        static GLOBAL: Jemalloc = Jemalloc;
323
324        print_heap_usage("Start");
325
326        let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); // 1MB
327
328        print_heap_usage("Before initializing LazyCache");
329
330        let _ = cache.get();
331
332        print_heap_usage("After initializing LazyCache");
333    }
334}