no need for an atomic bool

This commit is contained in:
Bryan Stitt 2022-09-19 22:17:24 +00:00
parent 28fa424c2a
commit b6275aff1e
2 changed files with 78 additions and 60 deletions

View File

@ -4,7 +4,7 @@ use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
use std::cmp::Eq;
use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::Ordering;
use std::sync::{atomic::AtomicU64, Arc};
use tokio::sync::Mutex;
use tokio::time::{Duration, Instant};
@ -49,6 +49,7 @@ where
}
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
/// TODO: max_per_period being None means two things. some places it means unlimited, but here it means to use the default. make an enum
pub async fn throttle(
&self,
key: &K,
@ -61,40 +62,52 @@ where
return Ok(DeferredRateLimitResult::RetryNever);
}
let arc_new_entry = Arc::new(AtomicBool::new(false));
// TODO: this can't be right. what type do we actually want here?
let arc_retry_at = Arc::new(Mutex::new(None));
let arc_deferred_rate_limit_result = Arc::new(Mutex::new(None));
let redis_key = format!("{}:{}", self.prefix, key);
// TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc
// TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep
let arc_key_count: Arc<AtomicU64> = {
// clone things outside of the
let arc_new_entry = arc_new_entry.clone();
let arc_retry_at = arc_retry_at.clone();
// clone things outside of the `async move`
let arc_deferred_rate_limit_result = arc_deferred_rate_limit_result.clone();
let redis_key = redis_key.clone();
let rrl = Arc::new(self.rrl.clone());
// set arc_deferred_rate_limit_result and return the coun
self.local_cache
.get_with(*key, async move {
arc_new_entry.store(true, Ordering::Release);
// we do not use the try operator here because we want to be okay with redis errors
let redis_count = match rrl
.throttle_label(&redis_key, Some(max_per_period), count)
.await
{
Ok(RedisRateLimitResult::Allowed(count)) => count,
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
let _ = arc_retry_at.lock().await.insert(Some(retry_at));
Ok(RedisRateLimitResult::Allowed(count)) => {
let _ = arc_deferred_rate_limit_result
.lock()
.await
.insert(DeferredRateLimitResult::Allowed);
count
}
Ok(RedisRateLimitResult::RetryNever) => unimplemented!(),
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
let _ = arc_deferred_rate_limit_result
.lock()
.await
.insert(DeferredRateLimitResult::RetryAt(retry_at));
count
}
Ok(RedisRateLimitResult::RetryNever) => {
panic!("RetryNever shouldn't happen")
}
Err(err) => {
// if we get a redis error, just let the user through. local caches will work fine
// though now that we do this, we need to reset rate limits every minute!
let _ = arc_deferred_rate_limit_result
.lock()
.await
.insert(DeferredRateLimitResult::Allowed);
// if we get a redis error, just let the user through.
// if users are sticky on a server, local caches will work well enough
// though now that we do this, we need to reset rate limits every minute! cache must have ttl!
error!(?err, "unable to rate limit! creating empty cache");
0
}
@ -105,14 +118,12 @@ where
.await
};
if arc_new_entry.load(Ordering::Acquire) {
let mut locked = arc_deferred_rate_limit_result.lock().await;
if let Some(deferred_rate_limit_result) = locked.take() {
// new entry. redis was already incremented
// return the retry_at that we got from
if let Some(Some(retry_at)) = arc_retry_at.lock().await.take() {
Ok(DeferredRateLimitResult::RetryAt(retry_at))
} else {
Ok(DeferredRateLimitResult::Allowed)
}
Ok(deferred_rate_limit_result)
} else {
// we have a cached amount here
let cached_key_count = arc_key_count.fetch_add(count, Ordering::Acquire);

View File

@ -107,11 +107,17 @@ impl Web3ProxyApp {
.await?
{
Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
UserCacheValue::from((
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
user_id,
requests_per_minute,
user_count_per_period,
))
}
None => {
@ -120,7 +126,7 @@ impl Web3ProxyApp {
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
0,
0,
Some(0),
))
}
};
@ -132,7 +138,7 @@ impl Web3ProxyApp {
}
pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> {
// check the local cache
// check the local cache fo user data to save a database query
let user_data = if let Some(cached_user) = self.user_cache.get(&user_key) {
// TODO: also include the time this value was last checked! otherwise we cache forever!
if cached_user.expires_at < Instant::now() {
@ -148,7 +154,6 @@ impl Web3ProxyApp {
};
// if cache was empty, check the database
// TODO: i think there is a cleaner way to do this
let user_data = match user_data {
None => self
.cache_user_data(user_key)
@ -162,43 +167,45 @@ impl Web3ProxyApp {
}
// TODO: turn back on rate limiting once our alpha test is complete
return Ok(RateLimitResult::AllowedUser(user_data.user_id));
// TODO: if user_data.unlimited_queries
// return Ok(RateLimitResult::AllowedUser(user_data.user_id));
// user key is valid. now check rate limits
// TODO: this is throwing errors when curve-api hits us with high concurrency. investigate i think its bb8's fault
if let Some(rate_limiter) = &self.frontend_key_rate_limiter {
// TODO: query redis in the background so that users don't have to wait on this network request
// TODO: better key? have a prefix so its easy to delete all of these
// TODO: we should probably hash this or something
match rate_limiter
.throttle(&user_key, Some(user_data.user_count_per_period), 1)
.await
{
Ok(DeferredRateLimitResult::Allowed) => {
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
// TODO: set headers so they know when they can retry
// TODO: debug or trace?
// this is too verbose, but a stat might be good
// TODO: keys are secrets! use the id instead
trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
Ok(RateLimitResult::RateLimitedUser(
user_data.user_id,
Some(retry_at),
))
}
Ok(DeferredRateLimitResult::RetryNever) => {
// TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely
// TODO: keys are secret. don't log them!
trace!(?user_key, "rate limit is 0");
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
}
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
if user_data.user_count_per_period.is_none() {
// None means unlimited rate limit
Ok(RateLimitResult::AllowedUser(user_data.user_id))
} else {
match rate_limiter
.throttle(&user_key, user_data.user_count_per_period, 1)
.await
{
Ok(DeferredRateLimitResult::Allowed) => {
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
// TODO: set headers so they know when they can retry
// TODO: debug or trace?
// this is too verbose, but a stat might be good
// TODO: keys are secrets! use the id instead
trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
Ok(RateLimitResult::RateLimitedUser(
user_data.user_id,
Some(retry_at),
))
}
Ok(DeferredRateLimitResult::RetryNever) => {
// TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely
// TODO: keys are secret. don't log them!
trace!(?user_key, "rate limit is 0");
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
}
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
}
}
} else {