web3-proxy/redis-rate-limit/src/lib.rs

117 lines
3.5 KiB
Rust
Raw Normal View History

//#![warn(missing_docs)]
mod errors;
use anyhow::Context;
2022-09-14 09:18:13 +03:00
use deadpool_redis::redis::pipe;
use std::ops::Add;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::{Duration, Instant};
use tracing::{debug, trace};
2022-09-14 09:18:13 +03:00
pub use deadpool_redis::redis;
pub use deadpool_redis::{Config, Connection, Manager, Pool, Runtime};
2022-09-14 09:18:13 +03:00
// pub use crate::errors::{RedisError, RedisErrorSink};
// pub use bb8_redis::{bb8, redis, RedisConnectionManager};
pub struct RedisRateLimit {
2022-09-14 09:18:13 +03:00
pool: Pool,
key_prefix: String,
2022-08-30 23:01:42 +03:00
/// The default maximum requests allowed in a period.
max_requests_per_period: u64,
/// seconds
period: f32,
}
pub enum ThrottleResult {
Allowed,
RetryAt(Instant),
RetryNever,
}
impl RedisRateLimit {
pub fn new(
2022-09-14 09:18:13 +03:00
pool: Pool,
app: &str,
label: &str,
2022-08-30 23:01:42 +03:00
max_requests_per_period: u64,
period: f32,
) -> Self {
let key_prefix = format!("{}:rrl:{}", app, label);
Self {
pool,
key_prefix,
2022-08-30 23:01:42 +03:00
max_requests_per_period,
period,
}
}
2022-08-30 23:01:42 +03:00
/// label might be an ip address or a user_key id.
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
/// TODO:
pub async fn throttle_label(
&self,
label: &str,
max_per_period: Option<u64>,
count: u64,
) -> anyhow::Result<ThrottleResult> {
2022-08-30 23:01:42 +03:00
let max_per_period = max_per_period.unwrap_or(self.max_requests_per_period);
if max_per_period == 0 {
return Ok(ThrottleResult::RetryNever);
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("cannot tell the time")?
2022-08-30 23:01:42 +03:00
.as_secs_f32();
// if self.period is 60, period_id will be the minute of the current time
let period_id = (now / self.period) % self.period;
let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id);
2022-09-14 09:41:34 +03:00
let mut conn = self.pool.get().await.context("throttle")?;
// TODO: at high concurency, i think this is giving errors
// TODO: i'm starting to think that bb8 has a bug
let x: Vec<u64> = pipe()
// we could get the key first, but that means an extra redis call for every check. this seems better
.incr(&throttle_key, count)
// set expiration each time we set the key. ignore the result
2022-08-30 23:01:42 +03:00
.expire(&throttle_key, self.period as usize)
// TODO: NX will make it only set the expiration the first time. works in redis, but not elasticache
// .arg("NX")
.ignore()
// do the query
.query_async(&mut *conn)
.await
.context("increment rate limit")?;
2022-09-01 08:58:55 +03:00
let new_count = x.first().context("check rate limit result")?;
if new_count > &max_per_period {
2022-08-30 23:01:42 +03:00
let seconds_left_in_period = self.period - (now % self.period);
2022-08-30 23:01:42 +03:00
let retry_at = Instant::now().add(Duration::from_secs_f32(seconds_left_in_period));
debug!(%label, ?retry_at, "rate limited: {}/{}", new_count, max_per_period);
Ok(ThrottleResult::RetryAt(retry_at))
} else {
trace!(%label, "NOT rate limited: {}/{}", new_count, max_per_period);
Ok(ThrottleResult::Allowed)
}
}
#[inline]
pub async fn throttle(&self) -> anyhow::Result<ThrottleResult> {
self.throttle_label("", None, 1).await
}
2022-09-03 05:59:30 +03:00
pub fn max_requests_per_period(&self) -> u64 {
self.max_requests_per_period
}
}