ip, origin, referer, and user agent checks

This commit is contained in:
Bryan Stitt 2022-09-23 05:22:33 +00:00
parent d55aea2d98
commit 961ccf7cf2
13 changed files with 171 additions and 81 deletions

1
Cargo.lock generated

@ -5580,6 +5580,7 @@ dependencies = [
"handlebars",
"hashbrown",
"http",
"ipnet",
"metered",
"migration",
"moka",

@ -160,7 +160,7 @@ These are roughly in order of completition
- [-] opt-in debug mode that inspects responses for reverts and saves the request to the database for the user.
- [-] let them choose a % to log (or maybe x/second). someone like curve logging all reverts will be a BIG database very quickly
- this must be opt-in or spawned since it will slow things down and will make their calls less private
- [-] Api keys need option to lock to IP, cors header, referer, etc
- [-] Api keys need option to lock to IP, cors header, referer, user agent, etc
- [ ] active requests per second per api key
- [ ] distribution of methods per api key (eth_call, eth_getLogs, etc.)
- [-] add configurable size limits to all the Caches

@ -15,7 +15,11 @@ pub struct Model {
pub description: Option<String>,
pub private_txs: bool,
pub active: bool,
pub requests_per_minute: u64,
pub requests_per_minute: Option<u64>,
pub allowed_ips: Option<String>,
pub allowed_origins: Option<String>,
pub allowed_referers: Option<String>,
pub allowed_user_agents: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

@ -16,15 +16,20 @@ impl MigrationTrait for Migration {
.modify_column(
ColumnDef::new(UserKeys::RequestsPerMinute)
.big_unsigned()
.not_null(),
.null(),
)
// add a column for logging reverts in the RevertLogs table
.add_column(
ColumnDef::new(UserKeys::LogReverts)
.boolean()
.decimal_len(5, 4)
.not_null()
.default(false),
.default("0.0"),
)
// add columns for more advanced authorization
.add_column(ColumnDef::new(UserKeys::AllowedIps).text().null())
.add_column(ColumnDef::new(UserKeys::AllowedOrigins).text().null())
.add_column(ColumnDef::new(UserKeys::AllowedReferers).text().null())
.add_column(ColumnDef::new(UserKeys::AllowedUserAgents).text().null())
.to_owned(),
)
.await?;
@ -97,6 +102,7 @@ impl MigrationTrait for Migration {
pub enum UserKeys {
Table,
Id,
// we don't touch some of the columns
// UserId,
// ApiKey,
// Description,
@ -104,6 +110,10 @@ pub enum UserKeys {
// Active,
RequestsPerMinute,
LogReverts,
AllowedIps,
AllowedOrigins,
AllowedReferers,
AllowedUserAgents,
}
#[derive(Iden)]

@ -36,6 +36,7 @@ flume = "0.10.14"
futures = { version = "0.3.24", features = ["thread-pool"] }
hashbrown = { version = "0.12.3", features = ["serde"] }
http = "0.2.8"
ipnet = "*"
metered = { version = "0.9.0", features = ["serialize"] }
moka = { version = "0.9.4", default-features = false, features = ["future"] }
notify = "5.0.0"

@ -24,6 +24,7 @@ use futures::stream::FuturesUnordered;
use futures::stream::StreamExt;
use futures::Future;
use hashbrown::HashMap;
use ipnet::IpNet;
use metered::{metered, ErrorCount, HitCount, ResponseTime, Throughput};
use migration::{Migrator, MigratorTrait};
use moka::future::Cache;
@ -66,14 +67,16 @@ pub type AnyhowJoinHandle<T> = JoinHandle<anyhow::Result<T>>;
/// TODO: rename this?
pub struct UserKeyData {
pub user_key_id: u64,
/// if None, allow unlimited queries
pub user_count_per_period: Option<u64>,
/// if u64::MAX, allow unlimited queries
pub user_max_requests_per_period: Option<u64>,
/// if None, allow any Origin
pub allowed_origins: Option<Vec<String>>,
/// if None, allow any Referer
pub allowed_referer: Option<Referer>,
pub allowed_referers: Option<Vec<Referer>>,
/// if None, allow any UserAgent
pub allowed_user_agent: Option<UserAgent>,
/// if None, allow any IpAddr
pub allowed_ip: Option<IpAddr>,
pub allowed_user_agents: Option<Vec<UserAgent>>,
/// if None, allow any IP Address
pub allowed_ips: Option<Vec<IpNet>>,
}
/// The application

@ -210,10 +210,10 @@ mod tests {
let app_config = TopConfig {
app: AppConfig {
chain_id: 31337,
default_requests_per_minute: 6_000_000,
default_requests_per_minute: Some(6_000_000),
min_sum_soft_limit: 1,
min_synced_rpcs: 1,
public_rate_limit_per_minute: 6_000_000,
public_rate_limit_per_minute: 1_000_000,
response_cache_max_bytes: 10_usize.pow(7),
redirect_public_url: "example.com/".to_string(),
redirect_user_url: "example.com/{{user_id}}".to_string(),

@ -7,11 +7,6 @@ use tracing::info;
use uuid::Uuid;
use web3_proxy::users::new_api_key;
/// default to max int which the code sees as "unlimited" requests
fn default_rpm() -> u64 {
u64::MAX
}
#[derive(FromArgs, PartialEq, Debug, Eq)]
/// Create a new user and api key
#[argh(subcommand, name = "create_user")]
@ -29,9 +24,10 @@ pub struct CreateUserSubCommand {
/// If none given, one will be generated randomly.
api_key: Uuid,
#[argh(option, default = "default_rpm()")]
#[argh(option)]
/// maximum requests per minute
rpm: u64,
/// default to "None" which the code sees as "unlimited" requests
rpm: Option<u64>,
}
impl CreateUserSubCommand {

@ -54,8 +54,7 @@ pub struct AppConfig {
/// If none, the minimum * 2 is used
pub db_max_connections: Option<u32>,
pub influxdb_url: Option<String>,
#[serde(default = "default_default_requests_per_minute")]
pub default_requests_per_minute: u64,
pub default_requests_per_minute: Option<u64>,
pub invite_code: Option<String>,
#[serde(default = "default_min_sum_soft_limit")]
pub min_sum_soft_limit: u32,
@ -76,12 +75,6 @@ pub struct AppConfig {
pub redirect_user_url: String,
}
/// default to unlimited requests
/// TODO: pick a lower limit so we don't get DOSd
fn default_default_requests_per_minute() -> u64 {
u64::MAX
}
fn default_min_sum_soft_limit() -> u32 {
1
}

@ -1,17 +1,15 @@
use super::errors::FrontendErrorResponse;
use crate::app::{UserKeyData, Web3ProxyApp};
use anyhow::Context;
use axum::headers::{Referer, UserAgent};
use axum::headers::{Origin, Referer, UserAgent};
use deferred_rate_limiter::DeferredRateLimitResult;
use entities::user_keys;
use sea_orm::{
ColumnTrait, DatabaseConnection, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter,
QuerySelect,
};
use ipnet::IpNet;
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
use serde::Serialize;
use std::{net::IpAddr, sync::Arc};
use tokio::time::Instant;
use tracing::{error, trace, warn};
use tracing::{error, trace};
use uuid::Uuid;
#[derive(Debug)]
@ -31,6 +29,7 @@ pub enum RateLimitResult {
#[derive(Debug, Serialize)]
pub struct AuthorizedKey {
ip: IpAddr,
origin: Option<String>,
user_key_id: u64,
// TODO: what else?
}
@ -38,14 +37,64 @@ pub struct AuthorizedKey {
impl AuthorizedKey {
pub fn try_new(
ip: IpAddr,
user_data: UserKeyData,
origin: Option<Origin>,
referer: Option<Referer>,
user_agent: Option<UserAgent>,
user_data: UserKeyData,
) -> anyhow::Result<Self> {
warn!("todo: check referer and user_agent against user_data");
// check ip
match &user_data.allowed_ips {
None => {}
Some(allowed_ips) => {
if !allowed_ips.iter().any(|x| x.contains(&ip)) {
return Err(anyhow::anyhow!("IP is not allowed!"));
}
}
}
// check origin
// TODO: do this with the Origin type instead of a String?
let origin = origin.map(|x| x.to_string());
match (&origin, &user_data.allowed_origins) {
(None, None) => {}
(Some(_), None) => {}
(None, Some(_)) => return Err(anyhow::anyhow!("Origin required")),
(Some(origin), Some(allowed_origins)) => {
let origin = origin.to_string();
if !allowed_origins.contains(&origin) {
return Err(anyhow::anyhow!("IP is not allowed!"));
}
}
}
// check referer
match (referer, &user_data.allowed_referers) {
(None, None) => {}
(Some(_), None) => {}
(None, Some(_)) => return Err(anyhow::anyhow!("Referer required")),
(Some(referer), Some(allowed_referers)) => {
if !allowed_referers.contains(&referer) {
return Err(anyhow::anyhow!("Referer is not allowed!"));
}
}
}
// check user_agent
match (user_agent, &user_data.allowed_user_agents) {
(None, None) => {}
(Some(_), None) => {}
(None, Some(_)) => return Err(anyhow::anyhow!("User agent required")),
(Some(user_agent), Some(allowed_user_agents)) => {
if !allowed_user_agents.contains(&user_agent) {
return Err(anyhow::anyhow!("User agent is not allowed!"));
}
}
}
Ok(Self {
ip,
origin,
user_key_id: user_data.user_key_id,
})
}
@ -62,14 +111,12 @@ pub enum AuthorizedRequest {
}
impl AuthorizedRequest {
pub fn has_db(&self) -> bool {
let db_conn = match self {
Self::Internal(db_conn) => db_conn,
Self::Ip(db_conn, _) => db_conn,
Self::User(db_conn, _) => db_conn,
};
db_conn.is_some()
pub fn db_conn(&self) -> Option<&DatabaseConnection> {
match self {
Self::Internal(x) => x.as_ref(),
Self::Ip(x, _) => x.as_ref(),
Self::User(x, _) => x.as_ref(),
}
}
}
@ -96,6 +143,7 @@ pub async fn key_is_authorized(
app: &Web3ProxyApp,
user_key: Uuid,
ip: IpAddr,
origin: Option<Origin>,
referer: Option<Referer>,
user_agent: Option<UserAgent>,
) -> Result<AuthorizedRequest, FrontendErrorResponse> {
@ -110,7 +158,7 @@ pub async fn key_is_authorized(
x => unimplemented!("rate_limit_by_key shouldn't ever see these: {:?}", x),
};
let authorized_user = AuthorizedKey::try_new(ip, user_data, referer, user_agent)?;
let authorized_user = AuthorizedKey::try_new(ip, origin, referer, user_agent, user_data)?;
let db = app.db_conn.clone();
@ -159,50 +207,76 @@ impl Web3ProxyApp {
let db = self.db_conn.as_ref().context("no database")?;
/// helper enum for querying just a few columns instead of the entire table
/// TODO: query more! we need allowed ips, referers, and probably other things
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
Id,
RequestsPerMinute,
}
// TODO: join the user table to this to return the User? we don't always need it
match user_keys::Entity::find()
.select_only()
.column_as(user_keys::Column::Id, QueryAs::Id)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await?
{
Some((user_key_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
Some(user_key_model) => {
let allowed_ips: Option<Vec<IpNet>> =
user_key_model.allowed_ips.map(|allowed_ips| {
serde_json::from_str::<Vec<String>>(&allowed_ips)
.expect("allowed_ips should always parse")
.into_iter()
// TODO: try_for_each
.map(|x| {
x.parse::<IpNet>().expect("ip address should always parse")
})
.collect()
});
// TODO: should this be an Option<Vec<Origin>>?
let allowed_origins =
user_key_model.allowed_origins.map(|allowed_origins| {
serde_json::from_str::<Vec<String>>(&allowed_origins)
.expect("allowed_origins should always parse")
});
let allowed_referers =
user_key_model.allowed_referers.map(|allowed_referers| {
serde_json::from_str::<Vec<String>>(&allowed_referers)
.expect("allowed_referers should always parse")
.into_iter()
// TODO: try_for_each
.map(|x| {
x.parse::<Referer>().expect("referer should always parse")
})
.collect()
});
let allowed_user_agents =
user_key_model
.allowed_user_agents
.map(|allowed_user_agents| {
serde_json::from_str::<Vec<String>>(&allowed_user_agents)
.expect("allowed_user_agents should always parse")
.into_iter()
// TODO: try_for_each
.map(|x| {
x.parse::<UserAgent>()
.expect("user agent should always parse")
})
.collect()
});
Ok(UserKeyData {
user_key_id,
user_count_per_period,
allowed_ip: None,
allowed_referer: None,
allowed_user_agent: None,
user_key_id: user_key_model.id,
user_max_requests_per_period: user_key_model.requests_per_minute,
allowed_ips,
allowed_origins,
allowed_referers,
allowed_user_agents,
})
}
None => Ok(UserKeyData {
user_key_id: 0,
user_count_per_period: Some(0),
allowed_ip: None,
allowed_referer: None,
allowed_user_agent: None,
user_max_requests_per_period: Some(0),
allowed_ips: None,
allowed_origins: None,
allowed_referers: None,
allowed_user_agents: None,
}),
}
})
@ -219,7 +293,7 @@ impl Web3ProxyApp {
return Ok(RateLimitResult::UnknownKey);
}
let user_count_per_period = match user_data.user_count_per_period {
let user_max_requests_per_period = match user_data.user_max_requests_per_period {
None => return Ok(RateLimitResult::AllowedUser(user_data)),
Some(x) => x,
};
@ -227,7 +301,7 @@ impl Web3ProxyApp {
// user key is valid. now check rate limits
if let Some(rate_limiter) = &self.frontend_key_rate_limiter {
match rate_limiter
.throttle(user_key, Some(user_count_per_period), 1)
.throttle(user_key, Some(user_max_requests_per_period), 1)
.await
{
Ok(DeferredRateLimitResult::Allowed) => Ok(RateLimitResult::AllowedUser(user_data)),

@ -2,7 +2,7 @@ use super::authorization::{ip_is_authorized, key_is_authorized};
use super::errors::FrontendResult;
use crate::{app::Web3ProxyApp, jsonrpc::JsonRpcRequestEnum};
use axum::extract::Path;
use axum::headers::{Referer, UserAgent};
use axum::headers::{Origin, Referer, UserAgent};
use axum::TypedHeader;
use axum::{response::IntoResponse, Extension, Json};
use axum_client_ip::ClientIp;
@ -42,6 +42,7 @@ pub async fn user_proxy_web3_rpc(
Extension(app): Extension<Arc<Web3ProxyApp>>,
ClientIp(ip): ClientIp,
Json(payload): Json<JsonRpcRequestEnum>,
origin: Option<TypedHeader<Origin>>,
referer: Option<TypedHeader<Referer>>,
user_agent: Option<TypedHeader<UserAgent>>,
Path(user_key): Path<Uuid>,
@ -53,6 +54,7 @@ pub async fn user_proxy_web3_rpc(
&app,
user_key,
ip,
origin.map(|x| x.0),
referer.map(|x| x.0),
user_agent.map(|x| x.0),
)

@ -1,6 +1,6 @@
use super::authorization::{ip_is_authorized, key_is_authorized, AuthorizedRequest};
use super::errors::FrontendResult;
use axum::headers::{Referer, UserAgent};
use axum::headers::{Origin, Referer, UserAgent};
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
extract::Path,
@ -57,6 +57,7 @@ pub async fn user_websocket_handler(
Extension(app): Extension<Arc<Web3ProxyApp>>,
ClientIp(ip): ClientIp,
Path(user_key): Path<Uuid>,
origin: Option<TypedHeader<Origin>>,
referer: Option<TypedHeader<Referer>>,
user_agent: Option<TypedHeader<UserAgent>>,
ws_upgrade: Option<WebSocketUpgrade>,
@ -65,6 +66,7 @@ pub async fn user_websocket_handler(
&app,
user_key,
ip,
origin.map(|x| x.0),
referer.map(|x| x.0),
user_agent.map(|x| x.0),
)

@ -2,6 +2,7 @@ use super::connection::Web3Connection;
use super::provider::Web3Provider;
use crate::frontend::authorization::AuthorizedRequest;
use crate::metered::{JsonRpcErrorCount, ProviderErrorCount};
use anyhow::Context;
use ethers::providers::{HttpClientError, ProviderError, WsClientError};
use metered::metered;
use metered::HitCount;
@ -63,6 +64,8 @@ impl AuthorizedRequest {
where
T: Clone + fmt::Debug + serde::Serialize + Send + Sync + 'static,
{
let db_conn = self.db_conn().context("db_conn needed to save reverts")?;
todo!("save the revert to the database");
}
}
@ -158,7 +161,8 @@ impl OpenRequestHandle {
let error_handler = if let RequestErrorHandler::SaveReverts(save_chance) = error_handler
{
if ["eth_call", "eth_estimateGas"].contains(&method)
&& self.authorization.has_db()
&& self.authorization.db_conn().is_some()
&& save_chance != 0.0
&& (save_chance == 1.0
|| rand::thread_rng().gen_range(0.0..=1.0) <= save_chance)
{