266 lines
7.8 KiB
Rust
266 lines
7.8 KiB
Rust
pub mod errors;
|
|
use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
|
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
use std::path::Path;
|
|
use std::sync::OnceLock;
|
|
use std::time::{Duration, SystemTime};
|
|
|
|
use crate::cache::errors::CacheError;
|
|
|
|
const CACHE_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("cache");
|
|
static CACHE_SINGLETON: OnceLock<Cache> = OnceLock::new();
|
|
|
|
pub struct CacheTTL {}
|
|
|
|
impl CacheTTL {
|
|
pub fn hours(hours: u64) -> Option<Duration> {
|
|
return Some(Duration::from_secs(hours * 60 * 60));
|
|
}
|
|
|
|
pub fn minutes(minutes: u64) -> Option<Duration> {
|
|
return Some(Duration::from_secs(minutes * 60));
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct CacheEntry<T> {
|
|
data: T,
|
|
expires_at: Option<SystemTime>,
|
|
}
|
|
|
|
pub struct Cache {
|
|
db: Database,
|
|
}
|
|
|
|
pub fn get() -> &'static Cache {
|
|
CACHE_SINGLETON.get_or_init(|| {
|
|
let cache_path = waycast_config::cache_path("waycast_cache.db")
|
|
.unwrap_or_else(|| std::env::current_dir().unwrap().join("waycast_cache.db"));
|
|
|
|
// Ensure cache directory exists
|
|
if let Some(parent) = cache_path.parent() {
|
|
if let Err(e) = std::fs::create_dir_all(parent) {
|
|
eprintln!("Warning: Failed to create cache directory {}: {}", parent.display(), e);
|
|
}
|
|
}
|
|
|
|
new(cache_path).expect("Failed to initialize cache :(")
|
|
})
|
|
}
|
|
|
|
// Get an existing cache at the given path or
|
|
// create it if it doesn't exist
|
|
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Cache, CacheError> {
|
|
let db = Database::create(db_path)?;
|
|
|
|
// Initialize the table if it doesn't exist
|
|
let write_txn = db.begin_write()?;
|
|
{
|
|
let _ = write_txn.open_table(CACHE_TABLE)?;
|
|
}
|
|
write_txn.commit()?;
|
|
|
|
Ok(Cache { db })
|
|
}
|
|
|
|
impl Cache {
|
|
/// Cache a value with an optional TTL. If TTL is None, the value never expires.
|
|
pub fn remember_with_ttl<T>(
|
|
&self,
|
|
key: &str,
|
|
ttl: Option<Duration>,
|
|
compute: impl FnOnce() -> T,
|
|
) -> Result<T, CacheError>
|
|
where
|
|
T: Serialize + DeserializeOwned + Clone,
|
|
{
|
|
// Try to get from cache first
|
|
if let Some(entry) = self.get_cached_entry::<T>(key)? {
|
|
// Check if entry has expired
|
|
if let Some(expires_at) = entry.expires_at {
|
|
if SystemTime::now() < expires_at {
|
|
return Ok(entry.data);
|
|
}
|
|
// Entry has expired, continue to recompute
|
|
} else {
|
|
// No expiration, return cached data
|
|
return Ok(entry.data);
|
|
}
|
|
}
|
|
|
|
// Not in cache or expired, compute the value
|
|
let data = compute();
|
|
let expires_at = ttl.map(|duration| SystemTime::now() + duration);
|
|
let entry = CacheEntry {
|
|
data: data.clone(),
|
|
expires_at,
|
|
};
|
|
|
|
// Store in cache
|
|
self.store_entry(key, &entry)?;
|
|
|
|
Ok(data)
|
|
}
|
|
|
|
/// Cache a value with no expiration
|
|
pub fn remember<T>(&self, key: &str, compute: impl FnOnce() -> T) -> Result<T, CacheError>
|
|
where
|
|
T: Serialize + DeserializeOwned + Clone,
|
|
{
|
|
self.remember_with_ttl(key, None, compute)
|
|
}
|
|
|
|
/// Get a cached value if it exists and hasn't expired
|
|
pub fn get<T>(&self, key: &str) -> Result<Option<T>, CacheError>
|
|
where
|
|
T: Serialize + DeserializeOwned,
|
|
{
|
|
if let Some(entry) = self.get_cached_entry::<T>(key)? {
|
|
// Check if entry has expired
|
|
if let Some(expires_at) = entry.expires_at {
|
|
if SystemTime::now() < expires_at {
|
|
return Ok(Some(entry.data));
|
|
}
|
|
// Entry has expired, remove it and return None
|
|
self.forget(key)?;
|
|
return Ok(None);
|
|
} else {
|
|
// No expiration, return cached data
|
|
return Ok(Some(entry.data));
|
|
}
|
|
}
|
|
Ok(None)
|
|
}
|
|
|
|
/// Store a value in the cache with optional TTL
|
|
pub fn put<T>(&self, key: &str, value: T, ttl: Option<Duration>) -> Result<(), CacheError>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
let expires_at = ttl.map(|duration| SystemTime::now() + duration);
|
|
let entry = CacheEntry {
|
|
data: value,
|
|
expires_at,
|
|
};
|
|
self.store_entry(key, &entry)
|
|
}
|
|
|
|
/// Remove a key from the cache
|
|
pub fn forget(&self, key: &str) -> Result<(), CacheError> {
|
|
let write_txn = self.db.begin_write()?;
|
|
{
|
|
let mut table = write_txn.open_table(CACHE_TABLE)?;
|
|
table.remove(key)?;
|
|
}
|
|
write_txn.commit()?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Clear all cached entries
|
|
pub fn clear(&self) -> Result<(), CacheError> {
|
|
let write_txn = self.db.begin_write()?;
|
|
{
|
|
let mut table = write_txn.open_table(CACHE_TABLE)?;
|
|
// Remove all entries
|
|
let keys: Vec<String> = {
|
|
let mut keys = Vec::new();
|
|
let mut iter = table.iter()?;
|
|
while let Some(Ok((key, _))) = iter.next() {
|
|
keys.push(key.value().to_string());
|
|
}
|
|
keys
|
|
};
|
|
|
|
for key in keys {
|
|
table.remove(key.as_str())?;
|
|
}
|
|
}
|
|
write_txn.commit()?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_cached_entry<T>(&self, key: &str) -> Result<Option<CacheEntry<T>>, CacheError>
|
|
where
|
|
T: DeserializeOwned,
|
|
{
|
|
let read_txn = self.db.begin_read()?;
|
|
let table = read_txn.open_table(CACHE_TABLE)?;
|
|
|
|
if let Some(cached_bytes) = table.get(key)? {
|
|
match bincode::deserialize::<CacheEntry<T>>(cached_bytes.value()) {
|
|
Ok(entry) => Ok(Some(entry)),
|
|
Err(_) => {
|
|
// Failed to deserialize, probably corrupted or wrong format
|
|
Ok(None)
|
|
}
|
|
}
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
fn store_entry<T>(&self, key: &str, entry: &CacheEntry<T>) -> Result<(), CacheError>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
let write_txn = self.db.begin_write()?;
|
|
{
|
|
let mut table = write_txn.open_table(CACHE_TABLE)?;
|
|
let serialized = bincode::serialize(entry)
|
|
.map_err(|e| CacheError::SerializationError(e.to_string()))?;
|
|
table.insert(key, serialized.as_slice())?;
|
|
}
|
|
write_txn.commit()?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::time::Duration;
|
|
use tempfile::NamedTempFile;
|
|
|
|
#[test]
|
|
fn test_cache_remember() {
|
|
let temp_file = NamedTempFile::new().unwrap();
|
|
let cache = new(temp_file.path()).unwrap();
|
|
|
|
let result = cache
|
|
.remember("test_key", || "computed_value".to_string())
|
|
.unwrap();
|
|
assert_eq!(result, "computed_value");
|
|
|
|
// Second call should return cached value
|
|
let result2 = cache
|
|
.remember("test_key", || "different_value".to_string())
|
|
.unwrap();
|
|
assert_eq!(result2, "computed_value");
|
|
}
|
|
|
|
#[test]
|
|
fn test_cache_ttl() {
|
|
let temp_file = NamedTempFile::new().unwrap();
|
|
let cache = new(temp_file.path()).unwrap();
|
|
|
|
// Cache with very short TTL
|
|
let result = cache
|
|
.remember_with_ttl("ttl_key", Some(Duration::from_millis(1)), || {
|
|
"cached_value".to_string()
|
|
})
|
|
.unwrap();
|
|
assert_eq!(result, "cached_value");
|
|
|
|
// Wait for expiration
|
|
std::thread::sleep(Duration::from_millis(10));
|
|
|
|
// Should recompute after expiration
|
|
let result2 = cache
|
|
.remember_with_ttl("ttl_key", Some(Duration::from_millis(1)), || {
|
|
"new_value".to_string()
|
|
})
|
|
.unwrap();
|
|
assert_eq!(result2, "new_value");
|
|
}
|
|
}
|