1#![allow(unsafe_code)]
2use core::{cell::UnsafeCell, fmt, ops::Deref};
9use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
10
11#[cfg(feature = "std")]
12use alloc::boxed::Box;
13#[cfg(feature = "std")]
14type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
15
16pub struct LazyCache<T> {
18 #[cfg(feature = "std")]
19 pub(crate) once: std::sync::Once,
20 pub(crate) value: UnsafeCell<Option<T>>,
21 #[cfg(feature = "std")]
22 pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone)]
26pub enum LazyCacheError {
27 LockPoisoned,
28 UninitializedCache,
29 MissingFunctionOrInitializedTwice,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum LazyCacheErrorOr<E> {
34 Inner(E),
35 Outer(LazyCacheError),
36}
37
38unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
41unsafe impl<T: Send> Send for LazyCache<T> {}
42
43#[cfg(feature = "std")]
44impl<T> LazyCache<T> {
45 pub fn new<F>(f: F) -> Self
46 where
47 F: FnOnce() -> T + Send + Sync + 'static,
48 {
49 Self {
50 once: std::sync::Once::new(),
51 value: UnsafeCell::new(None),
52 init: UnsafeCell::new(Some(Box::new(f))),
53 }
54 }
55
56 fn try_initialize(&self) -> Result<(), LazyCacheError> {
57 let mut error = None;
58
59 self.once.call_once_force(|state| {
60 if state.is_poisoned() {
61 error = Some(LazyCacheError::LockPoisoned);
62 return;
63 }
64
65 let init_fn = unsafe { (*self.init.get()).take() };
66 match init_fn {
67 Some(f) => {
68 let value = f();
69 unsafe {
70 *self.value.get() = Some(value);
71 }
72 }
73 None => {
74 error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
75 }
76 }
77 });
78
79 if let Some(e) = error {
80 return Err(e);
81 }
82
83 if self.once.is_completed() {
84 Ok(())
85 } else {
86 Err(LazyCacheError::LockPoisoned)
87 }
88 }
89
90 pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
94 self.try_initialize()?;
95 unsafe {
96 (*self.value.get())
97 .as_ref()
98 .ok_or(LazyCacheError::UninitializedCache)
99 }
100 }
101
102 pub fn get(&self) -> &T {
106 self.try_get().unwrap()
107 }
108}
109
110impl<T> LazyCache<T> {
111 pub fn preinit(value: T) -> Self {
113 #[cfg(feature = "std")]
114 {
115 let once = std::sync::Once::new();
116 once.call_once(|| {});
117 Self {
118 once,
119 value: UnsafeCell::new(Some(value)),
120 init: UnsafeCell::new(None),
121 }
122 }
123 #[cfg(not(feature = "std"))]
124 {
125 Self {
126 value: UnsafeCell::new(Some(value)),
127 }
128 }
129 }
130}
131
132#[cfg(not(feature = "std"))]
133impl<T> LazyCache<T> {
134 pub fn get(&self) -> &T {
138 unsafe {
139 (*self.value.get())
140 .as_ref()
141 .expect("LazyCache not initialized (no_std mode requires preinit)")
142 }
143 }
144}
145
146impl<T, E: Clone> LazyCache<Result<T, E>> {
149 pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
153 #[cfg(feature = "std")]
154 let result = self.try_get();
155 #[cfg(not(feature = "std"))]
156 let result = unsafe {
157 (*self.value.get())
158 .as_ref()
159 .ok_or(LazyCacheError::UninitializedCache)
160 };
161
162 match result {
163 Ok(Ok(v)) => Ok(v),
164 Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
165 Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
166 }
167 }
168}
169
170impl<T> Deref for LazyCache<T> {
171 type Target = T;
172
173 fn deref(&self) -> &Self::Target {
174 self.get()
175 }
176}
177
178impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 let value = unsafe { &*self.value.get() };
182 match value {
183 Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
184 None => f.write_str("LazyCache(<uninitialized>)"),
185 }
186 }
187}
188
189impl<T> Serialize for LazyCache<T>
190where
191 T: Serialize + Send + Sync + 'static,
192{
193 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: Serializer,
196 {
197 self.get().serialize(serializer)
198 }
199}
200
201impl<'de, T> Deserialize<'de> for LazyCache<T>
202where
203 T: DeserializeOwned + Send + Sync + 'static,
204{
205 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
207 where
208 D: serde::Deserializer<'de>,
209 {
210 let value = T::deserialize(deserializer)?;
211 Ok(Self::preinit(value))
212 }
213}
214
215#[cfg(all(test, feature = "std"))]
216mod test {
218 use super::*;
219 use std::{
220 sync::{Arc, Mutex, Once},
221 thread,
222 };
223
224 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
225 fn print_heap_usage(label: &str) {
226 use tikv_jemalloc_ctl::{epoch, stats};
227
228 epoch::advance().unwrap(); let allocated = stats::allocated::read().unwrap();
230 println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
231 }
232
233 #[test]
235 fn test_lazy_cache() {
236 {
238 let cache = LazyCache::preinit(100);
240 assert_eq!(*cache.get(), 100);
241
242 let lazy = LazyCache::new(|| {
244 let a = 10;
245 let b = 20;
246 a + b
247 });
248 assert_eq!(*lazy.get(), 30);
249 assert_eq!(*lazy.get(), 30);
251 }
252
253 {
255 let counter = Arc::new(Mutex::new(0));
256 let counter_clone = Arc::clone(&counter);
257
258 let cache = LazyCache::new(move || {
259 let mut count = counter_clone.lock().unwrap();
260 *count += 1;
261 99
263 });
264
265 assert_eq!(*cache.get(), 99);
266 assert_eq!(*cache.get(), 99); assert_eq!(*counter.lock().unwrap(), 1); }
269 {
271 let cache = LazyCache::preinit(10);
272 let serialized = serde_json::to_string(&cache).unwrap();
273 let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
274 assert_eq!(*deserialized.get(), 10);
275 }
276 {
278 let cache = LazyCache::preinit(10);
279 assert_eq!(format!("{cache:?}"), "LazyCache(10)");
280
281 let lazy = LazyCache::new(|| 20);
282 assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
283 }
284 {
286 let cache: LazyCache<i32> = LazyCache {
287 once: Once::new(),
288 value: UnsafeCell::new(None),
289 init: UnsafeCell::new(None), };
291 let err = cache.try_get();
292 assert_eq!(
293 err.unwrap_err(),
294 LazyCacheError::MissingFunctionOrInitializedTwice
295 );
296 }
297 {
299 let lazy = Arc::new(LazyCache::<()>::new(|| {
300 panic!("poison the lock");
301 }));
302
303 let lazy_clone = Arc::clone(&lazy);
304 let _ = thread::spawn(move || {
305 let _ = lazy_clone.try_initialize();
306 })
307 .join(); let result = lazy.try_initialize();
311 assert_eq!(result, Err(LazyCacheError::LockPoisoned));
312 }
313 }
314
315 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
316 #[test]
317 fn test_lazy_cache_allocation() {
318 use tikv_jemallocator::Jemalloc;
319
320 #[global_allocator]
321 static GLOBAL: Jemalloc = Jemalloc;
322
323 print_heap_usage("Start");
324
325 let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); print_heap_usage("Before initializing LazyCache");
328
329 let _ = cache.get();
330
331 print_heap_usage("After initializing LazyCache");
332 }
333}