1#![allow(unsafe_code)]
2use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
9use std::{cell::UnsafeCell, fmt, ops::Deref, sync::Once};
10
11type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
12
13pub struct LazyCache<T> {
15 pub(crate) once: Once,
16 pub(crate) value: UnsafeCell<Option<T>>,
17 pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
18}
19
20#[derive(Debug, PartialEq, Eq, Clone)]
21pub enum LazyCacheError {
22 LockPoisoned,
23 UninitializedCache,
24 MissingFunctionOrInitializedTwice,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub enum LazyCacheErrorOr<E> {
29 Inner(E),
30 Outer(LazyCacheError),
31}
32
33unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
36unsafe impl<T: Send> Send for LazyCache<T> {}
37
38impl<T> LazyCache<T> {
42 pub fn new<F>(f: F) -> Self
43 where
44 F: FnOnce() -> T + Send + Sync + 'static,
45 {
46 Self {
47 once: Once::new(),
48 value: UnsafeCell::new(None),
49 init: UnsafeCell::new(Some(Box::new(f))),
50 }
51 }
52
53 pub fn preinit(value: T) -> Self {
55 let once = Once::new();
56 once.call_once(|| {});
57 Self {
58 once,
59 value: UnsafeCell::new(Some(value)),
60 init: UnsafeCell::new(None),
61 }
62 }
63
64 fn try_initialize(&self) -> Result<(), LazyCacheError> {
65 let mut error = None;
66
67 self.once.call_once_force(|state| {
68 if state.is_poisoned() {
69 error = Some(LazyCacheError::LockPoisoned);
70 return;
71 }
72
73 let init_fn = unsafe { (*self.init.get()).take() };
74 match init_fn {
75 Some(f) => {
76 let value = f();
77 unsafe {
78 *self.value.get() = Some(value);
79 }
80 }
81 None => {
82 error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
83 }
84 }
85 });
86
87 if let Some(e) = error {
88 return Err(e);
89 }
90
91 if self.once.is_completed() {
92 Ok(())
93 } else {
94 Err(LazyCacheError::LockPoisoned)
95 }
96 }
97
98 pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
102 self.try_initialize()?;
103 unsafe {
104 (*self.value.get())
105 .as_ref()
106 .ok_or(LazyCacheError::UninitializedCache)
107 }
108 }
109
110 pub fn get(&self) -> &T {
114 self.try_get().unwrap()
115 }
116}
117
118impl<T, E: Clone> LazyCache<Result<T, E>> {
121 pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
125 match self.try_get() {
126 Ok(Ok(v)) => Ok(v),
127 Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
128 Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
129 }
130 }
131}
132
133impl<T> Deref for LazyCache<T> {
134 type Target = T;
135
136 fn deref(&self) -> &Self::Target {
137 self.get()
138 }
139}
140
141impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 let value = unsafe { &*self.value.get() };
145 match value {
146 Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
147 None => f.write_str("LazyCache(<uninitialized>)"),
148 }
149 }
150}
151
152impl<T> Serialize for LazyCache<T>
153where
154 T: Serialize + Send + Sync + 'static,
155{
156 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157 where
158 S: Serializer,
159 {
160 self.get().serialize(serializer)
161 }
162}
163
164impl<'de, T> Deserialize<'de> for LazyCache<T>
165where
166 T: DeserializeOwned + Send + Sync + 'static,
167{
168 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170 where
171 D: serde::Deserializer<'de>,
172 {
173 let value = T::deserialize(deserializer)?;
174 Ok(Self::preinit(value))
175 }
176}
177
178#[cfg(test)]
179mod test {
181 use super::*;
182 use std::{
183 sync::{Arc, Mutex},
184 thread,
185 };
186
187 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
188 fn print_heap_usage(label: &str) {
189 use tikv_jemalloc_ctl::{epoch, stats};
190
191 epoch::advance().unwrap(); let allocated = stats::allocated::read().unwrap();
193 println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
194 }
195
196 #[test]
198 fn test_lazy_cache() {
199 {
201 let cache = LazyCache::preinit(100);
203 assert_eq!(*cache.get(), 100);
204
205 let lazy = LazyCache::new(|| {
207 let a = 10;
208 let b = 20;
209 a + b
210 });
211 assert_eq!(*lazy.get(), 30);
212 assert_eq!(*lazy.get(), 30);
214 }
215
216 {
218 let counter = Arc::new(Mutex::new(0));
219 let counter_clone = Arc::clone(&counter);
220
221 let cache = LazyCache::new(move || {
222 let mut count = counter_clone.lock().unwrap();
223 *count += 1;
224 99
226 });
227
228 assert_eq!(*cache.get(), 99);
229 assert_eq!(*cache.get(), 99); assert_eq!(*counter.lock().unwrap(), 1); }
232 {
234 let cache = LazyCache::preinit(10);
235 let serialized = serde_json::to_string(&cache).unwrap();
236 let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
237 assert_eq!(*deserialized.get(), 10);
238 }
239 {
241 let cache = LazyCache::preinit(10);
242 assert_eq!(format!("{cache:?}"), "LazyCache(10)");
243
244 let lazy = LazyCache::new(|| 20);
245 assert_eq!(format!("{lazy:?}"), "LazyCache(<uninitialized>)");
246 }
247 {
249 let cache: LazyCache<i32> = LazyCache {
250 once: Once::new(),
251 value: UnsafeCell::new(None),
252 init: UnsafeCell::new(None), };
254 let err = cache.try_get();
255 assert_eq!(
256 err.unwrap_err(),
257 LazyCacheError::MissingFunctionOrInitializedTwice
258 );
259 }
260 {
262 let lazy = Arc::new(LazyCache::<()>::new(|| {
263 panic!("poison the lock");
264 }));
265
266 let lazy_clone = Arc::clone(&lazy);
267 let _ = thread::spawn(move || {
268 let _ = lazy_clone.try_initialize();
269 })
270 .join(); let result = lazy.try_initialize();
274 assert_eq!(result, Err(LazyCacheError::LockPoisoned));
275 }
276 }
277
278 #[cfg(all(not(target_arch = "wasm32"), feature = "diagnostics"))]
279 #[test]
280 fn test_lazy_cache_allocation() {
281 use tikv_jemallocator::Jemalloc;
282
283 #[global_allocator]
284 static GLOBAL: Jemalloc = Jemalloc;
285
286 print_heap_usage("Start");
287
288 let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); print_heap_usage("Before initializing LazyCache");
291
292 let _ = cache.get();
293
294 print_heap_usage("After initializing LazyCache");
295 }
296}