requests_per_minute, not requests_per_second

This commit is contained in:
Bryan Stitt 2022-08-07 19:33:16 +00:00
parent 439e27101d
commit 36cf8af511
5 changed files with 68 additions and 29 deletions

@ -15,7 +15,7 @@ pub struct Model {
pub description: Option<String>,
pub private_txs: bool,
pub active: bool,
pub requests_per_second: u32,
pub requests_per_minute: u32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

@ -27,7 +27,7 @@ use tokio::sync::{broadcast, watch};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_stream::wrappers::{BroadcastStream, WatchStream};
use tracing::{debug, info, info_span, instrument, trace, warn, Instrument};
use tracing::{info, info_span, instrument, trace, warn, Instrument};
use crate::bb8_helpers;
use crate::config::AppConfig;

@ -43,6 +43,7 @@ impl CreateUserSubCommand {
let uk = user_keys::ActiveModel {
user_id: sea_orm::Set(u.id),
api_key: sea_orm::Set(new_api_key()),
requests_per_minute: sea_orm::Set(6_000_000),
..Default::default()
};

@ -12,13 +12,14 @@ use axum::{
Extension, Router,
};
use entities::user_keys;
use redis_cell_client::ThrottleResult;
use reqwest::StatusCode;
use sea_orm::{
ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect,
};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tracing::info;
use tracing::{debug, info};
use uuid::Uuid;
use crate::app::Web3ProxyApp;
@ -26,25 +27,38 @@ use crate::app::Web3ProxyApp;
use self::errors::handle_anyhow_error;
pub async fn rate_limit_by_ip(app: &Web3ProxyApp, ip: &IpAddr) -> Result<(), impl IntoResponse> {
let rate_limiter_key = format!("ip:{}", ip);
let rate_limiter_key = format!("ip-{}", ip);
// TODO: dry this up with rate_limit_by_key
if let Some(rate_limiter) = app.rate_limiter() {
if rate_limiter
match rate_limiter
.throttle_key(&rate_limiter_key, None, None, None)
.await
.is_err()
{
// TODO: set headers so they know when they can retry
// warn!(?ip, "public rate limit exceeded"); // this is too verbose, but a stat might be good
// TODO: use their id if possible
return Err(handle_anyhow_error(
Some(StatusCode::TOO_MANY_REQUESTS),
None,
anyhow::anyhow!(format!("too many requests from this ip: {}", ip)),
)
.await
.into_response());
Ok(ThrottleResult::Allowed) => {}
Ok(ThrottleResult::RetryAt(_retry_at)) => {
// TODO: set headers so they know when they can retry
debug!(?rate_limiter_key, "rate limit exceeded"); // this is too verbose, but a stat might be good
// TODO: use their id if possible
return Err(handle_anyhow_error(
Some(StatusCode::TOO_MANY_REQUESTS),
None,
anyhow::anyhow!(format!("too many requests from this ip: {}", ip)),
)
.await
.into_response());
}
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
return Err(handle_anyhow_error(
Some(StatusCode::INTERNAL_SERVER_ERROR),
None,
anyhow::anyhow!(format!("too many requests from this ip: {}", ip)),
)
.await
.into_response());
}
}
} else {
// TODO: if no redis, rate limit with a local cache?
@ -61,9 +75,11 @@ pub async fn rate_limit_by_key(
) -> Result<(), impl IntoResponse> {
let db = app.db_conn();
/// query just a few columns instead of the entire table
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
UserId,
RequestsPerMinute,
}
// query the db to make sure this key is active
@ -71,20 +87,22 @@ pub async fn rate_limit_by_key(
match user_keys::Entity::find()
.select_only()
.column_as(user_keys::Column::UserId, QueryAs::UserId)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await
{
Ok::<Option<i64>, _>(Some(_)) => {
Ok::<Option<(i64, u32)>, _>(Some((_user_id, user_count_per_period))) => {
// user key is valid
if let Some(rate_limiter) = app.rate_limiter() {
// TODO: check the db for this? maybe add to the find above with a join?
let user_count_per_period = 100_000;
// TODO: how does max burst actually work? what should it be?
let user_max_burst = user_count_per_period;
let user_period = 1;
let user_max_burst = user_count_per_period / 3;
let user_period = 60;
if rate_limiter
.throttle_key(
@ -164,7 +182,17 @@ pub async fn run(port: u16, proxy_app: Arc<Web3ProxyApp>) -> anyhow::Result<()>
axum::Server::bind(&addr)
// TODO: option to use with_connect_info. we want it in dev, but not when running behind a proxy, but not
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(signal_shutdown())
// .serve(app.into_make_service())
.await
.map_err(Into::into)
}
/// Tokio signal handler that will wait for a user to press CTRL+C.
/// We use this in our hyper `Server` method `with_graceful_shutdown`.
async fn signal_shutdown() {
tokio::signal::ctrl_c()
.await
.expect("expect tokio signal ctrl-c");
info!("signal shutdown");
}

@ -1,7 +1,7 @@
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
extract::Path,
response::IntoResponse,
response::{IntoResponse, Response},
Extension,
};
use axum_client_ip::ClientIp;
@ -27,20 +27,30 @@ use super::{rate_limit_by_ip, rate_limit_by_key};
pub async fn public_websocket_handler(
Extension(app): Extension<Arc<Web3ProxyApp>>,
ClientIp(ip): ClientIp,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
if let Err(x) = rate_limit_by_ip(&app, &ip).await {
return x.into_response();
}
ws: Option<WebSocketUpgrade>,
) -> Response {
match ws {
Some(ws) => {
if let Err(x) = rate_limit_by_ip(&app, &ip).await {
return x.into_response();
}
ws.on_upgrade(|socket| proxy_web3_socket(app, socket))
ws.on_upgrade(|socket| proxy_web3_socket(app, socket))
.into_response()
}
None => {
// this is not a websocket. give a friendly page
// TODO: make a friendly page
"hello, world".into_response()
}
}
}
pub async fn user_websocket_handler(
Extension(app): Extension<Arc<Web3ProxyApp>>,
ws: WebSocketUpgrade,
Path(user_key): Path<Uuid>,
) -> impl IntoResponse {
) -> Response {
if let Err(x) = rate_limit_by_key(&app, user_key).await {
return x.into_response();
}