it compiles, but theres something wrong with moves
This commit is contained in:
parent
12b6d01434
commit
5cc4ca8d9e
|
@ -3,11 +3,12 @@ use moka::future::Cache;
|
||||||
use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
|
use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
|
||||||
use std::cell::Cell;
|
use std::cell::Cell;
|
||||||
use std::cmp::Eq;
|
use std::cmp::Eq;
|
||||||
use std::fmt::Display;
|
use std::fmt::{Debug, Display};
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::{atomic::AtomicU64, Arc};
|
use std::sync::{atomic::AtomicU64, Arc};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
/// A local cache that sits in front of a RedisRateLimiter
|
/// A local cache that sits in front of a RedisRateLimiter
|
||||||
/// Generic accross the key so it is simple to use with IPs or user keys
|
/// Generic accross the key so it is simple to use with IPs or user keys
|
||||||
|
@ -28,7 +29,7 @@ pub enum DeferredRateLimitResult {
|
||||||
|
|
||||||
impl<K> DeferredRateLimiter<K>
|
impl<K> DeferredRateLimiter<K>
|
||||||
where
|
where
|
||||||
K: Copy + Display + Hash + Eq + Send + Sync + 'static,
|
K: Copy + Debug + Display + Hash + Eq + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(cache_size: u64, prefix: &str, rrl: RedisRateLimiter) -> Self {
|
pub fn new(cache_size: u64, prefix: &str, rrl: RedisRateLimiter) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -54,39 +55,55 @@ where
|
||||||
let arc_new_entry = Arc::new(AtomicBool::new(false));
|
let arc_new_entry = Arc::new(AtomicBool::new(false));
|
||||||
let arc_retry_at = Arc::new(Cell::new(None));
|
let arc_retry_at = Arc::new(Cell::new(None));
|
||||||
|
|
||||||
|
let redis_key = format!("{}:{}", self.prefix, key);
|
||||||
|
|
||||||
// TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc
|
// TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc
|
||||||
// TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep
|
// TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep
|
||||||
let key_count = {
|
let arc_key_count: Arc<AtomicU64> = {
|
||||||
|
// clone things outside of the
|
||||||
let arc_new_entry = arc_new_entry.clone();
|
let arc_new_entry = arc_new_entry.clone();
|
||||||
let arc_retry_at = arc_retry_at.clone();
|
let arc_retry_at = arc_retry_at.clone();
|
||||||
|
let redis_key = redis_key.clone();
|
||||||
|
let rrl = Arc::new(self.rrl.clone());
|
||||||
|
|
||||||
self.local_cache
|
self.local_cache.get_with(*key, async move { todo!() });
|
||||||
.try_get_with(*key, async move {
|
|
||||||
arc_new_entry.store(true, Ordering::Release);
|
|
||||||
|
|
||||||
let label = format!("{}:{}", self.prefix, key);
|
/*
|
||||||
|
let x = self
|
||||||
|
.local_cache
|
||||||
|
.get_with(*key, move {
|
||||||
|
async move {
|
||||||
|
arc_new_entry.store(true, Ordering::Release);
|
||||||
|
|
||||||
let redis_count = match self
|
// we do not use the try operator here because we want to be okay with redis errors
|
||||||
.rrl
|
let redis_count = match rrl
|
||||||
.throttle_label(&label, Some(max_per_period), count)
|
.throttle_label(&redis_key, Some(max_per_period), count)
|
||||||
.await?
|
.await
|
||||||
{
|
{
|
||||||
RedisRateLimitResult::Allowed(count) => count,
|
Ok(RedisRateLimitResult::Allowed(count)) => count,
|
||||||
RedisRateLimitResult::RetryAt(retry_at, count) => {
|
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
|
||||||
arc_retry_at.set(Some(retry_at));
|
arc_retry_at.set(Some(retry_at));
|
||||||
count
|
count
|
||||||
}
|
}
|
||||||
RedisRateLimitResult::RetryNever => unimplemented!(),
|
Ok(RedisRateLimitResult::RetryNever) => unimplemented!(),
|
||||||
};
|
Err(x) => 0, // Err(err) => todo!("allow rate limit errors"),
|
||||||
|
// Err(RedisSomething)
|
||||||
|
// if we get a redis error, just let the user through. local caches will work fine
|
||||||
|
// though now that we do this, we need to reset rate limits every minute!
|
||||||
|
};
|
||||||
|
|
||||||
Ok::<_, anyhow::Error>(Arc::new(AtomicU64::new(redis_count)))
|
Arc::new(AtomicU64::new(redis_count))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await;
|
||||||
.unwrap()
|
*/
|
||||||
|
|
||||||
|
todo!("write more")
|
||||||
};
|
};
|
||||||
|
|
||||||
if arc_new_entry.load(Ordering::Acquire) {
|
if arc_new_entry.load(Ordering::Acquire) {
|
||||||
// new entry
|
// new entry. redis was already incremented
|
||||||
|
// return the retry_at that we got from
|
||||||
if let Some(retry_at) = arc_retry_at.get() {
|
if let Some(retry_at) = arc_retry_at.get() {
|
||||||
Ok(DeferredRateLimitResult::RetryAt(retry_at))
|
Ok(DeferredRateLimitResult::RetryAt(retry_at))
|
||||||
} else {
|
} else {
|
||||||
|
@ -94,24 +111,70 @@ where
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// we have a cached amount here
|
// we have a cached amount here
|
||||||
|
let cached_key_count = arc_key_count.fetch_add(count, Ordering::Acquire);
|
||||||
|
|
||||||
// increment our local count if
|
// assuming no other parallel futures incremented this key, this is the count that redis has
|
||||||
|
let expected_key_count = cached_key_count + count;
|
||||||
|
|
||||||
let f = async move {
|
if expected_key_count > max_per_period {
|
||||||
let label = format!("{}:{}", self.prefix, key);
|
// rate limit overshot!
|
||||||
|
let now = self.rrl.now_as_secs();
|
||||||
|
|
||||||
let redis_count = match self
|
// do not fetch_sub
|
||||||
.rrl
|
// another row might have queued a redis throttle_label to keep our count accurate
|
||||||
.throttle_label(&label, Some(max_per_period), count)
|
|
||||||
.await?
|
// show that we are rate limited without even querying redis
|
||||||
{
|
let retry_at = self.rrl.next_period(now);
|
||||||
RedisRateLimitResult::Allowed(count) => todo!("do something with allow"),
|
return Ok(DeferredRateLimitResult::RetryAt(retry_at));
|
||||||
RedisRateLimitResult::RetryAt(retry_at, count) => todo!("do something with retry at")
|
} else {
|
||||||
RedisRateLimitResult::RetryNever => unimplemented!(),
|
// local caches think rate limit should be okay
|
||||||
|
|
||||||
|
// prepare a future to increment redis
|
||||||
|
let increment_redis_f = {
|
||||||
|
let rrl = self.rrl.clone();
|
||||||
|
async move {
|
||||||
|
let rate_limit_result = rrl
|
||||||
|
.throttle_label(&redis_key, Some(max_per_period), count)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// TODO: log bad responses?
|
||||||
|
|
||||||
|
Ok::<_, anyhow::Error>(rate_limit_result)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok::<_, anyhow::Error>(())
|
// if close to max_per_period, wait for redis
|
||||||
};
|
// TODO: how close should we allow? depends on max expected concurent requests from one user
|
||||||
|
if expected_key_count > max_per_period * 98 / 100 {
|
||||||
|
// close to period. don't risk it. wait on redis
|
||||||
|
// match increment_redis_f.await {
|
||||||
|
// Ok(RedisRateLimitResult::Allowed(redis_count)) => todo!("1"),
|
||||||
|
// Ok(RedisRateLimitResult::RetryAt(retry_at, redis_count)) => todo!("2"),
|
||||||
|
// Ok(RedisRateLimitResult::RetryNever) => todo!("3"),
|
||||||
|
// Err(err) => {
|
||||||
|
// // don't let redis errors block our users!
|
||||||
|
// // error!(
|
||||||
|
// // // ?key, // TODO: this errors
|
||||||
|
// // ?err,
|
||||||
|
// // "unable to query rate limits. local cache available"
|
||||||
|
// // );
|
||||||
|
// todo!("4");
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
todo!("i think we can't move f around like this");
|
||||||
|
} else {
|
||||||
|
// rate limit has enough headroom that it should be safe to do this in the background
|
||||||
|
tokio::spawn(increment_redis_f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_count = arc_key_count.fetch_add(count, Ordering::Release);
|
||||||
|
|
||||||
|
todo!("check new_count");
|
||||||
|
|
||||||
|
// increment our local count and check what happened
|
||||||
|
// THERE IS A SMALL RACE HERE, but thats the price we pay for speed. a few queries leaking through is okay
|
||||||
|
// if it becomes a problem, we can lock, but i don't love that. it might be short enough to be okay though
|
||||||
|
|
||||||
todo!("write more");
|
todo!("write more");
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,25 @@ impl RedisRateLimiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn now_as_secs(&self) -> f32 {
|
||||||
|
// TODO: if system time doesn't match redis, this won't work great
|
||||||
|
// TODO: now that we fixed
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("cannot tell the time")
|
||||||
|
.as_secs_f32()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn period_id(&self, now_as_secs: f32) -> f32 {
|
||||||
|
(now_as_secs / self.period) % self.period
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next_period(&self, now_as_secs: f32) -> Instant {
|
||||||
|
let seconds_left_in_period = self.period - (now_as_secs % self.period);
|
||||||
|
|
||||||
|
Instant::now().add(Duration::from_secs_f32(seconds_left_in_period))
|
||||||
|
}
|
||||||
|
|
||||||
/// label might be an ip address or a user_key id.
|
/// label might be an ip address or a user_key id.
|
||||||
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
|
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
|
||||||
pub async fn throttle_label(
|
pub async fn throttle_label(
|
||||||
|
@ -59,13 +78,10 @@ impl RedisRateLimiter {
|
||||||
return Ok(RedisRateLimitResult::RetryNever);
|
return Ok(RedisRateLimitResult::RetryNever);
|
||||||
}
|
}
|
||||||
|
|
||||||
let now = SystemTime::now()
|
let now = self.now_as_secs();
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.context("cannot tell the time")?
|
|
||||||
.as_secs_f32();
|
|
||||||
|
|
||||||
// if self.period is 60, period_id will be the minute of the current time
|
// if self.period is 60, period_id will be the minute of the current time
|
||||||
let period_id = (now / self.period) % self.period;
|
let period_id = self.period_id(now);
|
||||||
|
|
||||||
// TODO: include max per period in the throttle key?
|
// TODO: include max per period in the throttle key?
|
||||||
let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id);
|
let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id);
|
||||||
|
@ -91,9 +107,8 @@ impl RedisRateLimiter {
|
||||||
let new_count = *x.first().context("check rate limit result")?;
|
let new_count = *x.first().context("check rate limit result")?;
|
||||||
|
|
||||||
if new_count > max_per_period {
|
if new_count > max_per_period {
|
||||||
let seconds_left_in_period = self.period - (now % self.period);
|
// TODO: this might actually be early if we are way over the count
|
||||||
|
let retry_at = self.next_period(now);
|
||||||
let retry_at = Instant::now().add(Duration::from_secs_f32(seconds_left_in_period));
|
|
||||||
|
|
||||||
debug!(%label, ?retry_at, "rate limited: {}/{}", new_count, max_per_period);
|
debug!(%label, ?retry_at, "rate limited: {}/{}", new_count, max_per_period);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue