need a mutex, not a cell

This commit is contained in:
Bryan Stitt 2022-09-17 01:06:10 +00:00
parent 5cc4ca8d9e
commit 6182b5f8e6
3 changed files with 78 additions and 69 deletions

15
Cargo.lock generated
View File

@ -1097,6 +1097,20 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.6" version = "0.5.6"
@ -1261,6 +1275,7 @@ name = "deferred-rate-limiter"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"crossbeam",
"moka", "moka",
"redis-rate-limiter", "redis-rate-limiter",
"tokio", "tokio",

View File

@ -8,6 +8,7 @@ edition = "2021"
redis-rate-limiter = { path = "../redis-rate-limiter" } redis-rate-limiter = { path = "../redis-rate-limiter" }
anyhow = "1.0.65" anyhow = "1.0.65"
crossbeam = "*"
moka = { version = "0.9.4", default-features = false, features = ["future"] } moka = { version = "0.9.4", default-features = false, features = ["future"] }
tokio = "1.21.1" tokio = "1.21.1"
tracing = "0.1.36" tracing = "0.1.36"

View File

@ -1,12 +1,12 @@
//#![warn(missing_docs)] //#![warn(missing_docs)]
use moka::future::Cache; use moka::future::Cache;
use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter}; use redis_rate_limiter::{RedisRateLimitResult, RedisRateLimiter};
use std::cell::Cell;
use std::cmp::Eq; use std::cmp::Eq;
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::hash::Hash; use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{atomic::AtomicU64, Arc}; use std::sync::{atomic::AtomicU64, Arc};
use tokio::sync::Mutex;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::error; use tracing::error;
@ -53,7 +53,9 @@ where
} }
let arc_new_entry = Arc::new(AtomicBool::new(false)); let arc_new_entry = Arc::new(AtomicBool::new(false));
let arc_retry_at = Arc::new(Cell::new(None));
// TODO: this can't be right. what type do we actually want here?
let arc_retry_at = Arc::new(Mutex::new(None));
let redis_key = format!("{}:{}", self.prefix, key); let redis_key = format!("{}:{}", self.prefix, key);
@ -66,13 +68,8 @@ where
let redis_key = redis_key.clone(); let redis_key = redis_key.clone();
let rrl = Arc::new(self.rrl.clone()); let rrl = Arc::new(self.rrl.clone());
self.local_cache.get_with(*key, async move { todo!() }); self.local_cache
.get_with(*key, async move {
/*
let x = self
.local_cache
.get_with(*key, move {
async move {
arc_new_entry.store(true, Ordering::Release); arc_new_entry.store(true, Ordering::Release);
// we do not use the try operator here because we want to be okay with redis errors // we do not use the try operator here because we want to be okay with redis errors
@ -82,29 +79,27 @@ where
{ {
Ok(RedisRateLimitResult::Allowed(count)) => count, Ok(RedisRateLimitResult::Allowed(count)) => count,
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => { Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
arc_retry_at.set(Some(retry_at)); let _ = arc_retry_at.lock().await.insert(Some(retry_at));
count count
} }
Ok(RedisRateLimitResult::RetryNever) => unimplemented!(), Ok(RedisRateLimitResult::RetryNever) => unimplemented!(),
Err(x) => 0, // Err(err) => todo!("allow rate limit errors"), Err(err) => {
// Err(RedisSomething)
// if we get a redis error, just let the user through. local caches will work fine // if we get a redis error, just let the user through. local caches will work fine
// though now that we do this, we need to reset rate limits every minute! // though now that we do this, we need to reset rate limits every minute!
error!(?err, "unable to rate limit! creating empty cache");
0
}
}; };
Arc::new(AtomicU64::new(redis_count)) Arc::new(AtomicU64::new(redis_count))
}
}) })
.await; .await
*/
todo!("write more")
}; };
if arc_new_entry.load(Ordering::Acquire) { if arc_new_entry.load(Ordering::Acquire) {
// new entry. redis was already incremented // new entry. redis was already incremented
// return the retry_at that we got from // return the retry_at that we got from
if let Some(retry_at) = arc_retry_at.get() { if let Some(Some(retry_at)) = arc_retry_at.lock().await.take() {
Ok(DeferredRateLimitResult::RetryAt(retry_at)) Ok(DeferredRateLimitResult::RetryAt(retry_at))
} else { } else {
Ok(DeferredRateLimitResult::Allowed) Ok(DeferredRateLimitResult::Allowed)
@ -125,58 +120,56 @@ where
// show that we are rate limited without even querying redis // show that we are rate limited without even querying redis
let retry_at = self.rrl.next_period(now); let retry_at = self.rrl.next_period(now);
return Ok(DeferredRateLimitResult::RetryAt(retry_at)); Ok(DeferredRateLimitResult::RetryAt(retry_at))
} else { } else {
// local caches think rate limit should be okay // local caches think rate limit should be okay
// prepare a future to increment redis // prepare a future to update redis
let increment_redis_f = { let rate_limit_f = {
let rrl = self.rrl.clone(); let rrl = self.rrl.clone();
async move { async move {
let rate_limit_result = rrl match rrl
.throttle_label(&redis_key, Some(max_per_period), count) .throttle_label(&redis_key, Some(max_per_period), count)
.await?; .await
{
// TODO: log bad responses? Ok(RedisRateLimitResult::Allowed(count)) => {
arc_key_count.store(count, Ordering::Release);
Ok::<_, anyhow::Error>(rate_limit_result) DeferredRateLimitResult::Allowed
}
Ok(RedisRateLimitResult::RetryAt(retry_at, count)) => {
arc_key_count.store(count, Ordering::Release);
DeferredRateLimitResult::RetryAt(retry_at)
}
Ok(RedisRateLimitResult::RetryNever) => {
// TODO: what should we do to arc_key_count?
DeferredRateLimitResult::RetryNever
}
Err(err) => {
// don't let redis errors block our users!
error!(
// ?key, // TODO: this errors
?err,
"unable to query rate limits. local cache available"
);
// TODO: we need to start a timer that resets this count every minute
DeferredRateLimitResult::Allowed
}
}
} }
}; };
// if close to max_per_period, wait for redis // if close to max_per_period, wait for redis
// TODO: how close should we allow? depends on max expected concurent requests from one user // TODO: how close should we allow? depends on max expected concurent requests from one user
if expected_key_count > max_per_period * 98 / 100 { if expected_key_count > max_per_period * 99 / 100 {
// close to period. don't risk it. wait on redis // close to period. don't risk it. wait on redis
// match increment_redis_f.await { Ok(rate_limit_f.await)
// Ok(RedisRateLimitResult::Allowed(redis_count)) => todo!("1"),
// Ok(RedisRateLimitResult::RetryAt(retry_at, redis_count)) => todo!("2"),
// Ok(RedisRateLimitResult::RetryNever) => todo!("3"),
// Err(err) => {
// // don't let redis errors block our users!
// // error!(
// // // ?key, // TODO: this errors
// // ?err,
// // "unable to query rate limits. local cache available"
// // );
// todo!("4");
// }
// };
todo!("i think we can't move f around like this");
} else { } else {
// rate limit has enough headroom that it should be safe to do this in the background // rate limit has enough headroom that it should be safe to do this in the background
tokio::spawn(increment_redis_f); tokio::spawn(rate_limit_f);
}
}
let new_count = arc_key_count.fetch_add(count, Ordering::Release); Ok(DeferredRateLimitResult::Allowed)
}
todo!("check new_count"); }
// increment our local count and check what happened
// THERE IS A SMALL RACE HERE, but thats the price we pay for speed. a few queries leaking through is okay
// if it becomes a problem, we can lock, but i don't love that. it might be short enough to be okay though
todo!("write more");
} }
} }
} }