1use serde::{de::DeserializeOwned, Deserialize, Serialize, Serializer};
7use std::{cell::UnsafeCell, fmt, ops::Deref, sync::Once};
8
9type LazyFn<T> = Box<dyn FnOnce() -> T + Send + Sync + 'static>;
10
11pub struct LazyCache<T> {
13 pub(crate) once: Once,
14 pub(crate) value: UnsafeCell<Option<T>>,
15 pub(crate) init: UnsafeCell<Option<LazyFn<T>>>,
16}
17
18#[derive(Debug, PartialEq, Clone)]
19pub enum LazyCacheError {
20 LockPoisoned,
21 UninitializedCache,
22 MissingFunctionOrInitializedTwice,
23}
24
25#[derive(Debug, PartialEq)]
26pub enum LazyCacheErrorOr<E> {
27 Inner(E),
28 Outer(LazyCacheError),
29}
30
31unsafe impl<T: Send + Sync> Sync for LazyCache<T> {}
34unsafe impl<T: Send> Send for LazyCache<T> {}
35
36impl<T> LazyCache<T> {
40 pub fn new<F>(f: F) -> Self
41 where
42 F: FnOnce() -> T + Send + Sync + 'static,
43 {
44 LazyCache {
45 once: Once::new(),
46 value: UnsafeCell::new(None),
47 init: UnsafeCell::new(Some(Box::new(f))),
48 }
49 }
50
51 pub fn preinit(value: T) -> LazyCache<T> {
53 let once = Once::new();
54 once.call_once(|| {});
55 LazyCache {
56 once,
57 value: UnsafeCell::new(Some(value)),
58 init: UnsafeCell::new(None),
59 }
60 }
61
62 fn try_initialize(&self) -> Result<(), LazyCacheError> {
63 let mut error = None;
64
65 self.once.call_once_force(|state| {
66 if state.is_poisoned() {
67 error = Some(LazyCacheError::LockPoisoned);
68 return;
69 }
70
71 let init_fn = unsafe { (*self.init.get()).take() };
72 match init_fn {
73 Some(f) => {
74 let value = f();
75 unsafe {
76 *self.value.get() = Some(value);
77 }
78 }
79 None => {
80 error = Some(LazyCacheError::MissingFunctionOrInitializedTwice);
81 }
82 }
83 });
84
85 if let Some(e) = error {
86 return Err(e);
87 }
88
89 if self.once.is_completed() {
90 Ok(())
91 } else {
92 Err(LazyCacheError::LockPoisoned)
93 }
94 }
95
96 pub(crate) fn try_get(&self) -> Result<&T, LazyCacheError> {
97 self.try_initialize()?;
98 unsafe {
99 (*self.value.get())
100 .as_ref()
101 .ok_or(LazyCacheError::UninitializedCache)
102 }
103 }
104
105 pub fn get(&self) -> &T {
106 self.try_get().unwrap()
107 }
108}
109
110impl<T, E: Clone> LazyCache<Result<T, E>> {
113 pub fn try_get_or_err(&self) -> Result<&T, LazyCacheErrorOr<E>> {
114 match self.try_get() {
115 Ok(Ok(v)) => Ok(v),
116 Ok(Err(e)) => Err(LazyCacheErrorOr::Inner(e.clone())),
117 Err(_) => Err(LazyCacheErrorOr::Outer(LazyCacheError::LockPoisoned)),
118 }
119 }
120}
121
122impl<T> Deref for LazyCache<T> {
123 type Target = T;
124
125 fn deref(&self) -> &Self::Target {
126 self.get()
127 }
128}
129
130impl<T: fmt::Debug> fmt::Debug for LazyCache<T> {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 let value = unsafe { &*self.value.get() };
134 match value {
135 Some(v) => f.debug_tuple("LazyCache").field(v).finish(),
136 None => f.write_str("LazyCache(<uninitialized>)"),
137 }
138 }
139}
140
141impl<T> Serialize for LazyCache<T>
142where
143 T: Serialize + Send + Sync + 'static,
144{
145 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
146 where
147 S: Serializer,
148 {
149 self.get().serialize(serializer)
150 }
151}
152
153impl<'de, T> Deserialize<'de> for LazyCache<T>
154where
155 T: DeserializeOwned + Send + Sync + 'static,
156{
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159 where
160 D: serde::Deserializer<'de>,
161 {
162 let value = T::deserialize(deserializer)?;
163 Ok(LazyCache::preinit(value))
164 }
165}
166
167#[cfg(test)]
168mod test {
170 use super::*;
171 use std::{
172 sync::{Arc, Mutex},
173 thread,
174 };
175
176 #[cfg(not(target_arch = "wasm32"))]
177 fn print_heap_usage(label: &str) {
178 use tikv_jemalloc_ctl::{epoch, stats};
179
180 epoch::advance().unwrap(); let allocated = stats::allocated::read().unwrap();
182 println!("[{label}] Heap allocated: {} kilobytes", allocated / 1024);
183 }
184
185 #[test]
187 fn test_lazy_cache() {
188 {
190 let cache = LazyCache::preinit(100);
192 assert_eq!(*cache.get(), 100);
193
194 let lazy = LazyCache::new(|| {
196 let a = 10;
197 let b = 20;
198 a + b
199 });
200 assert_eq!(*lazy.get(), 30);
201 assert_eq!(*lazy.get(), 30);
203 }
204
205 {
207 let counter = Arc::new(Mutex::new(0));
208 let counter_clone = Arc::clone(&counter);
209
210 let cache = LazyCache::new(move || {
211 let mut count = counter_clone.lock().unwrap();
212 *count += 1;
213 99
215 });
216
217 assert_eq!(*cache.get(), 99);
218 assert_eq!(*cache.get(), 99); assert_eq!(*counter.lock().unwrap(), 1); }
221 {
223 let cache = LazyCache::preinit(10);
224 let serialized = serde_json::to_string(&cache).unwrap();
225 let deserialized: LazyCache<i32> = serde_json::from_str(&serialized).unwrap();
226 assert_eq!(*deserialized.get(), 10);
227 }
228 {
230 let cache = LazyCache::preinit(10);
231 assert_eq!(format!("{:?}", cache), "LazyCache(10)");
232
233 let lazy = LazyCache::new(|| 20);
234 assert_eq!(format!("{:?}", lazy), "LazyCache(<uninitialized>)");
235 }
236 {
238 let cache: LazyCache<i32> = LazyCache {
239 once: Once::new(),
240 value: UnsafeCell::new(None),
241 init: UnsafeCell::new(None), };
243 let err = cache.try_get();
244 assert_eq!(
245 err.unwrap_err(),
246 LazyCacheError::MissingFunctionOrInitializedTwice
247 );
248 }
249 {
251 let lazy = Arc::new(LazyCache::<()>::new(|| {
252 panic!("poison the lock");
253 }));
254
255 let lazy_clone = Arc::clone(&lazy);
256 let _ = thread::spawn(move || {
257 let _ = lazy_clone.try_initialize();
258 })
259 .join(); let result = lazy.try_initialize();
263 assert_eq!(result, Err(LazyCacheError::LockPoisoned));
264 }
265 }
266
267 #[cfg(not(target_arch = "wasm32"))]
268 #[test]
269 fn test_lazy_cache_allocation() {
270 use tikv_jemallocator::Jemalloc;
271
272 #[global_allocator]
273 static GLOBAL: Jemalloc = Jemalloc;
274
275 print_heap_usage("Start");
276
277 let cache = Arc::new(LazyCache::new(|| vec![42u8; 1024 * 1024])); print_heap_usage("Before initializing LazyCache");
280
281 let _ = cache.get();
282
283 print_heap_usage("After initializing LazyCache");
284 }
285}