use tracing::{enabled, Level, trace}; use quick_cache::sync::KQCache; use quick_cache::{PlaceholderGuard, Weighter}; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; use std::convert::Infallible; use std::fmt::Debug; use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; use tokio::task::JoinHandle; use tokio::time::{sleep_until, Instant}; pub struct KQCacheWithTTL { cache: Arc>, max_item_weight: NonZeroU32, name: &'static str, ttl: Duration, tx: flume::Sender<(Instant, Key, Qey)>, weighter: We, pub task_handle: JoinHandle<()>, } struct KQCacheWithTTLTask { cache: Arc>, name: &'static str, rx: flume::Receiver<(Instant, Key, Qey)>, } pub struct PlaceholderGuardWithTTL<'a, Key, Qey, Val, We, B> { cache: &'a KQCacheWithTTL, inner: PlaceholderGuard<'a, Key, Qey, Val, We, B>, key: Key, qey: Qey, } impl< Key: Clone + Debug + Eq + Hash + Send + Sync + 'static, Qey: Clone + Debug + Eq + Hash + Send + Sync + 'static, Val: Clone + Send + Sync + 'static, We: Weighter + Clone + Send + Sync + 'static, B: BuildHasher + Clone + Send + Sync + 'static, > KQCacheWithTTL { pub async fn new_with_options( name: &'static str, estimated_items_capacity: usize, max_item_weight: NonZeroU32, weight_capacity: u64, weighter: We, hash_builder: B, ttl: Duration, ) -> Self { let (tx, rx) = flume::unbounded(); let cache = KQCache::with( estimated_items_capacity, weight_capacity, weighter.clone(), hash_builder, ); let cache = Arc::new(cache); let task = KQCacheWithTTLTask { cache: cache.clone(), name, rx, }; let task_handle = tokio::spawn(task.run()); Self { cache, max_item_weight, name, task_handle, ttl, tx, weighter, } } #[inline] pub fn get(&self, key: &Key, qey: &Qey) -> Option { self.cache.get(key, qey) } #[inline] pub async fn get_or_insert_async(&self, key: &Key, qey: &Qey, f: Fut) -> Val where Fut: Future, { self.try_get_or_insert_async::(key, qey, async move { Ok(f.await) }) .await .expect("infallible") } #[inline] pub async fn try_get_or_insert_async( &self, key: &Key, qey: &Qey, f: Fut, ) -> Result where Fut: Future>, { self.cache .get_or_insert_async(key, qey, async move { let x = f.await; if x.is_ok() { let expire_at = Instant::now() + self.ttl; trace!( "{}, {:?}, {:?} expiring in {}s", self.name, &key, &qey, expire_at.duration_since(Instant::now()).as_secs_f32() ); self.tx.send((expire_at, key.clone(), qey.clone())).unwrap(); } x }) .await } #[inline] pub async fn get_value_or_guard_async( &self, key: Key, qey: Qey, ) -> Result> { match self.cache.get_value_or_guard_async(&key, &qey).await { Ok(x) => Ok(x), Err(inner) => Err(PlaceholderGuardWithTTL { cache: self, inner, key, qey, }), } } /// if the item was too large to insert, it is returned with the error /// IMPORTANT! Inserting the same key multiple times does NOT reset the TTL! #[inline] pub fn try_insert(&self, key: Key, qey: Qey, val: Val) -> Result<(), (Key, Qey, Val)> { let expire_at = Instant::now() + self.ttl; let weight = self.weighter.weight(&key, &qey, &val); if weight <= self.max_item_weight { self.cache.insert(key.clone(), qey.clone(), val); trace!( "{}, {:?}, {:?} expiring in {}s", self.name, &key, &qey, expire_at.duration_since(Instant::now()).as_secs_f32() ); self.tx.send((expire_at, key, qey)).unwrap(); Ok(()) } else { Err((key, qey, val)) } } #[inline] pub fn peek(&self, key: &Key, qey: &Qey) -> Option { self.cache.peek(key, qey) } #[inline] pub fn remove(&self, key: &Key, qey: &Qey) -> bool { self.cache.remove(key, qey) } } impl< Key: Debug + Eq + Hash, Qey: Debug + Eq + Hash, Val: Clone, We: Weighter + Clone, B: BuildHasher + Clone, > KQCacheWithTTLTask { async fn run(self) { trace!("watching for expirations on {}", self.name); while let Ok((expire_at, key, qey)) = self.rx.recv_async().await { let now = Instant::now(); if expire_at > now { if enabled!(Level::TRACE) { trace!( "{}, {:?}, {:?} sleeping for {}ms.", self.name, key, qey, expire_at.duration_since(now).as_millis(), ); } sleep_until(expire_at).await; trace!("{}, {:?}, {:?} done sleeping", self.name, key, qey); } else { trace!("no need to sleep!"); } if self.cache.remove(&key, &qey) { trace!("removed {}, {:?}, {:?}", self.name, key, qey); } else { trace!("empty {}, {:?}, {:?}", self.name, key, qey); }; } trace!("watching for expirations on {}", self.name) } } impl< 'a, Key: Clone + Debug + Hash + Eq, Qey: Clone + Debug + Hash + Eq, Val: Clone, We: Weighter, B: BuildHasher, > PlaceholderGuardWithTTL<'a, Key, Qey, Val, We, B> { pub fn insert(self, val: Val) { let expire_at = Instant::now() + self.cache.ttl; let weight = self.cache.weighter.weight(&self.key, &self.qey, &val); if weight <= self.cache.max_item_weight { self.inner.insert(val); if enabled!(Level::TRACE) { trace!( "{}, {:?}, {:?} expiring in {}s", self.cache.name, self.key, self.qey, expire_at.duration_since(Instant::now()).as_secs_f32() ); } self.cache.tx.send((expire_at, self.key, self.qey)).unwrap(); } } } impl< Key: Clone + Debug + Eq + Hash + Send + Sync + 'static, Qey: Clone + Debug + Eq + Hash + Send + Sync + 'static, Val: Clone + Send + Sync + 'static, We: Weighter + Clone + Send + Sync + 'static, B: BuildHasher + Clone + Send + Sync + 'static, > Serialize for KQCacheWithTTL { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let mut state = serializer.serialize_struct(self.name, 5)?; state.serialize_field("len", &self.cache.len())?; state.serialize_field("weight", &self.cache.weight())?; state.serialize_field("capacity", &self.cache.capacity())?; state.serialize_field("hits", &self.cache.hits())?; state.serialize_field("misses", &self.cache.misses())?; state.end() } }