2022-09-15 20:57:24 +03:00
|
|
|
//#![warn(missing_docs)]
|
|
|
|
use moka::future::Cache;
|
|
|
|
use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
|
|
|
|
use std::cmp::Eq;
|
2022-09-17 02:02:55 +03:00
|
|
|
use std::fmt::{Debug, Display};
|
2022-09-15 20:57:24 +03:00
|
|
|
use std::hash::Hash;
|
2022-09-20 01:17:24 +03:00
|
|
|
use std::sync::atomic::Ordering;
|
2022-09-15 20:57:24 +03:00
|
|
|
use std::sync::{atomic::AtomicU64, Arc};
|
2022-09-17 04:06:10 +03:00
|
|
|
use tokio::sync::Mutex;
|
2022-09-17 04:19:11 +03:00
|
|
|
use tokio::time::{Duration, Instant};
|
2022-09-17 02:02:55 +03:00
|
|
|
use tracing::error;
|
2022-09-15 20:57:24 +03:00
|
|
|
|
|
|
|
/// A local cache that sits in front of a RedisRateLimiter
|
|
|
|
/// Generic accross the key so it is simple to use with IPs or user keys
|
|
|
|
pub struct DeferredRateLimiter<K>
|
|
|
|
where
|
|
|
|
K: Send + Sync,
|
|
|
|
{
|
2022-09-20 04:33:39 +03:00
|
|
|
local_cache: Cache<K, Arc<AtomicU64>, hashbrown::hash_map::DefaultHashBuilder>,
|
2022-09-15 20:57:24 +03:00
|
|
|
prefix: String,
|
|
|
|
rrl: RedisRateLimiter,
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum DeferredRateLimitResult {
|
|
|
|
Allowed,
|
|
|
|
RetryAt(Instant),
|
|
|
|
RetryNever,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<K> DeferredRateLimiter<K>
|
|
|
|
where
|
2022-09-17 02:02:55 +03:00
|
|
|
K: Copy + Debug + Display + Hash + Eq + Send + Sync + 'static,
|
2022-09-15 20:57:24 +03:00
|
|
|
{
|
|
|
|
pub fn new(cache_size: u64, prefix: &str, rrl: RedisRateLimiter) -> Self {
|
2022-09-17 04:19:11 +03:00
|
|
|
let ttl = rrl.period as u64;
|
|
|
|
|
2022-09-20 01:24:56 +03:00
|
|
|
// TODO: time to live is not exactly right. we want this ttl counter to start only after redis is down. this works for now
|
2022-09-17 04:19:11 +03:00
|
|
|
let local_cache = Cache::builder()
|
|
|
|
.time_to_live(Duration::from_secs(ttl))
|
|
|
|
.max_capacity(cache_size)
|
|
|
|
.name(prefix)
|
2022-09-20 04:33:39 +03:00
|
|
|
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
|
2022-09-17 04:19:11 +03:00
|
|
|
|
2022-09-15 20:57:24 +03:00
|
|
|
Self {
|
2022-09-17 04:19:11 +03:00
|
|
|
local_cache,
|
2022-09-15 20:57:24 +03:00
|
|
|
prefix: prefix.to_string(),
|
|
|
|
rrl,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
|
2022-09-20 01:17:24 +03:00
|
|
|
/// TODO: max_per_period being None means two things. some places it means unlimited, but here it means to use the default. make an enum
|
2022-09-15 20:57:24 +03:00
|
|
|
pub async fn throttle(
|
|
|
|
&self,
|
2022-09-20 06:26:12 +03:00
|
|
|
key: K,
|
2022-09-15 20:57:24 +03:00
|
|
|
max_per_period: Option<u64>,
|
|
|
|
count: u64,
|
|
|
|
) -> anyhow::Result<DeferredRateLimitResult> {
|
|
|
|
let max_per_period = max_per_period.unwrap_or(self.rrl.max_requests_per_period);
|
|
|
|
|
|
|
|
if max_per_period == 0 {
|
|
|
|
return Ok(DeferredRateLimitResult::RetryNever);
|
|
|
|
}
|
|
|
|
|
2022-09-20 01:17:24 +03:00
|
|
|
let arc_deferred_rate_limit_result = Arc::new(Mutex::new(None));
|
2022-09-15 20:57:24 +03:00
|
|
|
|
2022-09-17 02:02:55 +03:00
|
|
|
let redis_key = format!("{}:{}", self.prefix, key);
|
|
|
|
|
2022-09-15 20:57:24 +03:00
|
|
|
// TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc
|
|
|
|
// TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep
|
2022-09-20 01:41:53 +03:00
|
|
|
let local_key_count: Arc<AtomicU64> = {
|
2022-09-20 01:17:24 +03:00
|
|
|
// clone things outside of the `async move`
|
2022-09-20 01:41:53 +03:00
|
|
|
let deferred_rate_limit_result = arc_deferred_rate_limit_result.clone();
|
2022-09-17 02:02:55 +03:00
|
|
|
let redis_key = redis_key.clone();
|
|
|
|
let rrl = Arc::new(self.rrl.clone());
|
|
|
|
|
2022-09-20 01:17:24 +03:00
|
|
|
// set arc_deferred_rate_limit_result and return the coun
|
2022-09-17 04:06:10 +03:00
|
|
|
self.local_cache
|
2022-09-20 06:26:12 +03:00
|
|
|
.get_with(key, async move {
|
2022-09-17 04:06:10 +03:00
|
|
|
// we do not use the try operator here because we want to be okay with redis errors
|
|
|
|
let redis_count = match rrl
|
|
|
|
.throttle_label(&redis_key, Some(max_per_period), count)
|
|
|
|
.await
|
|
|
|
{
|
2022-09-20 01:17:24 +03:00
|
|
|
Ok(RedisRateLimitResult::Allowed(count)) => {
|
2022-09-20 01:41:53 +03:00
|
|
|
let _ = deferred_rate_limit_result
|
2022-09-20 01:17:24 +03:00
|
|
|
.lock()
|
|
|
|
.await
|
|
|
|
.insert(DeferredRateLimitResult::Allowed);
|
|
|
|
count
|
|
|
|
}
|
2022-09-17 04:06:10 +03:00
|
|
|
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
|
2022-09-20 01:41:53 +03:00
|
|
|
let _ = deferred_rate_limit_result
|
2022-09-20 01:17:24 +03:00
|
|
|
.lock()
|
|
|
|
.await
|
|
|
|
.insert(DeferredRateLimitResult::RetryAt(retry_at));
|
2022-09-17 04:06:10 +03:00
|
|
|
count
|
|
|
|
}
|
2022-09-20 01:17:24 +03:00
|
|
|
Ok(RedisRateLimitResult::RetryNever) => {
|
|
|
|
panic!("RetryNever shouldn't happen")
|
|
|
|
}
|
2022-09-17 04:06:10 +03:00
|
|
|
Err(err) => {
|
2022-09-20 01:41:53 +03:00
|
|
|
let _ = deferred_rate_limit_result
|
2022-09-20 01:17:24 +03:00
|
|
|
.lock()
|
|
|
|
.await
|
|
|
|
.insert(DeferredRateLimitResult::Allowed);
|
|
|
|
|
|
|
|
// if we get a redis error, just let the user through.
|
|
|
|
// if users are sticky on a server, local caches will work well enough
|
|
|
|
// though now that we do this, we need to reset rate limits every minute! cache must have ttl!
|
2022-09-17 04:06:10 +03:00
|
|
|
error!(?err, "unable to rate limit! creating empty cache");
|
|
|
|
0
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
Arc::new(AtomicU64::new(redis_count))
|
2022-09-15 20:57:24 +03:00
|
|
|
})
|
2022-09-17 04:06:10 +03:00
|
|
|
.await
|
2022-09-15 20:57:24 +03:00
|
|
|
};
|
|
|
|
|
2022-09-20 01:17:24 +03:00
|
|
|
let mut locked = arc_deferred_rate_limit_result.lock().await;
|
|
|
|
|
|
|
|
if let Some(deferred_rate_limit_result) = locked.take() {
|
2022-09-17 02:02:55 +03:00
|
|
|
// new entry. redis was already incremented
|
|
|
|
// return the retry_at that we got from
|
2022-09-20 01:17:24 +03:00
|
|
|
Ok(deferred_rate_limit_result)
|
2022-09-15 20:57:24 +03:00
|
|
|
} else {
|
|
|
|
// we have a cached amount here
|
2022-09-20 01:41:53 +03:00
|
|
|
let cached_key_count = local_key_count.fetch_add(count, Ordering::Acquire);
|
2022-09-17 02:02:55 +03:00
|
|
|
|
|
|
|
// assuming no other parallel futures incremented this key, this is the count that redis has
|
|
|
|
let expected_key_count = cached_key_count + count;
|
|
|
|
|
|
|
|
if expected_key_count > max_per_period {
|
|
|
|
// rate limit overshot!
|
|
|
|
let now = self.rrl.now_as_secs();
|
|
|
|
|
|
|
|
// do not fetch_sub
|
|
|
|
// another row might have queued a redis throttle_label to keep our count accurate
|
2022-09-15 20:57:24 +03:00
|
|
|
|
2022-09-17 02:02:55 +03:00
|
|
|
// show that we are rate limited without even querying redis
|
|
|
|
let retry_at = self.rrl.next_period(now);
|
2022-09-17 04:06:10 +03:00
|
|
|
Ok(DeferredRateLimitResult::RetryAt(retry_at))
|
2022-09-17 02:02:55 +03:00
|
|
|
} else {
|
|
|
|
// local caches think rate limit should be okay
|
|
|
|
|
2022-09-17 04:06:10 +03:00
|
|
|
// prepare a future to update redis
|
|
|
|
let rate_limit_f = {
|
2022-09-17 02:02:55 +03:00
|
|
|
let rrl = self.rrl.clone();
|
|
|
|
async move {
|
2022-09-17 04:06:10 +03:00
|
|
|
match rrl
|
2022-09-17 02:02:55 +03:00
|
|
|
.throttle_label(&redis_key, Some(max_per_period), count)
|
2022-09-17 04:06:10 +03:00
|
|
|
.await
|
|
|
|
{
|
|
|
|
Ok(RedisRateLimitResult::Allowed(count)) => {
|
2022-09-20 01:41:53 +03:00
|
|
|
local_key_count.store(count, Ordering::Release);
|
2022-09-17 04:06:10 +03:00
|
|
|
DeferredRateLimitResult::Allowed
|
|
|
|
}
|
|
|
|
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
|
2022-09-20 01:41:53 +03:00
|
|
|
local_key_count.store(count, Ordering::Release);
|
2022-09-17 04:06:10 +03:00
|
|
|
DeferredRateLimitResult::RetryAt(retry_at)
|
|
|
|
}
|
|
|
|
Ok(RedisRateLimitResult::RetryNever) => {
|
|
|
|
// TODO: what should we do to arc_key_count?
|
|
|
|
DeferredRateLimitResult::RetryNever
|
|
|
|
}
|
|
|
|
Err(err) => {
|
|
|
|
// don't let redis errors block our users!
|
|
|
|
error!(
|
2022-09-20 06:26:12 +03:00
|
|
|
?key,
|
2022-09-17 04:06:10 +03:00
|
|
|
?err,
|
2022-09-20 06:26:12 +03:00
|
|
|
"unable to query rate limits, but local cache is available"
|
2022-09-17 04:06:10 +03:00
|
|
|
);
|
|
|
|
// TODO: we need to start a timer that resets this count every minute
|
|
|
|
DeferredRateLimitResult::Allowed
|
|
|
|
}
|
|
|
|
}
|
2022-09-17 02:02:55 +03:00
|
|
|
}
|
2022-09-15 20:57:24 +03:00
|
|
|
};
|
|
|
|
|
2022-09-17 02:02:55 +03:00
|
|
|
// if close to max_per_period, wait for redis
|
|
|
|
// TODO: how close should we allow? depends on max expected concurent requests from one user
|
2022-09-17 04:06:10 +03:00
|
|
|
if expected_key_count > max_per_period * 99 / 100 {
|
2022-09-17 02:02:55 +03:00
|
|
|
// close to period. don't risk it. wait on redis
|
2022-09-17 04:06:10 +03:00
|
|
|
Ok(rate_limit_f.await)
|
2022-09-17 02:02:55 +03:00
|
|
|
} else {
|
|
|
|
// rate limit has enough headroom that it should be safe to do this in the background
|
2022-09-17 04:06:10 +03:00
|
|
|
tokio::spawn(rate_limit_f);
|
|
|
|
|
|
|
|
Ok(DeferredRateLimitResult::Allowed)
|
2022-09-17 02:02:55 +03:00
|
|
|
}
|
|
|
|
}
|
2022-09-15 20:57:24 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|