web3-proxy/redis-rate-limiter/src/lib.rs

126 lines
3.9 KiB
Rust
Raw Normal View History

//#![warn(missing_docs)]
use anyhow::Context;
use std::ops::Add;
use tokio::time::{Duration, Instant};
2022-09-14 09:18:13 +03:00
pub use deadpool_redis::redis;
2022-09-15 20:57:24 +03:00
pub use deadpool_redis::{
Config as RedisConfig, Connection as RedisConnection, Manager as RedisManager,
Pool as RedisPool, PoolError as RedisPoolError, Runtime as DeadpoolRuntime,
2022-09-15 20:57:24 +03:00
};
2022-09-15 20:57:24 +03:00
#[derive(Clone)]
pub struct RedisRateLimiter {
key_prefix: String,
2022-08-30 23:01:42 +03:00
/// The default maximum requests allowed in a period.
2022-09-15 20:57:24 +03:00
pub max_requests_per_period: u64,
2022-08-30 23:01:42 +03:00
/// seconds
2022-09-15 20:57:24 +03:00
pub period: f32,
pool: RedisPool,
}
2022-09-15 20:57:24 +03:00
pub enum RedisRateLimitResult {
/// TODO: what is the inner value?
2022-09-15 20:57:24 +03:00
Allowed(u64),
/// TODO: what is the inner value?
2022-09-15 20:57:24 +03:00
RetryAt(Instant, u64),
RetryNever,
}
2022-09-15 20:57:24 +03:00
impl RedisRateLimiter {
pub fn new(
app: &str,
label: &str,
2022-08-30 23:01:42 +03:00
max_requests_per_period: u64,
period: f32,
2022-09-15 20:57:24 +03:00
pool: RedisPool,
) -> Self {
let key_prefix = format!("{}:rrl:{}", app, label);
Self {
pool,
key_prefix,
2022-08-30 23:01:42 +03:00
max_requests_per_period,
period,
}
}
pub fn now_as_secs(&self) -> f32 {
// TODO: if system time doesn't match redis, this won't work great
(chrono::Utc::now().timestamp_millis() as f32) / 1_000.0
}
pub fn period_id(&self, now_as_secs: f32) -> f32 {
(now_as_secs / self.period) % self.period
}
pub fn next_period(&self, now_as_secs: f32) -> Instant {
let seconds_left_in_period = self.period - (now_as_secs % self.period);
Instant::now().add(Duration::from_secs_f32(seconds_left_in_period))
}
2022-10-27 03:12:42 +03:00
/// label might be an ip address or a rpc_key id.
2022-08-30 23:01:42 +03:00
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
pub async fn throttle_label(
&self,
label: &str,
max_per_period: Option<u64>,
count: u64,
2022-09-15 20:57:24 +03:00
) -> anyhow::Result<RedisRateLimitResult> {
2022-08-30 23:01:42 +03:00
let max_per_period = max_per_period.unwrap_or(self.max_requests_per_period);
if max_per_period == 0 {
2022-09-15 20:57:24 +03:00
return Ok(RedisRateLimitResult::RetryNever);
}
let now = self.now_as_secs();
// if self.period is 60, period_id will be the minute of the current time
let period_id = self.period_id(now);
2022-09-15 20:57:24 +03:00
// TODO: include max per period in the throttle key?
let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id);
let mut conn = self
.pool
.get()
.await
.context("get redis connection for rate limits")?;
2022-09-20 09:56:24 +03:00
// TODO: at high concurency, this gives "connection reset by peer" errors. at least they are off the hot path
// TODO: only set expire if this is a new key
2022-09-24 06:59:21 +03:00
// TODO: automatic retry
let x: Vec<_> = redis::pipe()
.atomic()
// we could get the key first, but that means an extra redis call for every check. this seems better
.incr(&throttle_key, count)
// set expiration each time we set the key. ignore the result
.expire(&throttle_key, 1 + self.period as usize)
// TODO: NX will make it only set the expiration the first time. works in redis, but not elasticache
// .arg("NX")
.ignore()
// do the query
.query_async(&mut *conn)
.await
.context("cannot increment rate limit or set expiration")?;
2022-09-24 06:59:21 +03:00
let new_count: u64 = *x.first().expect("check redis");
2022-09-15 20:57:24 +03:00
if new_count > max_per_period {
// TODO: this might actually be early if we are way over the count
let retry_at = self.next_period(now);
2022-08-30 23:01:42 +03:00
2022-09-15 20:57:24 +03:00
Ok(RedisRateLimitResult::RetryAt(retry_at, new_count))
} else {
2022-09-15 20:57:24 +03:00
Ok(RedisRateLimitResult::Allowed(new_count))
}
}
#[inline]
2022-09-24 06:59:21 +03:00
pub async fn throttle(&self) -> anyhow::Result<RedisRateLimitResult> {
self.throttle_label("", None, 1).await
}
}