o1_utils/
lazy_cache.rs

1//! This is a polyfill of the `LazyLock` type in the std library as of Rust 1.80.
2//! The current file should be deleted soon, as we now support Rust 1.81 and
3//! use the official `LazyLock`, and `LazyCache` as a wrapper around `LazyLock`
4//! to allow for custom serialization definitions.
5
6use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
7use std::{cell::UnsafeCell, fmt, ops::Deref, sync::Once};
8
9type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
10
11/// A thread-safe, lazily-initialized value.
12pub struct LazyCache<T> {
13    pub(crate) once: Once,
14    pub(crate) value: UnsafeCell<Option<T>>,
15    pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
16}
17
18#[derive(Debug, PartialEq, Clone)]
19pub enum LazyCacheError {
20    LockPoisoned,
21    UninitializedCache,
22    MissingFunctionOrInitializedTwice,
23}
24
25#[derive(Debug, PartialEq)]
26pub enum LazyCacheErrorOr<E> {
27    Inner(E),
28    Outer(LazyCacheError),
29}
30
31// We never create a `&F` from a `&LazyCache<T, F>` so it is fine
32// to not impl `Sync` for `F`.
33unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
34unsafe impl<T: Send> Send for LazyCache<T> {}
35
36// auto-derived `Send` impl is OK.
37//unsafe impl<T: Send, F: Send> Send for LazyCache<T, F> {}
38
39impl<T> LazyCache<T> {
40    pub fn new<F>(f: F) -> Self
41    where
42        F: FnOnce() -> T + Send + Sync + 'static,
43    {
44        LazyCache {
45            once: Once::new(),
46            value: UnsafeCell::new(None),
47            init: UnsafeCell::new(Some(Box::new(f))),
48        }
49    }
50
51    /// Creates a new lazy value that is already initialized.
52    pub fn preinit(value: T) -> LazyCache<T> {
53        let once = Once::new();
54        once.call_once(|| {});
55        LazyCache {
56            once,
57            value: UnsafeCell::new(Some(value)),
58            init: UnsafeCell::new(None),
59        }
60    }
61
62    fn try_initialize(&self) -> Result<(), LazyCacheError> {
63        let mut error = None;
64
65        self.once.call_once_force(|state| {
66            if state.is_poisoned() {
67                error = Some(LazyCacheError::LockPoisoned);
68                return;
69            }
70
71            let init_fn = unsafe { (*self.init.get()).take() };
72            match init_fn {
73                Some(f) => {
74                    let value = f();
75                    unsafe {
76                        *self.value.get() = Some(value);
77                    }
78                }
79                None => {
80                    error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
81                }
82            }
83        });
84
85        if let Some(e) = error {
86            return Err(e);
87        }
88
89        if self.once.is_completed() {
90            Ok(())
91        } else {
92            Err(LazyCacheError::LockPoisoned)
93        }
94    }
95
96    pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
97        self.try_initialize()?;
98        unsafe {
99            (*self.value.get())
100                .as_ref()
101                .ok_or(LazyCacheError::UninitializedCache)
102        }
103    }
104
105    pub fn get(&self) -> &T {
106        self.try_get().unwrap()
107    }
108}
109
110// Wrapper to support cases where the init function might return an error that
111// needs to be handled separately (for example, LookupConstraintSystem::crate())
112impl<T, E: Clone> LazyCache<Result<T, E>> {
113    pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
114        match self.try_get() {
115            Ok(Ok(v)) => Ok(v),
116            Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
117            Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
118        }
119    }
120}
121
122impl<T> Deref for LazyCache<T> {
123    type Target = T;
124
125    fn deref(&self) -> &Self::Target {
126        self.get()
127    }
128}
129
130impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        // SAFETY: It's safe to access self.value here, read-only
133        let value = unsafe { &*self.value.get() };
134        match value {
135            Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
136            None => f.write_str("LazyCache(<uninitialized>)"),
137        }
138    }
139}
140
141impl<T> Serialize for LazyCache<T>
142where
143    T: Serialize + Send + Sync + 'static,
144{
145    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
146    where
147        S: Serializer,
148    {
149        self.get().serialize(serializer)
150    }
151}
152
153impl<'de, T> Deserialize<'de> for LazyCache<T>
154where
155    T: DeserializeOwned + Send + Sync + 'static,
156{
157    // Deserializing will create a `LazyCache` with a cached value or an error
158    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159    where
160        D: serde::Deserializer<'de>,
161    {
162        let value = T::deserialize(deserializer)?;
163        Ok(LazyCache::preinit(value))
164    }
165}
166
167#[cfg(test)]
168// Unit tests for LazyCache
169mod test {
170    use super::*;
171    use std::{
172        sync::{Arc, Mutex},
173        thread,
174    };
175
176    #[cfg(not(target_arch = "wasm32"))]
177    fn print_heap_usage(label: &str) {
178        use tikv_jemalloc_ctl::{epoch, stats};
179
180        epoch::advance().unwrap(); // refresh internal stats!
181        let allocated = stats::allocated::read().unwrap();
182        println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
183    }
184
185    /// Test creating and getting `LazyCache` values
186    #[test]
187    fn test_lazy_cache() {
188        // get
189        {
190            // Cached variant
191            let cache = LazyCache::preinit(100);
192            assert_eq!(*cache.get(), 100);
193
194            // Lazy variant
195            let lazy = LazyCache::new(|| {
196                let a = 10;
197                let b = 20;
198                a + b
199            });
200            assert_eq!(*lazy.get(), 30);
201            // Ensure the value is cached and can be accessed multiple times
202            assert_eq!(*lazy.get(), 30);
203        }
204
205        // function called only once
206        {
207            let counter = Arc::new(Mutex::new(0));
208            let counter_clone = Arc::clone(&counter);
209
210            let cache = LazyCache::new(move || {
211                let mut count = counter_clone.lock().unwrap();
212                *count += 1;
213                // counter_clone will be dropped here
214                99
215            });
216
217            assert_eq!(*cache.get(), 99);
218            assert_eq!(*cache.get(), 99); // Ensure cached
219            assert_eq!(*counter.lock().unwrap(), 1); // Function was called exactly once
220        }
221        // serde
222        {
223            let cache = LazyCache::preinit(10);
224            let serialized = serde_json::to_string(&cache).unwrap();
225            let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
226            assert_eq!(*deserialized.get(), 10);
227        }
228        // debug
229        {
230            let cache = LazyCache::preinit(10);
231            assert_eq!(format!("{:?}", cache), "LazyCache(10)");
232
233            let lazy = LazyCache::new(|| 20);
234            assert_eq!(format!("{:?}", lazy), "LazyCache(<uninitialized>)");
235        }
236        // LazyCacheError::MissingFunctionOrInitializedTwice
237        {
238            let cache: LazyCache<i32> = LazyCache {
239                once: Once::new(),
240                value: UnsafeCell::new(None),
241                init: UnsafeCell::new(None), // No function set
242            };
243            let err = cache.try_get();
244            assert_eq!(
245                err.unwrap_err(),
246                LazyCacheError::MissingFunctionOrInitializedTwice
247            );
248        }
249        // LazyCacheError::LockPoisoned
250        {
251            let lazy = Arc::new(LazyCache::<()>::new(|| {
252                panic!("poison the lock");
253            }));
254
255            let lazy_clone = Arc::clone(&lazy);
256            let _ = thread::spawn(move || {
257                let _ = lazy_clone.try_initialize();
258            })
259            .join(); // triggers panic inside init
260
261            // Now the Once is poisoned
262            let result = lazy.try_initialize();
263            assert_eq!(result, Err(LazyCacheError::LockPoisoned));
264        }
265    }
266
267    #[cfg(not(target_arch = "wasm32"))]
268    #[test]
269    fn test_lazy_cache_allocation() {
270        use tikv_jemallocator::Jemalloc;
271
272        #[global_allocator]
273        static GLOBAL: Jemalloc = Jemalloc;
274
275        print_heap_usage("Start");
276
277        let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); // 1MB
278
279        print_heap_usage("Before initializing LazyCache");
280
281        let _ = cache.get();
282
283        print_heap_usage("After initializing LazyCache");
284    }
285}