no need for an atomic bool

This commit is contained in:
Bryan Stitt 2022-09-19 22:17:24 +00:00
parent 28fa424c2a
commit b6275aff1e
2 changed files with 78 additions and 60 deletions

View File

@ -4,7 +4,7 @@ use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
use std::cmp::Eq; use std::cmp::Eq;
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::hash::Hash; use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::Ordering;
use std::sync::{atomic::AtomicU64, Arc}; use std::sync::{atomic::AtomicU64, Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -49,6 +49,7 @@ where
} }
/// if setting max_per_period, be sure to keep the period the same for all requests to this label /// if setting max_per_period, be sure to keep the period the same for all requests to this label
/// TODO: max_per_period being None means two things. some places it means unlimited, but here it means to use the default. make an enum
pub async fn throttle( pub async fn throttle(
&self, &self,
key: &K, key: &K,
@ -61,40 +62,52 @@ where
return Ok(DeferredRateLimitResult::RetryNever); return Ok(DeferredRateLimitResult::RetryNever);
} }
let arc_new_entry = Arc::new(AtomicBool::new(false)); let arc_deferred_rate_limit_result = Arc::new(Mutex::new(None));
// TODO: this can't be right. what type do we actually want here?
let arc_retry_at = Arc::new(Mutex::new(None));
let redis_key = format!("{}:{}", self.prefix, key); let redis_key = format!("{}:{}", self.prefix, key);
// TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc // TODO: DO NOT UNWRAP HERE. figure out how to handle anyhow error being wrapped in an Arc
// TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep // TODO: i'm sure this could be a lot better. but race conditions make this hard to think through. brain needs sleep
let arc_key_count: Arc<AtomicU64> = { let arc_key_count: Arc<AtomicU64> = {
// clone things outside of the // clone things outside of the `async move`
let arc_new_entry = arc_new_entry.clone(); let arc_deferred_rate_limit_result = arc_deferred_rate_limit_result.clone();
let arc_retry_at = arc_retry_at.clone();
let redis_key = redis_key.clone(); let redis_key = redis_key.clone();
let rrl = Arc::new(self.rrl.clone()); let rrl = Arc::new(self.rrl.clone());
// set arc_deferred_rate_limit_result and return the coun
self.local_cache self.local_cache
.get_with(*key, async move { .get_with(*key, async move {
arc_new_entry.store(true, Ordering::Release);
// we do not use the try operator here because we want to be okay with redis errors // we do not use the try operator here because we want to be okay with redis errors
let redis_count = match rrl let redis_count = match rrl
.throttle_label(&redis_key, Some(max_per_period), count) .throttle_label(&redis_key, Some(max_per_period), count)
.await .await
{ {
Ok(RedisRateLimitResult::Allowed(count)) => count, Ok(RedisRateLimitResult::Allowed(count)) => {
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { let _ = arc_deferred_rate_limit_result
let _ = arc_retry_at.lock().await.insert(Some(retry_at)); .lock()
.await
.insert(DeferredRateLimitResult::Allowed);
count count
} }
Ok(RedisRateLimitResult::RetryNever) => unimplemented!(), Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
let _ = arc_deferred_rate_limit_result
.lock()
.await
.insert(DeferredRateLimitResult::RetryAt(retry_at));
count
}
Ok(RedisRateLimitResult::RetryNever) => {
panic!("RetryNever shouldn't happen")
}
Err(err) => { Err(err) => {
// if we get a redis error, just let the user through. local caches will work fine let _ = arc_deferred_rate_limit_result
// though now that we do this, we need to reset rate limits every minute! .lock()
.await
.insert(DeferredRateLimitResult::Allowed);
// if we get a redis error, just let the user through.
// if users are sticky on a server, local caches will work well enough
// though now that we do this, we need to reset rate limits every minute! cache must have ttl!
error!(?err, "unable to rate limit! creating empty cache"); error!(?err, "unable to rate limit! creating empty cache");
0 0
} }
@ -105,14 +118,12 @@ where
.await .await
}; };
if arc_new_entry.load(Ordering::Acquire) { let mut locked = arc_deferred_rate_limit_result.lock().await;
if let Some(deferred_rate_limit_result) = locked.take() {
// new entry. redis was already incremented // new entry. redis was already incremented
// return the retry_at that we got from // return the retry_at that we got from
if let Some(Some(retry_at)) = arc_retry_at.lock().await.take() { Ok(deferred_rate_limit_result)
Ok(DeferredRateLimitResult::RetryAt(retry_at))
} else {
Ok(DeferredRateLimitResult::Allowed)
}
} else { } else {
// we have a cached amount here // we have a cached amount here
let cached_key_count = arc_key_count.fetch_add(count, Ordering::Acquire); let cached_key_count = arc_key_count.fetch_add(count, Ordering::Acquire);

View File

@ -107,11 +107,17 @@ impl Web3ProxyApp {
.await? .await?
{ {
Some((user_id, requests_per_minute)) => { Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
UserCacheValue::from(( UserCacheValue::from((
// TODO: how long should this cache last? get this from config // TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60), Instant::now() + Duration::from_secs(60),
user_id, user_id,
requests_per_minute, user_count_per_period,
)) ))
} }
None => { None => {
@ -120,7 +126,7 @@ impl Web3ProxyApp {
// TODO: how long should this cache last? get this from config // TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60), Instant::now() + Duration::from_secs(60),
0, 0,
0, Some(0),
)) ))
} }
}; };
@ -132,7 +138,7 @@ impl Web3ProxyApp {
} }
pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> { pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> {
// check the local cache // check the local cache fo user data to save a database query
let user_data = if let Some(cached_user) = self.user_cache.get(&user_key) { let user_data = if let Some(cached_user) = self.user_cache.get(&user_key) {
// TODO: also include the time this value was last checked! otherwise we cache forever! // TODO: also include the time this value was last checked! otherwise we cache forever!
if cached_user.expires_at < Instant::now() { if cached_user.expires_at < Instant::now() {
@ -148,7 +154,6 @@ impl Web3ProxyApp {
}; };
// if cache was empty, check the database // if cache was empty, check the database
// TODO: i think there is a cleaner way to do this
let user_data = match user_data { let user_data = match user_data {
None => self None => self
.cache_user_data(user_key) .cache_user_data(user_key)
@ -162,43 +167,45 @@ impl Web3ProxyApp {
} }
// TODO: turn back on rate limiting once our alpha test is complete // TODO: turn back on rate limiting once our alpha test is complete
return Ok(RateLimitResult::AllowedUser(user_data.user_id)); // TODO: if user_data.unlimited_queries
// return Ok(RateLimitResult::AllowedUser(user_data.user_id));
// user key is valid. now check rate limits // user key is valid. now check rate limits
// TODO: this is throwing errors when curve-api hits us with high concurrency. investigate i think its bb8's fault
if let Some(rate_limiter) = &self.frontend_key_rate_limiter { if let Some(rate_limiter) = &self.frontend_key_rate_limiter {
// TODO: query redis in the background so that users don't have to wait on this network request if user_data.user_count_per_period.is_none() {
// TODO: better key? have a prefix so its easy to delete all of these // None means unlimited rate limit
// TODO: we should probably hash this or something Ok(RateLimitResult::AllowedUser(user_data.user_id))
match rate_limiter } else {
.throttle(&user_key, Some(user_data.user_count_per_period), 1) match rate_limiter
.await .throttle(&user_key, user_data.user_count_per_period, 1)
{ .await
Ok(DeferredRateLimitResult::Allowed) => { {
Ok(RateLimitResult::AllowedUser(user_data.user_id)) Ok(DeferredRateLimitResult::Allowed) => {
} Ok(RateLimitResult::AllowedUser(user_data.user_id))
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => { }
// TODO: set headers so they know when they can retry Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
// TODO: debug or trace? // TODO: set headers so they know when they can retry
// this is too verbose, but a stat might be good // TODO: debug or trace?
// TODO: keys are secrets! use the id instead // this is too verbose, but a stat might be good
trace!(?user_key, "rate limit exceeded until {:?}", retry_at); // TODO: keys are secrets! use the id instead
Ok(RateLimitResult::RateLimitedUser( trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
user_data.user_id, Ok(RateLimitResult::RateLimitedUser(
Some(retry_at), user_data.user_id,
)) Some(retry_at),
} ))
Ok(DeferredRateLimitResult::RetryNever) => { }
// TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely Ok(DeferredRateLimitResult::RetryNever) => {
// TODO: keys are secret. don't log them! // TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely
trace!(?user_key, "rate limit is 0"); // TODO: keys are secret. don't log them!
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None)) trace!(?user_key, "rate limit is 0");
} Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
Err(err) => { }
// internal error, not rate limit being hit Err(err) => {
// TODO: i really want axum to do this for us in a single place. // internal error, not rate limit being hit
error!(?err, "rate limiter is unhappy. allowing ip"); // TODO: i really want axum to do this for us in a single place.
Ok(RateLimitResult::AllowedUser(user_data.user_id)) error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
} }
} }
} else { } else {