1#![allow(unsafe_code)]
2use core::{cell::UnsafeCell, fmt, ops::Deref};
9use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
10
11#[cfg(feature = "std")]
12use alloc::boxed::Box;
13#[cfg(feature = "std")]
14type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
15
16pub struct LazyCache<T> {
18 #[cfg(feature = "std")]
19 pub(crate) once: std::sync::Once,
20 pub(crate) value: UnsafeCell<Option<T>>,
21 #[cfg(feature = "std")]
22 pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone)]
26pub enum LazyCacheError {
27 LockPoisoned,
28 UninitializedCache,
29 MissingFunctionOrInitializedTwice,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum LazyCacheErrorOr<E> {
34 Inner(E),
35 Outer(LazyCacheError),
36}
37
38unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
41unsafe impl<T: Send> Send for LazyCache<T> {}
42
43#[cfg(feature = "std")]
44impl<T> LazyCache<T> {
45 pub fn new<F>(f: F) -> Self
46 where
47 F: FnOnce() -> T + Send + Sync + 'static,
48 {
49 Self {
50 once: std::sync::Once::new(),
51 value: UnsafeCell::new(None),
52 init: UnsafeCell::new(Some(Box::new(f))),
53 }
54 }
55
56 fn try_initialize(&self) -> Result<(), LazyCacheError> {
57 let mut error = None;
58
59 self.once.call_once_force(|state| {
60 if state.is_poisoned() {
61 error = Some(LazyCacheError::LockPoisoned);
62 return;
63 }
64
65 let init_fn = unsafe { (*self.init.get()).take() };
66 match init_fn {
67 Some(f) => {
68 let value = f();
69 unsafe {
70 *self.value.get() = Some(value);
71 }
72 }
73 None => {
74 error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
75 }
76 }
77 });
78
79 if let Some(e) = error {
80 return Err(e);
81 }
82
83 if self.once.is_completed() {
84 Ok(())
85 } else {
86 Err(LazyCacheError::LockPoisoned)
87 }
88 }
89
90 pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
94 self.try_initialize()?;
95 unsafe {
96 (*self.value.get())
97 .as_ref()
98 .ok_or(LazyCacheError::UninitializedCache)
99 }
100 }
101
102 pub fn get(&self) -> &T {
106 self.try_get().unwrap()
107 }
108}
109
110impl<T> LazyCache<T> {
111 #[allow(clippy::nursery)]
112 pub fn preinit(value: T) -> Self {
114 #[cfg(feature = "std")]
115 {
116 let once = std::sync::Once::new();
117 once.call_once(|| {});
118 Self {
119 once,
120 value: UnsafeCell::new(Some(value)),
121 init: UnsafeCell::new(None),
122 }
123 }
124 #[cfg(not(feature = "std"))]
125 {
126 Self {
127 value: UnsafeCell::new(Some(value)),
128 }
129 }
130 }
131}
132
133#[cfg(not(feature = "std"))]
134impl<T> LazyCache<T> {
135 pub fn get(&self) -> &T {
139 unsafe {
140 (*self.value.get())
141 .as_ref()
142 .expect("LazyCache not initialized (no_std mode requires preinit)")
143 }
144 }
145}
146
147impl<T, E: Clone> LazyCache<Result<T, E>> {
150 pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
154 #[cfg(feature = "std")]
155 let result = self.try_get();
156 #[cfg(not(feature = "std"))]
157 let result = unsafe {
158 (*self.value.get())
159 .as_ref()
160 .ok_or(LazyCacheError::UninitializedCache)
161 };
162
163 match result {
164 Ok(Ok(v)) => Ok(v),
165 Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
166 Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
167 }
168 }
169}
170
171impl<T> Deref for LazyCache<T> {
172 type Target = T;
173
174 fn deref(&self) -> &Self::Target {
175 self.get()
176 }
177}
178
179impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 let value = unsafe { &*self.value.get() };
183 match value {
184 Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
185 None => f.write_str("LazyCache(<uninitialized>)"),
186 }
187 }
188}
189
190impl<T> Serialize for LazyCache<T>
191where
192 T: Serialize + Send + Sync + 'static,
193{
194 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195 where
196 S: Serializer,
197 {
198 self.get().serialize(serializer)
199 }
200}
201
202impl<'de, T> Deserialize<'de> for LazyCache<T>
203where
204 T: DeserializeOwned + Send + Sync + 'static,
205{
206 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
208 where
209 D: serde::Deserializer<'de>,
210 {
211 let value = T::deserialize(deserializer)?;
212 Ok(Self::preinit(value))
213 }
214}
215
216#[cfg(all(test, feature = "std"))]
217mod test {
219 use super::*;
220 use std::{
221 sync::{Arc, Mutex, Once},
222 thread,
223 };
224
225 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
226 fn print_heap_usage(label: &str) {
227 use tikv_jemalloc_ctl::{epoch, stats};
228
229 epoch::advance().unwrap(); let allocated = stats::allocated::read().unwrap();
231 println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
232 }
233
234 #[test]
236 fn test_lazy_cache() {
237 {
239 let cache = LazyCache::preinit(100);
241 assert_eq!(*cache.get(), 100);
242
243 let lazy = LazyCache::new(|| {
245 let a = 10;
246 let b = 20;
247 a + b
248 });
249 assert_eq!(*lazy.get(), 30);
250 assert_eq!(*lazy.get(), 30);
252 }
253
254 {
256 let counter = Arc::new(Mutex::new(0));
257 let counter_clone = Arc::clone(&counter);
258
259 let cache = LazyCache::new(move || {
260 let mut count = counter_clone.lock().unwrap();
261 *count += 1;
262 99
264 });
265
266 assert_eq!(*cache.get(), 99);
267 assert_eq!(*cache.get(), 99); assert_eq!(*counter.lock().unwrap(), 1); }
270 {
272 let cache = LazyCache::preinit(10);
273 let serialized = serde_json::to_string(&cache).unwrap();
274 let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
275 assert_eq!(*deserialized.get(), 10);
276 }
277 {
279 let cache = LazyCache::preinit(10);
280 assert_eq!(format!("{cache:?}"), "LazyCache(10)");
281
282 let lazy = LazyCache::new(|| 20);
283 assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
284 }
285 {
287 let cache: LazyCache<i32> = LazyCache {
288 once: Once::new(),
289 value: UnsafeCell::new(None),
290 init: UnsafeCell::new(None), };
292 let err = cache.try_get();
293 assert_eq!(
294 err.unwrap_err(),
295 LazyCacheError::MissingFunctionOrInitializedTwice
296 );
297 }
298 {
300 let lazy = Arc::new(LazyCache::<()>::new(|| {
301 panic!("poison the lock");
302 }));
303
304 let lazy_clone = Arc::clone(&lazy);
305 let _ = thread::spawn(move || {
306 let _ = lazy_clone.try_initialize();
307 })
308 .join(); let result = lazy.try_initialize();
312 assert_eq!(result, Err(LazyCacheError::LockPoisoned));
313 }
314 }
315
316 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
317 #[test]
318 fn test_lazy_cache_allocation() {
319 use tikv_jemallocator::Jemalloc;
320
321 #[global_allocator]
322 static GLOBAL: Jemalloc = Jemalloc;
323
324 print_heap_usage("Start");
325
326 let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); print_heap_usage("Before initializing LazyCache");
329
330 let _ = cache.get();
331
332 print_heap_usage("After initializing LazyCache");
333 }
334}