2022-08-16 01:50:56 +03:00
|
|
|
//#![warn(missing_docs)]
|
|
|
|
mod errors;
|
|
|
|
|
|
|
|
use anyhow::Context;
|
|
|
|
use bb8_redis::redis::pipe;
|
|
|
|
use std::ops::Add;
|
|
|
|
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
2022-08-30 23:01:42 +03:00
|
|
|
use tracing::trace;
|
2022-08-16 01:50:56 +03:00
|
|
|
|
|
|
|
pub use crate::errors::{RedisError, RedisErrorSink};
|
2022-08-17 00:43:39 +03:00
|
|
|
pub use bb8_redis::{bb8, redis, RedisConnectionManager};
|
2022-08-16 01:50:56 +03:00
|
|
|
|
|
|
|
pub type RedisPool = bb8::Pool<RedisConnectionManager>;
|
|
|
|
|
|
|
|
pub struct RedisRateLimit {
|
|
|
|
pool: RedisPool,
|
|
|
|
key_prefix: String,
|
2022-08-30 23:01:42 +03:00
|
|
|
/// The default maximum requests allowed in a period.
|
|
|
|
max_requests_per_period: u64,
|
|
|
|
/// seconds
|
|
|
|
period: f32,
|
2022-08-16 01:50:56 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
pub enum ThrottleResult {
|
|
|
|
Allowed,
|
|
|
|
RetryAt(Instant),
|
|
|
|
RetryNever,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl RedisRateLimit {
|
|
|
|
pub fn new(
|
|
|
|
pool: RedisPool,
|
|
|
|
app: &str,
|
|
|
|
label: &str,
|
2022-08-30 23:01:42 +03:00
|
|
|
max_requests_per_period: u64,
|
|
|
|
period: f32,
|
2022-08-16 01:50:56 +03:00
|
|
|
) -> Self {
|
|
|
|
let key_prefix = format!("{}:rrl:{}", app, label);
|
|
|
|
|
|
|
|
Self {
|
|
|
|
pool,
|
|
|
|
key_prefix,
|
2022-08-30 23:01:42 +03:00
|
|
|
max_requests_per_period,
|
2022-08-16 01:50:56 +03:00
|
|
|
period,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-30 23:01:42 +03:00
|
|
|
/// label might be an ip address or a user_key id.
|
|
|
|
/// if setting max_per_period, be sure to keep the period the same for all requests to this label
|
|
|
|
/// TODO:
|
2022-08-16 01:50:56 +03:00
|
|
|
pub async fn throttle_label(
|
|
|
|
&self,
|
|
|
|
label: &str,
|
|
|
|
max_per_period: Option<u64>,
|
|
|
|
count: u64,
|
|
|
|
) -> anyhow::Result<ThrottleResult> {
|
2022-08-30 23:01:42 +03:00
|
|
|
let max_per_period = max_per_period.unwrap_or(self.max_requests_per_period);
|
2022-08-16 01:50:56 +03:00
|
|
|
|
|
|
|
if max_per_period == 0 {
|
|
|
|
return Ok(ThrottleResult::RetryNever);
|
|
|
|
}
|
|
|
|
|
|
|
|
let now = SystemTime::now()
|
|
|
|
.duration_since(UNIX_EPOCH)
|
|
|
|
.context("cannot tell the time")?
|
2022-08-30 23:01:42 +03:00
|
|
|
.as_secs_f32();
|
2022-08-16 01:50:56 +03:00
|
|
|
|
2022-08-18 00:42:45 +03:00
|
|
|
// if self.period is 60, period_id will be the minute of the current time
|
|
|
|
let period_id = (now / self.period) % self.period;
|
2022-08-16 01:50:56 +03:00
|
|
|
|
|
|
|
let throttle_key = format!("{}:{}:{}", self.key_prefix, label, period_id);
|
|
|
|
|
|
|
|
let mut conn = self.pool.get().await?;
|
|
|
|
|
|
|
|
let x: Vec<u64> = pipe()
|
|
|
|
// we could get the key first, but that means an extra redis call for every check. this seems better
|
|
|
|
.incr(&throttle_key, count)
|
|
|
|
// set expiration the first time we set the key. ignore the result
|
2022-08-30 23:01:42 +03:00
|
|
|
.expire(&throttle_key, self.period as usize)
|
2022-08-18 00:42:45 +03:00
|
|
|
// .arg("NX") // TODO: this works in redis, but not elasticache
|
2022-08-16 01:50:56 +03:00
|
|
|
.ignore()
|
|
|
|
// do the query
|
|
|
|
.query_async(&mut *conn)
|
|
|
|
.await
|
|
|
|
.context("increment rate limit")?;
|
|
|
|
|
|
|
|
let new_count = x
|
|
|
|
.first()
|
|
|
|
.ok_or_else(|| anyhow::anyhow!("check rate limit result"))?;
|
|
|
|
|
|
|
|
if new_count > &max_per_period {
|
2022-08-30 23:01:42 +03:00
|
|
|
let seconds_left_in_period = self.period - (now % self.period);
|
2022-08-16 01:50:56 +03:00
|
|
|
|
2022-08-30 23:01:42 +03:00
|
|
|
let retry_at = Instant::now().add(Duration::from_secs_f32(seconds_left_in_period));
|
|
|
|
|
|
|
|
trace!(%label, ?retry_at, "rate limited");
|
2022-08-16 01:50:56 +03:00
|
|
|
|
|
|
|
return Ok(ThrottleResult::RetryAt(retry_at));
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(ThrottleResult::Allowed)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
pub async fn throttle(&self) -> anyhow::Result<ThrottleResult> {
|
|
|
|
self.throttle_label("", None, 1).await
|
|
|
|
}
|
|
|
|
}
|