mina_tree/ondisk/
database.rs

1use std::{
2    collections::HashMap,
3    fs::{File, OpenOptions},
4    io::{BufReader, BufWriter, Seek, SeekFrom, Write},
5    path::{Path, PathBuf},
6};
7
8use std::io::ErrorKind::{InvalidData, Other, UnexpectedEof};
9
10use super::{
11    batch::Batch,
12    compression::{compress, decompress, MaybeCompressed},
13    lock::LockedFile,
14};
15
16pub(super) type Key = Box<[u8]>;
17pub(super) type Value = Box<[u8]>;
18pub(super) type Offset = u64;
19
20pub type Uuid = String;
21
22const KEY_IS_COMPRESSED_BIT: u8 = 1 << 0;
23const VALUE_IS_COMPRESSED_BIT: u8 = 1 << 1;
24const IS_REMOVED_BIT: u8 = 1 << 2;
25
26const BUFFER_DEFAULT_CAPACITY: usize = 4096;
27
28const DATABASE_VERSION: u64 = 1;
29const DATABASE_VERSION_NBYTES: usize = 8;
30
31pub struct Database {
32    uuid: Uuid,
33    /// Index of keys to their values offset
34    index: HashMap<Key, Offset>,
35    /// Points to end of file
36    current_file_offset: Offset,
37    file: BufWriter<LockedFile>,
38    /// Read buffer
39    buffer: Vec<u8>,
40    /// Filename of the inner file
41    filename: PathBuf,
42}
43
44/// Compute crc32 of an entry
45///
46/// This is used to verify data corruption
47fn compute_crc32(header: &EntryHeader, key_bytes: &[u8], value_bytes: &[u8]) -> u32 {
48    let bool_to_byte = |b| if b { 1 } else { 0 };
49
50    let is_removed = bool_to_byte(header.is_removed);
51    let key_is_compressed = bool_to_byte(header.key_is_compressed);
52    let value_is_compressed = bool_to_byte(header.value_is_compressed);
53
54    let mut crc32: crc32fast::Hasher = Default::default();
55
56    crc32.update(&header.key_length.to_le_bytes());
57    crc32.update(&header.value_length.to_le_bytes());
58    crc32.update(&[key_is_compressed, value_is_compressed, is_removed]);
59    crc32.update(key_bytes);
60    if !header.is_removed {
61        crc32.update(value_bytes);
62    };
63
64    crc32.finalize()
65}
66
67/// Header for each entry in the database
68#[derive(Debug)]
69struct EntryHeader {
70    key_length: u32,
71    value_length: u64,
72    key_is_compressed: bool,
73    value_is_compressed: bool,
74    is_removed: bool,
75    crc32: u32,
76}
77
78impl EntryHeader {
79    /// Number of bytes the `EntryHeader` occupies on disk
80    pub const NBYTES: usize = 17;
81
82    /// Returns key + value length
83    fn entry_length(&self) -> std::io::Result<u64> {
84        (self.key_length as u64)
85            .checked_add(self.value_length)
86            .ok_or_else(|| std::io::Error::from(InvalidData))
87    }
88
89    /// Returns the value offset of this entry
90    fn compute_value_offset(&self, header_offset: Offset) -> Option<Offset> {
91        header_offset
92            .checked_add(self.key_length as u64)?
93            .checked_add(EntryHeader::NBYTES as u64)
94    }
95
96    /// Convert this header to bytes
97    fn to_bytes(&self) -> std::io::Result<[u8; Self::NBYTES]> {
98        let set_bit = |cond, bit| if cond { bit } else { 0 };
99
100        let mut bitflags = 0;
101        bitflags |= set_bit(self.key_is_compressed, KEY_IS_COMPRESSED_BIT);
102        bitflags |= set_bit(self.value_is_compressed, VALUE_IS_COMPRESSED_BIT);
103        bitflags |= set_bit(self.is_removed, IS_REMOVED_BIT);
104
105        let bytes = [0; Self::NBYTES];
106        let mut bytes = std::io::Cursor::new(bytes);
107
108        bytes.write_all(&self.key_length.to_le_bytes())?;
109        bytes.write_all(&self.value_length.to_le_bytes())?;
110        bytes.write_all(&[bitflags])?;
111        bytes.write_all(&self.crc32.to_le_bytes())?;
112
113        Ok(bytes.into_inner())
114    }
115
116    /// Build a `Header` from its entry (key and value)
117    fn make(key: &MaybeCompressed, value: &Option<MaybeCompressed>) -> std::io::Result<Self> {
118        let to_u64 = |n: usize| n.try_into().map_err(|_| std::io::Error::from(InvalidData));
119        let to_u32 = |n: usize| n.try_into().map_err(|_| std::io::Error::from(InvalidData));
120
121        let key_is_compressed = key.is_compressed();
122        let key = key.as_ref();
123
124        let value_is_compressed = value
125            .as_ref()
126            .map(|value| value.is_compressed())
127            .unwrap_or(false);
128
129        let key_length: u32 = to_u32(key.len())?;
130        let value_length = match value.as_ref() {
131            None => 0,
132            Some(value) => to_u64(value.as_ref().len())?,
133        };
134        let is_removed = value.is_none();
135
136        let mut header = EntryHeader {
137            key_length,
138            key_is_compressed,
139            value_length,
140            value_is_compressed,
141            is_removed,
142            crc32: 0, // Set with correct value below
143        };
144
145        let crc32 = compute_crc32(
146            &header,
147            key,
148            value.as_ref().map(AsRef::as_ref).unwrap_or(&[]),
149        );
150        header.crc32 = crc32;
151
152        Ok(header)
153    }
154
155    /// Reads a header from a slice of bytes
156    ///
157    /// Returns an error when the slice is too small
158    fn read(bytes: &[u8]) -> std::io::Result<Self> {
159        if bytes.len() < Self::NBYTES {
160            return Err(UnexpectedEof.into());
161        }
162
163        let key_length = read_u32(bytes)?;
164        let value_length = read_u64(&bytes[4..])?;
165        let bitflags = read_u8(&bytes[12..])?;
166        let crc32 = read_u32(&bytes[13..])?;
167
168        let key_is_compressed = (bitflags & KEY_IS_COMPRESSED_BIT) != 0;
169        let value_is_compressed = (bitflags & VALUE_IS_COMPRESSED_BIT) != 0;
170        let is_removed = (bitflags & IS_REMOVED_BIT) != 0;
171
172        Ok(Self {
173            key_length,
174            key_is_compressed,
175            value_length,
176            value_is_compressed,
177            is_removed,
178            crc32,
179        })
180    }
181
182    /// Returns an error when the checksum doesn't match
183    fn verify_checksum(&self, key_bytes: &[u8], value_bytes: &[u8]) -> std::io::Result<()> {
184        let crc32 = compute_crc32(self, key_bytes, value_bytes);
185
186        if crc32 != self.crc32 {
187            return Err(InvalidData.into());
188        }
189
190        Ok(())
191    }
192}
193
194fn next_uuid() -> Uuid {
195    uuid::Uuid::new_v4().to_string()
196}
197
198fn read_u64(slice: &[u8]) -> std::io::Result<u64> {
199    slice
200        .get(..8)
201        .and_then(|slice: &[u8]| slice.try_into().ok())
202        .map(u64::from_le_bytes)
203        .ok_or_else(|| UnexpectedEof.into())
204}
205
206fn read_u32(slice: &[u8]) -> std::io::Result<u32> {
207    slice
208        .get(..4)
209        .and_then(|slice: &[u8]| slice.try_into().ok())
210        .map(u32::from_le_bytes)
211        .ok_or_else(|| UnexpectedEof.into())
212}
213
214fn read_u8(slice: &[u8]) -> std::io::Result<u8> {
215    slice
216        .get(..1)
217        .and_then(|slice: &[u8]| slice.try_into().ok())
218        .map(u8::from_le_bytes)
219        .ok_or_else(|| UnexpectedEof.into())
220}
221
222fn ensure_buffer_length(buffer: &mut Vec<u8>, length: usize) {
223    if buffer.len() < length {
224        buffer.resize(length, 0)
225    }
226}
227
228#[cfg(unix)]
229fn read_exact_at(file: &mut File, buffer: &mut [u8], offset: Offset) -> std::io::Result<()> {
230    use std::os::unix::prelude::FileExt;
231
232    file.read_exact_at(buffer, offset)
233}
234
235#[cfg(not(unix))]
236fn read_exact_at(file: &mut File, buffer: &mut [u8], offset: Offset) -> std::io::Result<()> {
237    use std::io::Read;
238
239    file.seek(SeekFrom::Start(offset))?;
240    file.read_exact(buffer)
241}
242
243enum CreateMode {
244    Regular,
245    Temporary,
246}
247
248impl Database {
249    /// Creates a new instance of the database at the specified directory.
250    /// If the directory contains an existing database, its content will be loaded.
251    ///
252    /// # Arguments
253    ///
254    /// * `directory` - The path where the database will be created or opened.
255    ///
256    /// # Returns
257    ///
258    /// * `Result<Self>` - Returns an instance of the database if successful, otherwise
259    ///    returns an error.
260    ///
261    /// # Errors
262    ///
263    /// This method will return an error in the following cases:
264    ///
265    ///   * Unable to open or create the directory.
266    ///   * Another process is already using the database.
267    ///   * The database is corrupted (when the path contains an existing database).
268    ///   * The database version is incompatible
269    ///
270    pub fn create(directory: impl AsRef<Path>) -> std::io::Result<Self> {
271        Self::create_impl(directory, CreateMode::Regular)
272    }
273
274    fn create_impl(directory: impl AsRef<Path>, mode: CreateMode) -> std::io::Result<Self> {
275        let directory = directory.as_ref();
276
277        let filename = directory.join(match mode {
278            CreateMode::Regular => "db",
279            CreateMode::Temporary => "db_tmp",
280        });
281
282        if filename.try_exists()? {
283            if let CreateMode::Temporary = mode {
284                std::fs::remove_file(&filename)?;
285            } else {
286                return Self::reload(filename);
287            }
288        }
289
290        if !directory.try_exists()? {
291            std::fs::create_dir_all(directory)?;
292        }
293
294        let mut file = LockedFile::try_open_exclusively(
295            &filename,
296            OpenOptions::new()
297                .read(true)
298                .write(true)
299                .append(true)
300                .create_new(true),
301        )?;
302
303        file.write_all(&DATABASE_VERSION.to_le_bytes())?;
304
305        Ok(Self {
306            uuid: next_uuid(),
307            index: HashMap::with_capacity(128),
308            current_file_offset: DATABASE_VERSION_NBYTES as u64,
309            file: BufWriter::with_capacity(4 * 1024 * 1024, file), // 4 MB
310            buffer: Vec::with_capacity(BUFFER_DEFAULT_CAPACITY),
311            filename,
312        })
313    }
314
315    /// Reload the database at the specified path
316    fn reload(filename: PathBuf) -> std::io::Result<Self> {
317        use std::io::Read;
318
319        let mut file = LockedFile::try_open_exclusively(
320            &filename,
321            OpenOptions::new()
322                .read(true)
323                .write(true)
324                .append(true)
325                .create_new(false),
326        )?;
327
328        let mut current_offset = 0;
329        let eof = file.seek(SeekFrom::End(0))?;
330
331        file.seek(SeekFrom::Start(0))?;
332
333        let mut reader = BufReader::with_capacity(4 * 1024 * 1024, file); // 4 MB
334        let mut bytes = vec![0; BUFFER_DEFAULT_CAPACITY];
335
336        // Check if the database is the same version
337        {
338            reader.read_exact(&mut bytes[..DATABASE_VERSION_NBYTES])?;
339            let database_version = read_u64(&bytes)?;
340            if database_version != DATABASE_VERSION {
341                return Err(std::io::Error::new(Other, "Incompatible database"));
342            }
343            current_offset += DATABASE_VERSION_NBYTES as u64;
344        }
345
346        let mut index = HashMap::with_capacity(256);
347
348        while current_offset < eof {
349            let header_offset = current_offset;
350
351            ensure_buffer_length(&mut bytes, EntryHeader::NBYTES);
352            reader.read_exact(&mut bytes[..EntryHeader::NBYTES])?;
353
354            let header = EntryHeader::read(&bytes)?;
355            let entry_length = header.entry_length()? as usize;
356            let key_length = header.key_length as usize;
357
358            ensure_buffer_length(&mut bytes, entry_length);
359            reader.read_exact(&mut bytes[..entry_length])?;
360
361            ensure_buffer_length(&mut bytes, entry_length);
362            let (key_bytes, value_bytes) = bytes[..entry_length].split_at(key_length);
363
364            header.verify_checksum(key_bytes, value_bytes)?;
365
366            let key = decompress(key_bytes, header.key_is_compressed)?;
367
368            if header.is_removed {
369                index.remove(&key);
370            } else {
371                index.insert(key, header_offset);
372            }
373
374            current_offset += (EntryHeader::NBYTES + entry_length) as u64;
375        }
376
377        if eof != current_offset {
378            return Err(UnexpectedEof.into());
379        }
380
381        Ok(Self {
382            uuid: next_uuid(),
383            index,
384            current_file_offset: eof,
385            file: BufWriter::with_capacity(4 * 1024 * 1024, reader.into_inner()), // 4 MB
386            buffer: Vec::with_capacity(BUFFER_DEFAULT_CAPACITY),
387            filename,
388        })
389    }
390
391    /// Retrieves the UUID of the current database instance.
392    ///
393    /// # Returns
394    ///
395    /// * `&Uuid` - Returns a reference to the UUID of the instance.
396    pub fn get_uuid(&self) -> &Uuid {
397        &self.uuid
398    }
399
400    /// Closes the current database instance.
401    ///
402    /// Any usage of this database after this call will return an error.
403    pub fn close(&self) {
404        // NOTE: `close` is actually implemented at the ffi level, where `Self` is dropped
405    }
406
407    fn read_header(&mut self, header_offset: Offset) -> std::io::Result<EntryHeader> {
408        ensure_buffer_length(&mut self.buffer, EntryHeader::NBYTES);
409        read_exact_at(
410            self.file.get_mut(),
411            &mut self.buffer[..EntryHeader::NBYTES],
412            header_offset,
413        )?;
414
415        EntryHeader::read(&self.buffer)
416    }
417
418    fn read_value(&mut self, offset: Offset, length: usize) -> std::io::Result<&[u8]> {
419        ensure_buffer_length(&mut self.buffer, length);
420        read_exact_at(self.file.get_mut(), &mut self.buffer[..length], offset)?;
421
422        Ok(&self.buffer[..length])
423    }
424
425    /// Retrieves the value associated with a given key.
426    ///
427    /// # Arguments
428    ///
429    /// * `key` - Bytes representing the key to fetch the value of.
430    ///
431    /// # Returns
432    ///
433    /// * `Result<Option<Box<[u8]>>>` - Returns an optional values if the key exists;
434    ///    otherwise, None. Returns an error if something goes wrong.
435    pub fn get(&mut self, key: &[u8]) -> std::io::Result<Option<Value>> {
436        // Note: `&mut self` is required for `File::seek`
437
438        let header_offset = match self.index.get(key).copied() {
439            Some(header_offset) => header_offset,
440            None => return Ok(None),
441        };
442
443        let header = self.read_header(header_offset)?;
444
445        let value_offset = header
446            .compute_value_offset(header_offset)
447            .ok_or_else(|| std::io::Error::from(InvalidData))?;
448        let value_length = header.value_length as usize;
449
450        let value = self.read_value(value_offset, value_length)?;
451
452        decompress(value, header.value_is_compressed).map(Some)
453    }
454
455    fn set_impl(&mut self, key: Key, value: Option<Value>) -> std::io::Result<()> {
456        let is_removed = value.is_none();
457
458        let compressed_key = compress(&key)?;
459        let compressed_value = match value.as_ref() {
460            Some(value) => Some(compress(value)?),
461            None => None,
462        };
463
464        let header = EntryHeader::make(&compressed_key, &compressed_value)?;
465        let header_offset = self.current_file_offset;
466
467        self.file.write_all(&header.to_bytes()?)?;
468        self.file.write_all(compressed_key.as_ref())?;
469        if let Some(value) = compressed_value.as_ref() {
470            self.file.write_all(value.as_ref())?;
471        };
472
473        let buffer_len = EntryHeader::NBYTES as u64 + header.entry_length()?;
474        self.current_file_offset += buffer_len;
475
476        // Update index
477        if is_removed {
478            self.index.remove(&key);
479        } else {
480            self.index.insert(key, header_offset);
481        }
482
483        Ok(())
484    }
485
486    /// Adds or updates an entry (key-value pair) in the database.
487    ///
488    /// # Arguments
489    ///
490    /// * `key` - Bytes representing the key to store.
491    /// * `value` - Bytes representing the value to store.
492    ///
493    /// # Returns
494    ///
495    /// * `Result<()>` - Returns () if successful, otherwise returns an error.
496    pub fn set(&mut self, key: Key, value: Value) -> std::io::Result<()> {
497        self.set_impl(key, Some(value))?;
498        self.flush()?;
499        Ok(())
500    }
501
502    /// Processes multiple entries (key-value pairs) to set and keys to remove in
503    /// a single batch operation.
504    ///
505    /// # Arguments
506    ///
507    /// * `key_data_pairs` - An iterable of key-value pairs to add or update.
508    /// * `remove_keys` - An iterable of keys to remove from the database.
509    ///
510    /// # Returns
511    ///
512    /// * `Result<()>` - Returns () if successful, otherwise returns an error.
513    pub fn set_batch<KV, R>(&mut self, key_data_pairs: KV, remove_keys: R) -> std::io::Result<()>
514    where
515        KV: IntoIterator<Item = (Key, Value)>,
516        R: IntoIterator<Item = Key>,
517    {
518        for (key, value) in key_data_pairs {
519            self.set_impl(key, Some(value))?;
520        }
521
522        for key in remove_keys {
523            self.set_impl(key, None)? // empty value
524        }
525
526        self.flush()?;
527
528        Ok(())
529    }
530
531    /// Fetches a batch of values for the given keys.
532    ///
533    /// # Arguments
534    ///
535    /// * `keys` - An iterable of keys to fetch the values of
536    ///
537    /// # Returns
538    ///
539    /// * `Result<Vec<Option<Box<[u8]>>>>` - Returns a vector of optional values
540    ///    corresponding to each key; if a key is not found, returns None.
541    pub fn get_batch<K>(&mut self, keys: K) -> std::io::Result<Vec<Option<Value>>>
542    where
543        K: IntoIterator<Item = Key>,
544    {
545        keys.into_iter().map(|key| self.get(&key)).collect()
546    }
547
548    /// Creates a new checkpoint, saving a consistent snapshot of the
549    /// current state of the database.
550    ///
551    /// # Arguments
552    ///
553    /// * `directory` - The path where the checkpoint files will be created.
554    ///
555    /// # Returns
556    ///
557    /// * `Result<()>` - Returns () if checkpoint creation is successful,
558    ///   otherwise returns an error.
559    pub fn make_checkpoint(&mut self, directory: impl AsRef<Path>) -> std::io::Result<()> {
560        self.create_checkpoint(directory.as_ref())?;
561        Ok(())
562    }
563
564    /// Creates a new checkpoint, and instantiates a new database from it.
565    ///
566    /// # Arguments
567    ///
568    /// * `directory` - The path where the checkpoint files will be created.
569    ///
570    /// # Returns
571    ///
572    /// * `Result<Self>` - Returns a new instance of the database if successful,
573    ///   otherwise returns an error.
574    pub fn create_checkpoint(&mut self, directory: impl AsRef<Path>) -> std::io::Result<Self> {
575        let mut checkpoint = Self::create(directory.as_ref())?;
576
577        let keys: Vec<Key> = self.index.keys().cloned().collect();
578
579        for key in keys {
580            let value = self.get(&key)?;
581            checkpoint.set_impl(key, value)?;
582        }
583
584        checkpoint.flush()?;
585
586        Ok(checkpoint)
587    }
588
589    /// Flush writes buffer to fs and call `fsync`
590    fn flush(&mut self) -> std::io::Result<()> {
591        self.file.flush()?;
592        self.file.get_ref().sync_all()
593    }
594
595    fn remove_impl(&mut self, key: Key) -> std::io::Result<()> {
596        self.set_impl(key, None) // empty value
597    }
598
599    /// Removes a key-value pair from the database.
600    ///
601    /// # Arguments
602    ///
603    /// * `key` - Bytes representing the key to remove.
604    ///
605    /// # Returns
606    ///
607    /// * `Result<()>` - Returns () if the key is removed successfully,
608    ///   otherwise returns an error.
609    pub fn remove(&mut self, key: Key) -> std::io::Result<()> {
610        self.remove_impl(key)?;
611        self.flush()
612    }
613
614    /// Retrieves all entries (key-value pairs) from the database.
615    ///
616    /// # Returns
617    ///
618    /// * `Result<Vec<(Box<[u8]>, Box<[u8]>)>>` - Returns a vector containing
619    ///   all key-value pairs as boxed byte arrays. Returns an error if retrieval fails.
620    pub fn to_alist(&mut self) -> std::io::Result<Vec<(Key, Value)>> {
621        let keys: Vec<Key> = self.index.keys().cloned().collect();
622
623        keys.into_iter()
624            .map(|key| {
625                Ok((
626                    key.clone(),
627                    self.get(&key)?
628                        .ok_or_else(|| std::io::Error::from(InvalidData))?,
629                ))
630            })
631            .collect()
632    }
633
634    /// Processes a pre-built batch of operations, effectively running the batch on the database.
635    ///
636    /// # Arguments
637    ///
638    /// * `batch` - A mutable reference to a `Batch` struct containing the operations to execute.
639    ///
640    /// # Returns
641    ///
642    /// * `Result<()>` - Returns () if the batch is executed successfully,
643    ///   otherwise returns an error.
644    pub fn run_batch(&mut self, batch: &mut Batch) -> std::io::Result<()> {
645        use super::batch::Action::{Remove, Set};
646
647        for action in batch.take() {
648            match action {
649                Set(key, value) => self.set_impl(key, Some(value))?,
650                Remove(key) => self.remove_impl(key)?,
651            }
652        }
653
654        self.flush()
655    }
656
657    /// Triggers garbage collection for the database, cleaning up obsolete
658    /// data and potentially freeing up storage space.
659    ///
660    /// # Returns
661    ///
662    /// * `Result<()>` - Returns () if garbage collection is successful,
663    ///   otherwise returns an error.
664    pub fn gc(&mut self) -> std::io::Result<()> {
665        let directory = self.filename.parent().unwrap();
666        let mut new_db = Self::create_impl(directory, CreateMode::Temporary)?;
667
668        let keys: Vec<Key> = self.index.keys().cloned().collect();
669
670        for key in keys {
671            let value = self.get(&key)?;
672            new_db.set_impl(key, value)?;
673        }
674
675        new_db.flush()?;
676
677        exchange_file_atomically(&self.filename, &new_db.filename)?;
678
679        new_db.filename.clone_from(&self.filename);
680        new_db.uuid.clone_from(&self.uuid);
681
682        *self = new_db;
683
684        Ok(())
685    }
686}
687
688#[cfg(not(target_os = "linux"))]
689fn exchange_file_atomically(db_path: &Path, tmp_path: &Path) -> std::io::Result<()> {
690    std::fs::rename(tmp_path, db_path)
691}
692
693// `renameat2` is a Linux syscall
694#[cfg(target_os = "linux")]
695fn exchange_file_atomically(db_path: &Path, tmp_path: &Path) -> std::io::Result<()> {
696    use std::os::unix::prelude::OsStrExt;
697
698    let cstr_db_path = std::ffi::CString::new(db_path.as_os_str().as_bytes())?;
699    let cstr_db_path = cstr_db_path.as_ptr();
700
701    let cstr_tmp_path = std::ffi::CString::new(tmp_path.as_os_str().as_bytes())?;
702    let cstr_tmp_path = cstr_tmp_path.as_ptr();
703
704    // Exchange `db_path` with `tmp_path` atomically
705    let result = unsafe {
706        libc::syscall(
707            libc::SYS_renameat2,
708            libc::AT_FDCWD,
709            cstr_tmp_path,
710            libc::AT_FDCWD,
711            cstr_db_path,
712            libc::RENAME_EXCHANGE,
713        )
714    };
715
716    if result != 0 {
717        let error = std::io::Error::last_os_error();
718        return Err(error);
719    }
720
721    // Remove previous file
722    std::fs::remove_file(tmp_path)?;
723
724    Ok(())
725}
726
727#[cfg(test)]
728mod tests {
729    use rand::{Fill, Rng};
730    use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
731
732    use super::*;
733
734    struct TempDir {
735        path: PathBuf,
736    }
737
738    static DIRECTORY_NUMBER: AtomicUsize = AtomicUsize::new(0);
739
740    impl TempDir {
741        fn new() -> Self {
742            let next = || DIRECTORY_NUMBER.fetch_add(1, SeqCst);
743
744            let mut number = next();
745
746            let path = loop {
747                let directory = format!("/tmp/mina-keyvaluedb-test-{}", number);
748                let path = PathBuf::from(directory);
749
750                if !path.exists() {
751                    break path;
752                }
753                number = next();
754            };
755
756            std::fs::create_dir_all(&path).unwrap();
757
758            Self { path }
759        }
760
761        fn as_path(&self) -> &Path {
762            &self.path
763        }
764    }
765
766    impl Drop for TempDir {
767        fn drop(&mut self) {
768            if let Err(e) = std::fs::remove_dir_all(&self.path) {
769                eprintln!(
770                    "[test] Failed to remove temporary directory {:?}: {:?}",
771                    self.path, e
772                );
773            }
774        }
775    }
776
777    fn key(s: &str) -> Key {
778        Box::<[u8]>::from(s.as_bytes())
779    }
780
781    fn value(s: &str) -> Value {
782        Box::<[u8]>::from(s.as_bytes())
783        // s.as_bytes().to_vec()
784    }
785
786    fn sorted_vec(mut vec: Vec<(Key, Value)>) -> Vec<(Key, Value)> {
787        vec.sort_by_cached_key(|(k, _)| k.clone());
788        vec
789    }
790
791    #[test]
792    fn test_empty_value() {
793        let db_dir = TempDir::new();
794
795        let mut db = Database::create(db_dir.as_path()).unwrap();
796
797        db.set(key("a"), value("abc")).unwrap();
798        let v = db.get(&key("a")).unwrap().unwrap();
799        assert_eq!(v, value("abc"));
800
801        db.set(key("a"), value("")).unwrap();
802        let v = db.get(&key("a")).unwrap().unwrap();
803        assert_eq!(v, value(""));
804    }
805
806    #[test]
807    fn test_persistent_removed_value() {
808        let db_dir = TempDir::new();
809
810        let first = {
811            let mut db = Database::create(db_dir.as_path()).unwrap();
812
813            db.set(key("abcd"), value("abcd")).unwrap();
814
815            db.set(key("a"), value("abc")).unwrap();
816            let v = db.get(&key("a")).unwrap().unwrap();
817            assert_eq!(v, value("abc"));
818
819            db.set(key("a"), value("")).unwrap();
820            let v = db.get(&key("a")).unwrap().unwrap();
821            assert_eq!(v, value(""));
822
823            db.remove(key("a")).unwrap();
824            let v = db.get(&key("a")).unwrap();
825            assert!(v.is_none());
826
827            sorted_vec(db.to_alist().unwrap())
828        };
829
830        assert_eq!(first.len(), 1);
831
832        let second = {
833            let mut db = Database::create(db_dir.as_path()).unwrap();
834            sorted_vec(db.to_alist().unwrap())
835        };
836
837        assert_eq!(first, second);
838    }
839
840    #[test]
841    fn test_get_batch() {
842        let db_dir = TempDir::new();
843
844        let mut db = Database::create(db_dir.as_path()).unwrap();
845
846        let (key1, key2, key3): (Key, Key, Key) = (
847            "a".as_bytes().into(),
848            "b".as_bytes().into(),
849            "c".as_bytes().into(),
850        );
851        let data: Value = value("test");
852
853        db.set(key1.clone(), data.clone()).unwrap();
854        db.set(key3.clone(), data.clone()).unwrap();
855
856        let res = db.get_batch([key1, key2, key3]).unwrap();
857
858        assert_eq!(res[0].as_ref().unwrap(), &data);
859        assert!(res[1].is_none());
860        assert_eq!(res[2].as_ref().unwrap(), &data);
861    }
862
863    fn make_random_key_values(nkeys: usize) -> Vec<(Key, Value)> {
864        let mut rng = rand::thread_rng();
865
866        let mut key = [0; 32];
867
868        let mut key_values = HashMap::with_capacity(nkeys);
869
870        while key_values.len() < nkeys {
871            let key_length: usize = rng.gen_range(2..=32);
872            key[..key_length].try_fill(&mut rng).unwrap();
873
874            let i = Box::<[u8]>::from(key_values.len().to_ne_bytes());
875            key_values.insert(Box::<[u8]>::from(&key[..key_length]), i);
876        }
877
878        let mut key_values: Vec<(Key, Value)> = key_values.into_iter().collect();
879        key_values.sort_by_cached_key(|(k, _)| k.clone());
880        key_values
881    }
882
883    #[test]
884    fn test_persistent() {
885        let db_dir = TempDir::new();
886
887        let mut rng = rand::thread_rng();
888        let nkeys: usize = rng.gen_range(1000..2000);
889        let sorted = make_random_key_values(nkeys);
890
891        let first = {
892            let mut db = Database::create(db_dir.as_path()).unwrap();
893            db.set_batch(sorted.clone(), []).unwrap();
894            let mut alist = db.to_alist().unwrap();
895            alist.sort_by_cached_key(|(k, _)| k.clone());
896            alist
897        };
898
899        assert_eq!(sorted, first);
900
901        let second = {
902            let mut db = Database::create(db_dir.as_path()).unwrap();
903            let mut alist = db.to_alist().unwrap();
904            alist.sort_by_cached_key(|(k, _)| k.clone());
905            alist
906        };
907
908        assert_eq!(first, second);
909    }
910
911    #[test]
912    fn test_gc() {
913        let db_dir = TempDir::new();
914
915        let mut rng = rand::thread_rng();
916        let nkeys: usize = rng.gen_range(1000..2000);
917        let sorted = make_random_key_values(nkeys);
918
919        let mut db = Database::create(db_dir.as_path()).unwrap();
920        db.set_batch(sorted.clone(), []).unwrap();
921
922        (10..50).for_each(|index| {
923            db.remove(sorted[index].0.clone()).unwrap();
924        });
925
926        let offset = db.current_file_offset;
927
928        let mut alist1 = db.to_alist().unwrap();
929        alist1.sort_by_cached_key(|(k, _)| k.clone());
930
931        db.gc().unwrap();
932        assert!(offset > db.current_file_offset);
933
934        let mut alist2 = db.to_alist().unwrap();
935        alist2.sort_by_cached_key(|(k, _)| k.clone());
936        assert_eq!(alist1, alist2);
937
938        db.set(key("a"), value("b")).unwrap();
939        assert_eq!(db.get(&key("a")).unwrap().unwrap(), value("b"));
940    }
941
942    #[test]
943    fn test_to_alist() {
944        let db_dir = TempDir::new();
945
946        let mut rng = rand::thread_rng();
947
948        let nkeys: usize = rng.gen_range(1000..2000);
949
950        let sorted = make_random_key_values(nkeys);
951
952        let mut db = Database::create(db_dir.as_path()).unwrap();
953
954        db.set_batch(sorted.clone(), []).unwrap();
955
956        let mut alist = db.to_alist().unwrap();
957        alist.sort_by_cached_key(|(k, _)| k.clone());
958
959        assert_eq!(sorted, alist);
960    }
961
962    #[test]
963    fn test_checkpoint_read() {
964        let db_dir = TempDir::new();
965
966        let mut rng = rand::thread_rng();
967
968        let nkeys: usize = rng.gen_range(1000..2000);
969
970        let sorted = make_random_key_values(nkeys);
971
972        let mut db_hashtbl: HashMap<_, _> = sorted.into_iter().collect();
973        let mut cp_hashtbl: HashMap<_, _> = db_hashtbl.clone();
974
975        let mut db = Database::create(db_dir.as_path()).unwrap();
976
977        for (key, data) in &db_hashtbl {
978            db.set(key.clone(), data.clone()).unwrap();
979        }
980
981        let cp_dir = TempDir::new();
982        let mut cp = db.create_checkpoint(cp_dir.as_path()).unwrap();
983
984        db_hashtbl.insert(key("db_key"), value("db_data"));
985        cp_hashtbl.insert(key("cp_key"), value("cp_data"));
986
987        db.set(key("db_key"), value("db_data")).unwrap();
988        cp.set(key("cp_key"), value("cp_data")).unwrap();
989
990        let db_sorted: Vec<_> = sorted_vec(db_hashtbl.into_iter().collect());
991        let cp_sorted: Vec<_> = sorted_vec(cp_hashtbl.into_iter().collect());
992
993        let db_alist = sorted_vec(db.to_alist().unwrap());
994        let cp_alist = sorted_vec(cp.to_alist().unwrap());
995
996        assert_eq!(db_sorted, db_alist);
997        assert_eq!(cp_sorted, cp_alist);
998    }
999}