diff --git a/deferred-rate-limiter/src/lib.rs b/deferred-rate-limiter/src/lib.rs index 2dc586c1..1e751da4 100644 --- a/deferred-rate-limiter/src/lib.rs +++ b/deferred-rate-limiter/src/lib.rs @@ -3,11 +3,12 @@ use moka::future::Cache; use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter}; use std::cell::Cell; use std::cmp::Eq; -use std::fmt::Display; +use std::fmt::{Debug, Display}; use std::hash::Hash; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{atomic::AtomicU64, Arc}; use tokio::time::Instant; +use tracing::error; /// A local cache that sits in front of a RedisRateLimiter /// Generic accross the key so it is simple to use with IPs or user keys @@ -28,7 +29,7 @@ pub enum DeferredRateLimitResult { impl DeferredRateLimiter where - K: Copy + Display + Hash + Eq + Send + Sync + 'static, + K: Copy + Debug + Display + Hash + Eq + Send + Sync + 'static, { pub fn new(cache_size: u64, prefix: &str, rrl: RedisRateLimiter) -> Self { Self { @@ -54,39 +55,55 @@ where let arc_new_entry = Arc::new(AtomicBool::new(false)); let arc_retry_at = Arc::new(Cell::new(None)); + let redis_key = format!("{}:{}", self.prefix, key); + // TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc // TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep - let key_count = { + let arc_key_count: Arc = { + // clone things outside of the let arc_new_entry = arc_new_entry.clone(); let arc_retry_at = arc_retry_at.clone(); + let redis_key = redis_key.clone(); + let rrl = Arc::new(self.rrl.clone()); - self.local_cache - .try_get_with(*key, async move { - arc_new_entry.store(true, Ordering::Release); + self.local_cache.get_with(*key, async move { todo!() }); - let label = format!("{}:{}", self.prefix, key); + /* + let x = self + .local_cache + .get_with(*key, move { + async move { + arc_new_entry.store(true, Ordering::Release); - let redis_count = match self - .rrl - .throttle_label(&label, Some(max_per_period), count) - .await? - { - RedisRateLimitResult::Allowed(count) => count, - RedisRateLimitResult::RetryAt(retry_at, count) => { - arc_retry_at.set(Some(retry_at)); - count - } - RedisRateLimitResult::RetryNever => unimplemented!(), - }; + // we do not use the try operator here because we want to be okay with redis errors + let redis_count = match rrl + .throttle_label(&redis_key, Some(max_per_period), count) + .await + { + Ok(RedisRateLimitResult::Allowed(count)) => count, + Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { + arc_retry_at.set(Some(retry_at)); + count + } + Ok(RedisRateLimitResult::RetryNever) => unimplemented!(), + Err(x) => 0, // Err(err) => todo!("allow rate limit errors"), + // Err(RedisSomething) + // if we get a redis error, just let the user through. local caches will work fine + // though now that we do this, we need to reset rate limits every minute! + }; - Ok::<_, anyhow::Error>(Arc::new(AtomicU64::new(redis_count))) + Arc::new(AtomicU64::new(redis_count)) + } }) - .await - .unwrap() + .await; + */ + + todo!("write more") }; if arc_new_entry.load(Ordering::Acquire) { - // new entry + // new entry. redis was already incremented + // return the retry_at that we got from if let Some(retry_at) = arc_retry_at.get() { Ok(DeferredRateLimitResult::RetryAt(retry_at)) } else { @@ -94,24 +111,70 @@ where } } else { // we have a cached amount here + let cached_key_count = arc_key_count.fetch_add(count, Ordering::Acquire); - // increment our local count if + // assuming no other parallel futures incremented this key, this is the count that redis has + let expected_key_count = cached_key_count + count; - let f = async move { - let label = format!("{}:{}", self.prefix, key); + if expected_key_count > max_per_period { + // rate limit overshot! + let now = self.rrl.now_as_secs(); - let redis_count = match self - .rrl - .throttle_label(&label, Some(max_per_period), count) - .await? - { - RedisRateLimitResult::Allowed(count) => todo!("do something with allow"), - RedisRateLimitResult::RetryAt(retry_at, count) => todo!("do something with retry at") - RedisRateLimitResult::RetryNever => unimplemented!(), + // do not fetch_sub + // another row might have queued a redis throttle_label to keep our count accurate + + // show that we are rate limited without even querying redis + let retry_at = self.rrl.next_period(now); + return Ok(DeferredRateLimitResult::RetryAt(retry_at)); + } else { + // local caches think rate limit should be okay + + // prepare a future to increment redis + let increment_redis_f = { + let rrl = self.rrl.clone(); + async move { + let rate_limit_result = rrl + .throttle_label(&redis_key, Some(max_per_period), count) + .await?; + + // TODO: log bad responses? + + Ok::<_, anyhow::Error>(rate_limit_result) + } }; - Ok::<_, anyhow::Error>(()) - }; + // if close to max_per_period, wait for redis + // TODO: how close should we allow? depends on max expected concurent requests from one user + if expected_key_count > max_per_period * 98 / 100 { + // close to period. don't risk it. wait on redis + // match increment_redis_f.await { + // Ok(RedisRateLimitResult::Allowed(redis_count)) => todo!("1"), + // Ok(RedisRateLimitResult::RetryAt(retry_at, redis_count)) => todo!("2"), + // Ok(RedisRateLimitResult::RetryNever) => todo!("3"), + // Err(err) => { + // // don't let redis errors block our users! + // // error!( + // // // ?key, // TODO: this errors + // // ?err, + // // "unable to query rate limits. local cache available" + // // ); + // todo!("4"); + // } + // }; + todo!("i think we can't move f around like this"); + } else { + // rate limit has enough headroom that it should be safe to do this in the background + tokio::spawn(increment_redis_f); + } + } + + let new_count = arc_key_count.fetch_add(count, Ordering::Release); + + todo!("check new_count"); + + // increment our local count and check what happened + // THERE IS A SMALL RACE HERE, but thats the price we pay for speed. a few queries leaking through is okay + // if it becomes a problem, we can lock, but i don't love that. it might be short enough to be okay though todo!("write more"); } diff --git a/redis-rate-limiter/src/lib.rs b/redis-rate-limiter/src/lib.rs index 465101b2..2ac85b07 100644 --- a/redis-rate-limiter/src/lib.rs +++ b/redis-rate-limiter/src/lib.rs @@ -45,6 +45,25 @@ impl RedisRateLimiter { } } + pub fn now_as_secs(&self) -> f32 { + // TODO: if system time doesn't match redis, this won't work great + // TODO: now that we fixed + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("cannot tell the time") + .as_secs_f32() + } + + pub fn period_id(&self, now_as_secs: f32) -> f32 { + (now_as_secs / self.period) % self.period + } + + pub fn next_period(&self, now_as_secs: f32) -> Instant { + let seconds_left_in_period = self.period - (now_as_secs % self.period); + + Instant::now().add(Duration::from_secs_f32(seconds_left_in_period)) + } + /// label might be an ip address or a user_key id. /// if setting max_per_period, be sure to keep the period the same for all requests to this label pub async fn throttle_label( @@ -59,13 +78,10 @@ impl RedisRateLimiter { return Ok(RedisRateLimitResult::RetryNever); } - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .context("cannot tell the time")? - .as_secs_f32(); + let now = self.now_as_secs(); // if self.period is 60, period_id will be the minute of the current time - let period_id = (now / self.period) % self.period; + let period_id = self.period_id(now); // TODO: include max per period in the throttle key? let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id); @@ -91,9 +107,8 @@ impl RedisRateLimiter { let new_count = *x.first().context("check rate limit result")?; if new_count > max_per_period { - let seconds_left_in_period = self.period - (now % self.period); - - let retry_at = Instant::now().add(Duration::from_secs_f32(seconds_left_in_period)); + // TODO: this might actually be early if we are way over the count + let retry_at = self.next_period(now); debug!(%label, ?retry_at, "rate limited: {}/{}", new_count, max_per_period);