save rpc_key_id or origin. needs some testing

This commit is contained in:
Bryan Stitt 2022-11-09 23:58:07 +00:00
parent c150ca612b
commit c35dd96cfb
13 changed files with 124 additions and 39 deletions

6
Cargo.lock generated
View File

@ -1354,7 +1354,7 @@ dependencies = [
[[package]]
name = "entities"
version = "0.9.0"
version = "0.10.0"
dependencies = [
"sea-orm",
"serde",
@ -2663,7 +2663,7 @@ dependencies = [
[[package]]
name = "migration"
version = "0.9.0"
version = "0.10.0"
dependencies = [
"sea-orm-migration",
"tokio",
@ -5555,7 +5555,7 @@ dependencies = [
[[package]]
name = "web3_proxy"
version = "0.9.1"
version = "0.10.0"
dependencies = [
"anyhow",
"arc-swap",

View File

@ -1,6 +1,6 @@
[package]
name = "entities"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
[lib]

View File

@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
pub struct Model {
#[sea_orm(primary_key)]
pub id: u64,
pub rpc_key_id: u64,
pub rpc_key_id: Option<u64>,
pub chain_id: u64,
pub method: String,
pub archive_request: bool,
@ -41,6 +41,7 @@ pub struct Model {
pub p90_response_bytes: u64,
pub p99_response_bytes: u64,
pub max_response_bytes: u64,
pub origin: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@ -1,6 +1,6 @@
[package]
name = "migration"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
publish = false

View File

@ -0,0 +1,47 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(RpcAccounting::Table)
.modify_column(
ColumnDef::new(RpcAccounting::RpcKeyId)
.big_unsigned()
.null(),
)
.add_column(ColumnDef::new(RpcAccounting::Origin).string().null())
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(RpcAccounting::Table)
.modify_column(
ColumnDef::new(RpcAccounting::RpcKeyId)
.big_unsigned()
.not_null(),
)
.drop_column(RpcAccounting::Origin)
.to_owned(),
)
.await
}
}
/// Learn more at https://docs.rs/sea-query#iden
#[derive(Iden)]
enum RpcAccounting {
Table,
RpcKeyId,
Origin,
}

View File

@ -1,6 +1,6 @@
[package]
name = "web3_proxy"
version = "0.9.1"
version = "0.10.0"
edition = "2021"
default-run = "web3_proxy"

View File

@ -34,6 +34,7 @@ use serde::Serialize;
use serde_json::json;
use std::fmt;
use std::net::IpAddr;
use std::num::NonZeroU64;
use std::str::FromStr;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::Arc;
@ -67,8 +68,8 @@ pub struct AuthorizationChecks {
/// TODO: do we need this? its on the authorization so probably not
pub user_id: u64,
/// database id of the rpc key
/// if this is 0, then this request is being rate limited by ip
pub rpc_key_id: u64,
/// if this is None, then this request is being rate limited by ip
pub rpc_key_id: Option<NonZeroU64>,
/// if None, allow unlimited queries. inherited from the user_tier
pub max_requests_per_period: Option<u64>,
// if None, allow unlimited concurrent requests. inherited from the user_tier
@ -113,7 +114,8 @@ pub struct Web3ProxyApp {
// TODO: this key should be our RpcSecretKey class, not Ulid
pub rpc_secret_key_cache:
Cache<Ulid, AuthorizationChecks, hashbrown::hash_map::DefaultHashBuilder>,
pub rpc_key_semaphores: Cache<u64, Arc<Semaphore>, hashbrown::hash_map::DefaultHashBuilder>,
pub rpc_key_semaphores:
Cache<NonZeroU64, Arc<Semaphore>, hashbrown::hash_map::DefaultHashBuilder>,
pub ip_semaphores: Cache<IpAddr, Arc<Semaphore>, hashbrown::hash_map::DefaultHashBuilder>,
pub bearer_token_semaphores:
Cache<String, Arc<Semaphore>, hashbrown::hash_map::DefaultHashBuilder>,

View File

@ -1,11 +1,13 @@
use crate::frontend::authorization::{Authorization, RequestMetadata};
use crate::jsonrpc::JsonRpcForwardedResponse;
use axum::headers::Origin;
use chrono::{TimeZone, Utc};
use derive_more::From;
use entities::rpc_accounting;
use hashbrown::HashMap;
use hdrhistogram::{Histogram, RecordError};
use sea_orm::{ActiveModelTrait, DatabaseConnection, DbErr};
use std::num::NonZeroU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
@ -32,12 +34,31 @@ pub struct ProxyResponseStat {
impl ProxyResponseStat {
/// TODO: think more about this. probably rename it
fn key(&self) -> ProxyResponseAggregateKey {
// include either the rpc_key_id or the origin
let (rpc_key_id, origin) = match (
self.authorization.checks.rpc_key_id,
&self.authorization.origin,
) {
(Some(rpc_key_id), _) => {
// TODO: allow the user to opt into saving the origin
(Some(rpc_key_id), None)
}
(None, Some(origin)) => {
// we save the origin for anonymous access
(None, Some(origin.clone()))
}
(None, None) => {
// TODO: what should we do here? log ip? i really don't want to save any ips
(None, None)
}
};
ProxyResponseAggregateKey {
rpc_key_id: self.authorization.checks.rpc_key_id,
// TODO: include Origin here?
method: self.method.clone(),
archive_request: self.archive_request,
error_response: self.error_response,
method: self.method.clone(),
origin,
rpc_key_id,
}
}
}
@ -66,10 +87,12 @@ impl Default for ProxyResponseHistograms {
// TODO: think more about if we should include IP address in this
#[derive(Clone, From, Hash, PartialEq, Eq)]
struct ProxyResponseAggregateKey {
rpc_key_id: u64,
method: String,
error_response: bool,
archive_request: bool,
error_response: bool,
rpc_key_id: Option<NonZeroU64>,
method: String,
/// TODO: should this be Origin or String?
origin: Option<Origin>,
}
#[derive(Default)]
@ -182,7 +205,8 @@ impl ProxyResponseAggregate {
let aggregated_stat_model = rpc_accounting::ActiveModel {
id: sea_orm::NotSet,
// origin: sea_orm::Set(key.authorization.origin.to_string()),
rpc_key_id: sea_orm::Set(key.rpc_key_id),
rpc_key_id: sea_orm::Set(key.rpc_key_id.map(Into::into)),
origin: sea_orm::Set(key.origin.map(|x| x.to_string())),
chain_id: sea_orm::Set(chain_id),
method: sea_orm::Set(key.method),
archive_request: sea_orm::Set(key.archive_request),

View File

@ -373,33 +373,34 @@ impl Web3ProxyApp {
}
}
/// Limit the number of concurrent requests from the given key address.
/// Limit the number of concurrent requests from the given rpc key.
#[instrument(level = "trace")]
pub async fn authorization_checks_semaphore(
pub async fn rpc_key_semaphore(
&self,
authorization_checks: &AuthorizationChecks,
) -> anyhow::Result<Option<OwnedSemaphorePermit>> {
if let Some(max_concurrent_requests) = authorization_checks.max_concurrent_requests {
let rpc_key_id = authorization_checks.rpc_key_id.context("no rpc_key_id")?;
let semaphore = self
.rpc_key_semaphores
.get_with(authorization_checks.rpc_key_id, async move {
.get_with(rpc_key_id, async move {
let s = Semaphore::new(max_concurrent_requests as usize);
trace!(
"new semaphore for rpc_key_id {}",
authorization_checks.rpc_key_id
);
trace!("new semaphore for rpc_key_id {}", rpc_key_id);
Arc::new(s)
})
.await;
// if semaphore.available_permits() == 0 {
// // TODO: concurrent limit hit! emit a stat
// // TODO: concurrent limit hit! emit a stat? this has a race condition though.
// // TODO: maybe have a stat on how long we wait to acquire the semaphore instead?
// }
let semaphore_permit = semaphore.acquire_owned().await?;
Ok(Some(semaphore_permit))
} else {
// unlimited requests allowed
Ok(None)
}
}
@ -645,9 +646,12 @@ impl Web3ProxyApp {
None
};
let rpc_key_id =
Some(rpc_key_model.id.try_into().expect("db ids are never 0"));
Ok(AuthorizationChecks {
user_id: rpc_key_model.user_id,
rpc_key_id: rpc_key_model.id,
rpc_key_id,
allowed_ips,
allowed_origins,
allowed_referers,
@ -679,15 +683,13 @@ impl Web3ProxyApp {
let authorization_checks = self.authorization_checks(rpc_key).await?;
// if no rpc_key_id matching the given rpc was found, then we can't rate limit by key
if authorization_checks.rpc_key_id == 0 {
if authorization_checks.rpc_key_id.is_none() {
return Ok(RateLimitResult::UnknownKey);
}
// only allow this rpc_key to run a limited amount of concurrent requests
// TODO: rate limit should be BEFORE the semaphore!
let semaphore = self
.authorization_checks_semaphore(&authorization_checks)
.await?;
let semaphore = self.rpc_key_semaphore(&authorization_checks).await?;
let authorization = Authorization::try_new(
authorization_checks,

View File

@ -151,12 +151,13 @@ impl IntoResponse for FrontendErrorResponse {
};
// create a string with either the IP or the rpc_key_id
let msg = if authorization.checks.rpc_key_id == 0 {
let msg = if authorization.checks.rpc_key_id.is_none() {
format!("too many requests from {}.{}", authorization.ip, retry_msg)
} else {
format!(
"too many requests from rpc key #{}.{}",
authorization.checks.rpc_key_id, retry_msg
authorization.checks.rpc_key_id.unwrap(),
retry_msg
)
};

View File

@ -124,15 +124,15 @@ pub async fn websocket_handler_with_key(
"redirect_rpc_key_url not set. only websockets work here"
)
.into()),
(Some(redirect_public_url), _, 0) => {
(Some(redirect_public_url), _, None) => {
Ok(Redirect::to(redirect_public_url).into_response())
}
(_, Some(redirect_rpc_key_url), rpc_key_id) => {
let reg = Handlebars::new();
if authorization.checks.rpc_key_id == 0 {
if authorization.checks.rpc_key_id.is_none() {
// TODO: i think this is impossible
Err(anyhow::anyhow!("this page is for rpcs").into())
Err(anyhow::anyhow!("only authenticated websockets work here").into())
} else {
let redirect_rpc_key_url = reg
.render_template(

View File

@ -27,7 +27,7 @@ use itertools::Itertools;
use redis_rate_limiter::redis::AsyncCommands;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
TransactionTrait,
TransactionTrait, TryIntoModel,
};
use serde::Deserialize;
use serde_json::json;
@ -670,7 +670,7 @@ pub async fn rpc_keys_management(
uk
};
let uk: rpc_key::Model = uk.try_into()?;
let uk = uk.try_into_model()?;
Ok(Json(uk).into_response())
}

View File

@ -13,8 +13,8 @@ use metered::HitCount;
use metered::ResponseTime;
use metered::Throughput;
use rand::Rng;
use sea_orm::ActiveEnum;
use sea_orm::ActiveModelTrait;
use sea_orm::{ActiveEnum};
use serde_json::json;
use std::fmt;
use std::sync::atomic::{self, AtomicBool, Ordering};
@ -82,6 +82,14 @@ impl Authorization {
method: Method,
params: EthCallFirstParams,
) -> anyhow::Result<()> {
let rpc_key_id = match self.checks.rpc_key_id {
Some(rpc_key_id) => rpc_key_id.into(),
None => {
trace!(?self, "cannot save revert without rpc_key_id");
return Ok(());
}
};
let db_conn = self.db_conn.as_ref().context("no database connection")?;
// TODO: should the database set the timestamp?
@ -96,7 +104,7 @@ impl Authorization {
let call_data = params.data.map(|x| format!("{}", x));
let rl = revert_log::ActiveModel {
rpc_key_id: sea_orm::Set(self.checks.rpc_key_id),
rpc_key_id: sea_orm::Set(rpc_key_id),
method: sea_orm::Set(method),
to: sea_orm::Set(to),
call_data: sea_orm::Set(call_data),