diff --git a/Cargo.lock b/Cargo.lock index 5c7130c6..84fe196e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1097,6 +1097,20 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.6" @@ -1261,6 +1275,7 @@ name = "deferred-rate-limiter" version = "0.2.0" dependencies = [ "anyhow", + "crossbeam", "moka", "redis-rate-limiter", "tokio", diff --git a/deferred-rate-limiter/Cargo.toml b/deferred-rate-limiter/Cargo.toml index 44078934..f48d216c 100644 --- a/deferred-rate-limiter/Cargo.toml +++ b/deferred-rate-limiter/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" redis-rate-limiter = { path = "../redis-rate-limiter" } anyhow = "1.0.65" +crossbeam = "*" moka = { version = "0.9.4", default-features = false, features = ["future"] } tokio = "1.21.1" tracing = "0.1.36" diff --git a/deferred-rate-limiter/src/lib.rs b/deferred-rate-limiter/src/lib.rs index 1e751da4..e26bdc39 100644 --- a/deferred-rate-limiter/src/lib.rs +++ b/deferred-rate-limiter/src/lib.rs @@ -1,12 +1,12 @@ //#![warn(missing_docs)] use moka::future::Cache; use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter}; -use std::cell::Cell; use std::cmp::Eq; use std::fmt::{Debug, Display}; use std::hash::Hash; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{atomic::AtomicU64, Arc}; +use tokio::sync::Mutex; use tokio::time::Instant; use tracing::error; @@ -53,7 +53,9 @@ where } let arc_new_entry = Arc::new(AtomicBool::new(false)); - let arc_retry_at = Arc::new(Cell::new(None)); + + // TODO: this can't be right. what type do we actually want here? + let arc_retry_at = Arc::new(Mutex::new(None)); let redis_key = format!("{}:{}", self.prefix, key); @@ -66,45 +68,38 @@ where let redis_key = redis_key.clone(); let rrl = Arc::new(self.rrl.clone()); - self.local_cache.get_with(*key, async move { todo!() }); + self.local_cache + .get_with(*key, async move { + arc_new_entry.store(true, Ordering::Release); - /* - let x = self - .local_cache - .get_with(*key, move { - async move { - arc_new_entry.store(true, Ordering::Release); + // we do not use the try operator here because we want to be okay with redis errors + let redis_count = match rrl + .throttle_label(&redis_key, Some(max_per_period), count) + .await + { + Ok(RedisRateLimitResult::Allowed(count)) => count, + Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { + let _ = arc_retry_at.lock().await.insert(Some(retry_at)); + count + } + Ok(RedisRateLimitResult::RetryNever) => unimplemented!(), + Err(err) => { + // if we get a redis error, just let the user through. local caches will work fine + // though now that we do this, we need to reset rate limits every minute! + error!(?err, "unable to rate limit! creating empty cache"); + 0 + } + }; - // we do not use the try operator here because we want to be okay with redis errors - let redis_count = match rrl - .throttle_label(&redis_key, Some(max_per_period), count) - .await - { - Ok(RedisRateLimitResult::Allowed(count)) => count, - Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { - arc_retry_at.set(Some(retry_at)); - count - } - Ok(RedisRateLimitResult::RetryNever) => unimplemented!(), - Err(x) => 0, // Err(err) => todo!("allow rate limit errors"), - // Err(RedisSomething) - // if we get a redis error, just let the user through. local caches will work fine - // though now that we do this, we need to reset rate limits every minute! - }; - - Arc::new(AtomicU64::new(redis_count)) - } + Arc::new(AtomicU64::new(redis_count)) }) - .await; - */ - - todo!("write more") + .await }; if arc_new_entry.load(Ordering::Acquire) { // new entry. redis was already incremented // return the retry_at that we got from - if let Some(retry_at) = arc_retry_at.get() { + if let Some(Some(retry_at)) = arc_retry_at.lock().await.take() { Ok(DeferredRateLimitResult::RetryAt(retry_at)) } else { Ok(DeferredRateLimitResult::Allowed) @@ -125,58 +120,56 @@ where // show that we are rate limited without even querying redis let retry_at = self.rrl.next_period(now); - return Ok(DeferredRateLimitResult::RetryAt(retry_at)); + Ok(DeferredRateLimitResult::RetryAt(retry_at)) } else { // local caches think rate limit should be okay - // prepare a future to increment redis - let increment_redis_f = { + // prepare a future to update redis + let rate_limit_f = { let rrl = self.rrl.clone(); async move { - let rate_limit_result = rrl + match rrl .throttle_label(&redis_key, Some(max_per_period), count) - .await?; - - // TODO: log bad responses? - - Ok::<_, anyhow::Error>(rate_limit_result) + .await + { + Ok(RedisRateLimitResult::Allowed(count)) => { + arc_key_count.store(count, Ordering::Release); + DeferredRateLimitResult::Allowed + } + Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { + arc_key_count.store(count, Ordering::Release); + DeferredRateLimitResult::RetryAt(retry_at) + } + Ok(RedisRateLimitResult::RetryNever) => { + // TODO: what should we do to arc_key_count? + DeferredRateLimitResult::RetryNever + } + Err(err) => { + // don't let redis errors block our users! + error!( + // ?key, // TODO: this errors + ?err, + "unable to query rate limits. local cache available" + ); + // TODO: we need to start a timer that resets this count every minute + DeferredRateLimitResult::Allowed + } + } } }; // if close to max_per_period, wait for redis // TODO: how close should we allow? depends on max expected concurent requests from one user - if expected_key_count > max_per_period * 98 / 100 { + if expected_key_count > max_per_period * 99 / 100 { // close to period. don't risk it. wait on redis - // match increment_redis_f.await { - // Ok(RedisRateLimitResult::Allowed(redis_count)) => todo!("1"), - // Ok(RedisRateLimitResult::RetryAt(retry_at, redis_count)) => todo!("2"), - // Ok(RedisRateLimitResult::RetryNever) => todo!("3"), - // Err(err) => { - // // don't let redis errors block our users! - // // error!( - // // // ?key, // TODO: this errors - // // ?err, - // // "unable to query rate limits. local cache available" - // // ); - // todo!("4"); - // } - // }; - todo!("i think we can't move f around like this"); + Ok(rate_limit_f.await) } else { // rate limit has enough headroom that it should be safe to do this in the background - tokio::spawn(increment_redis_f); + tokio::spawn(rate_limit_f); + + Ok(DeferredRateLimitResult::Allowed) } } - - let new_count = arc_key_count.fetch_add(count, Ordering::Release); - - todo!("check new_count"); - - // increment our local count and check what happened - // THERE IS A SMALL RACE HERE, but thats the price we pay for speed. a few queries leaking through is okay - // if it becomes a problem, we can lock, but i don't love that. it might be short enough to be okay though - - todo!("write more"); } } }