diff --git a/web3_proxy/src/frontend/authorization.rs b/web3_proxy/src/frontend/authorization.rs new file mode 100644 index 00000000..b1963c37 --- /dev/null +++ b/web3_proxy/src/frontend/authorization.rs @@ -0,0 +1,243 @@ +use super::errors::FrontendErrorResponse; +use crate::app::{UserKeyData, Web3ProxyApp}; +use anyhow::Context; +use axum::headers::{Referer, UserAgent}; +use deferred_rate_limiter::DeferredRateLimitResult; +use entities::user_keys; +use sea_orm::{ + ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect, +}; +use serde::Serialize; +use std::{net::IpAddr, sync::Arc}; +use tokio::time::Instant; +use tracing::{error, trace, warn}; +use uuid::Uuid; + +#[derive(Debug)] +pub enum RateLimitResult { + /// contains the IP of the anonymous user + AllowedIp(IpAddr), + /// contains the user_key_id of an authenticated user + AllowedUser(UserKeyData), + /// contains the IP and retry_at of the anonymous user + RateLimitedIp(IpAddr, Option), + /// contains the user_key_id and retry_at of an authenticated user key + RateLimitedUser(UserKeyData, Option), + /// This key is not in our database. Deny access! + UnknownKey, +} + +#[derive(Debug, Serialize)] +pub struct AuthorizedKey { + ip: IpAddr, + user_key_id: u64, + // TODO: what else? +} + +impl AuthorizedKey { + pub fn try_new( + ip: IpAddr, + user_data: UserKeyData, + referer: Option, + user_agent: Option, + ) -> anyhow::Result { + warn!("todo: check referer and user_agent against user_data"); + + Ok(Self { + ip, + user_key_id: user_data.user_key_id, + }) + } +} + +#[derive(Debug, Serialize)] +pub enum AuthorizedRequest { + /// Request from the app itself + Internal, + /// Request from an anonymous IP address + Ip(IpAddr), + /// Request from an authenticated and authorized user + User(AuthorizedKey), +} + +pub async fn ip_is_authorized( + app: &Web3ProxyApp, + ip: IpAddr, +) -> Result { + // TODO: i think we could write an `impl From` for this + let ip = match app.rate_limit_by_ip(ip).await? { + RateLimitResult::AllowedIp(x) => x, + RateLimitResult::RateLimitedIp(x, retry_at) => { + return Err(FrontendErrorResponse::RateLimitedIp(x, retry_at)); + } + // TODO: don't panic. give the user an error + x => unimplemented!("rate_limit_by_ip shouldn't ever see these: {:?}", x), + }; + + Ok(AuthorizedRequest::Ip(ip)) +} + +pub async fn key_is_authorized( + app: &Web3ProxyApp, + user_key: Uuid, + ip: IpAddr, + referer: Option, + user_agent: Option, +) -> Result { + // check the rate limits. error if over the limit + let user_data = match app.rate_limit_by_key(user_key).await? { + RateLimitResult::AllowedUser(x) => x, + RateLimitResult::RateLimitedUser(x, retry_at) => { + return Err(FrontendErrorResponse::RateLimitedUser(x, retry_at)); + } + RateLimitResult::UnknownKey => return Err(FrontendErrorResponse::UnknownKey), + // TODO: don't panic. give the user an error + x => unimplemented!("rate_limit_by_key shouldn't ever see these: {:?}", x), + }; + + let authorized_user = AuthorizedKey::try_new(ip, user_data, referer, user_agent)?; + + Ok(AuthorizedRequest::User(authorized_user)) +} + +impl Web3ProxyApp { + pub async fn rate_limit_by_ip(&self, ip: IpAddr) -> anyhow::Result { + // TODO: dry this up with rate_limit_by_key + // TODO: have a local cache because if we hit redis too hard we get errors + // TODO: query redis in the background so that users don't have to wait on this network request + if let Some(rate_limiter) = &self.frontend_ip_rate_limiter { + match rate_limiter.throttle(ip, None, 1).await { + Ok(DeferredRateLimitResult::Allowed) => Ok(RateLimitResult::AllowedIp(ip)), + Ok(DeferredRateLimitResult::RetryAt(retry_at)) => { + // TODO: set headers so they know when they can retry + // TODO: debug or trace? + // this is too verbose, but a stat might be good + trace!(?ip, "rate limit exceeded until {:?}", retry_at); + Ok(RateLimitResult::RateLimitedIp(ip, Some(retry_at))) + } + Ok(DeferredRateLimitResult::RetryNever) => { + // TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely + trace!(?ip, "rate limit is 0"); + Ok(RateLimitResult::RateLimitedIp(ip, None)) + } + Err(err) => { + // internal error, not rate limit being hit + // TODO: i really want axum to do this for us in a single place. + error!(?err, "rate limiter is unhappy. allowing ip"); + Ok(RateLimitResult::AllowedIp(ip)) + } + } + } else { + // TODO: if no redis, rate limit with a local cache? "warn!" probably isn't right + todo!("no rate limiter"); + } + } + + // check the local cache for user data, or query the database + pub(crate) async fn user_data(&self, user_key: Uuid) -> anyhow::Result { + let user_data: Result<_, Arc> = self + .user_cache + .try_get_with(user_key, async move { + trace!(?user_key, "user_cache miss"); + + let db = self.db_conn.as_ref().context("no database")?; + + /// helper enum for querying just a few columns instead of the entire table + /// TODO: query more! we need allowed ips, referers, and probably other things + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Id, + RequestsPerMinute, + } + + // TODO: join the user table to this to return the User? we don't always need it + match user_keys::Entity::find() + .select_only() + .column_as(user_keys::Column::Id, QueryAs::Id) + .column_as( + user_keys::Column::RequestsPerMinute, + QueryAs::RequestsPerMinute, + ) + .filter(user_keys::Column::ApiKey.eq(user_key)) + .filter(user_keys::Column::Active.eq(true)) + .into_values::<_, QueryAs>() + .one(db) + .await? + { + Some((user_key_id, requests_per_minute)) => { + // TODO: add a column here for max, or is u64::MAX fine? + let user_count_per_period = if requests_per_minute == u64::MAX { + None + } else { + Some(requests_per_minute) + }; + + Ok(UserKeyData { + user_key_id, + user_count_per_period, + allowed_ip: None, + allowed_referer: None, + allowed_user_agent: None, + }) + } + None => Ok(UserKeyData { + user_key_id: 0, + user_count_per_period: Some(0), + allowed_ip: None, + allowed_referer: None, + allowed_user_agent: None, + }), + } + }) + .await; + + // TODO: i'm not actually sure about this expect + user_data.map_err(|err| Arc::try_unwrap(err).expect("this should be the only reference")) + } + + pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result { + let user_data = self.user_data(user_key).await?; + + if user_data.user_key_id == 0 { + return Ok(RateLimitResult::UnknownKey); + } + + let user_count_per_period = match user_data.user_count_per_period { + None => return Ok(RateLimitResult::AllowedUser(user_data)), + Some(x) => x, + }; + + // user key is valid. now check rate limits + if let Some(rate_limiter) = &self.frontend_key_rate_limiter { + match rate_limiter + .throttle(user_key, Some(user_count_per_period), 1) + .await + { + Ok(DeferredRateLimitResult::Allowed) => Ok(RateLimitResult::AllowedUser(user_data)), + Ok(DeferredRateLimitResult::RetryAt(retry_at)) => { + // TODO: set headers so they know when they can retry + // TODO: debug or trace? + // this is too verbose, but a stat might be good + // TODO: keys are secrets! use the id instead + trace!(?user_key, "rate limit exceeded until {:?}", retry_at); + Ok(RateLimitResult::RateLimitedUser(user_data, Some(retry_at))) + } + Ok(DeferredRateLimitResult::RetryNever) => { + // TODO: keys are secret. don't log them! + trace!(?user_key, "rate limit is 0"); + Ok(RateLimitResult::RateLimitedUser(user_data, None)) + } + Err(err) => { + // internal error, not rate limit being hit + // TODO: i really want axum to do this for us in a single place. + error!(?err, "rate limiter is unhappy. allowing ip"); + Ok(RateLimitResult::AllowedUser(user_data)) + } + } + } else { + // TODO: if no redis, rate limit with just a local cache? + // if we don't have redis, we probably don't have a db, so this probably will never happen + Err(anyhow::anyhow!("no redis. cannot rate limit")) + } + } +} diff --git a/web3_proxy/src/rpcs/connection.rs b/web3_proxy/src/rpcs/connection.rs index 7d4b0b96..6804af0e 100644 --- a/web3_proxy/src/rpcs/connection.rs +++ b/web3_proxy/src/rpcs/connection.rs @@ -531,7 +531,10 @@ impl Web3Connection { loop { // TODO: what should the max_wait be? - match self.wait_for_request_handle(None, Duration::from_secs(30)).await { + match self + .wait_for_request_handle(None, Duration::from_secs(30)) + .await + { Ok(active_request_handle) => { let block: Result, _> = active_request_handle .request( @@ -598,8 +601,9 @@ impl Web3Connection { } Web3Provider::Ws(provider) => { // todo: move subscribe_blocks onto the request handle? - let active_request_handle = - self.wait_for_request_handle(None, Duration::from_secs(30)).await; + let active_request_handle = self + .wait_for_request_handle(None, Duration::from_secs(30)) + .await; let mut stream = provider.subscribe_blocks().await?; drop(active_request_handle); @@ -697,8 +701,9 @@ impl Web3Connection { } Web3Provider::Ws(provider) => { // TODO: maybe the subscribe_pending_txs function should be on the active_request_handle - let active_request_handle = - self.wait_for_request_handle(None, Duration::from_secs(30)).await; + let active_request_handle = self + .wait_for_request_handle(None, Duration::from_secs(30)) + .await; let mut stream = provider.subscribe_pending_txs().await?; @@ -734,7 +739,7 @@ impl Web3Connection { let max_wait = Instant::now() + max_wait; loop { - let x = self.try_request_handle(authorized_request.clone()).await; + let x = self.try_request_handle(authorized_request).await; trace!(?x, "try_request_handle"); @@ -764,11 +769,12 @@ impl Web3Connection { #[instrument] pub async fn try_request_handle( self: &Arc, - authorized_user: Option<&Arc>, + authorized_request: Option<&Arc>, ) -> anyhow::Result { // check that we are connected if !self.has_provider().await { // TODO: emit a stat? + // TODO: wait until we have a provider? return Ok(OpenRequestResult::RetryNever); } @@ -794,7 +800,7 @@ impl Web3Connection { } }; - let handle = OpenRequestHandle::new(self.clone()); + let handle = OpenRequestHandle::new(self.clone(), authorized_request.cloned()); Ok(OpenRequestResult::Handle(handle)) } diff --git a/web3_proxy/src/rpcs/connections.rs b/web3_proxy/src/rpcs/connections.rs index e987ecd9..880f0f27 100644 --- a/web3_proxy/src/rpcs/connections.rs +++ b/web3_proxy/src/rpcs/connections.rs @@ -423,7 +423,7 @@ impl Web3Connections { // now that the rpcs are sorted, try to get an active request handle for one of them for rpc in synced_rpcs.into_iter() { // increment our connection counter - match rpc.try_request_handle(authorized_request.clone()).await { + match rpc.try_request_handle(authorized_request).await { Ok(OpenRequestResult::Handle(handle)) => { trace!("next server on {:?}: {:?}", self, rpc); return Ok(OpenRequestResult::Handle(handle)); diff --git a/web3_proxy/src/rpcs/request.rs b/web3_proxy/src/rpcs/request.rs index 6437bf2b..e342906f 100644 --- a/web3_proxy/src/rpcs/request.rs +++ b/web3_proxy/src/rpcs/request.rs @@ -1,14 +1,14 @@ use super::connection::Web3Connection; use super::provider::Web3Provider; +use crate::frontend::authorization::AuthorizedRequest; use crate::metered::{JsonRpcErrorCount, ProviderErrorCount}; use ethers::providers::{HttpClientError, ProviderError, WsClientError}; use metered::metered; use metered::HitCount; use metered::ResponseTime; use metered::Throughput; -use parking_lot::Mutex; use std::fmt; -use std::sync::atomic; +use std::sync::atomic::{self, AtomicBool, Ordering}; use std::sync::Arc; use tokio::time::{sleep, Duration, Instant}; use tracing::Level; @@ -26,9 +26,11 @@ pub enum OpenRequestResult { /// Make RPC requests through this handle and drop it when you are done. #[derive(Debug)] pub struct OpenRequestHandle { - conn: Mutex>>, + authorized_request: Arc, + conn: Arc, // TODO: this is the same metrics on the conn. use a reference? metrics: Arc, + used: AtomicBool, } pub enum RequestErrorHandler { @@ -51,7 +53,10 @@ impl From for RequestErrorHandler { #[metered(registry = OpenRequestHandleMetrics, visibility = pub)] impl OpenRequestHandle { - pub fn new(conn: Arc) -> Self { + pub fn new( + conn: Arc, + authorized_request: Option>, + ) -> Self { // TODO: take request_id as an argument? // TODO: attach a unique id to this? customer requests have one, but not internal queries // TODO: what ordering?! @@ -64,18 +69,21 @@ impl OpenRequestHandle { conn.total_requests.fetch_add(1, atomic::Ordering::Relaxed); let metrics = conn.open_request_handle_metrics.clone(); + let used = false.into(); - let conn = Mutex::new(Some(conn)); + let authorized_request = + authorized_request.unwrap_or_else(|| Arc::new(AuthorizedRequest::Internal)); - Self { conn, metrics } + Self { + authorized_request, + conn, + metrics, + used, + } } pub fn clone_connection(&self) -> Arc { - if let Some(conn) = self.conn.lock().as_ref() { - conn.clone() - } else { - unimplemented!("this shouldn't happen") - } + self.conn.clone() } /// Send a web3 request @@ -93,23 +101,22 @@ impl OpenRequestHandle { T: fmt::Debug + serde::Serialize + Send + Sync, R: serde::Serialize + serde::de::DeserializeOwned + fmt::Debug, { - let conn = self - .conn - .lock() - .take() - .expect("cannot use request multiple times"); + // ensure this function only runs once + if self.used.swap(true, Ordering::Release) { + unimplemented!("a request handle should only be used once"); + } // TODO: use tracing spans properly // TODO: requests from customers have request ids, but we should add // TODO: including params in this is way too verbose - trace!(rpc=%conn, %method, "request"); + trace!(rpc=%self.conn, %method, "request"); let mut provider = None; while provider.is_none() { - match conn.provider.read().await.as_ref() { + match self.conn.provider.read().await.as_ref() { None => { - warn!(rpc=%conn, "no provider!"); + warn!(rpc=%self.conn, "no provider!"); // TODO: how should this work? a reconnect should be in progress. but maybe force one now? // TODO: maybe use a watch handle? // TODO: sleep how long? subscribe to something instead? @@ -127,20 +134,18 @@ impl OpenRequestHandle { Web3Provider::Ws(provider) => provider.request(method, params).await, }; - conn.active_requests.fetch_sub(1, atomic::Ordering::AcqRel); - if let Err(err) = &response { match error_handler { RequestErrorHandler::ErrorLevel => { - error!(?err, %method, rpc=%conn, "bad response!"); + error!(?err, %method, rpc=%self.conn, "bad response!"); } RequestErrorHandler::DebugLevel => { - debug!(?err, %method, rpc=%conn, "bad response!"); + debug!(?err, %method, rpc=%self.conn, "bad response!"); } RequestErrorHandler::WarnLevel => { - warn!(?err, %method, rpc=%conn, "bad response!"); + warn!(?err, %method, rpc=%self.conn, "bad response!"); } - RequestErrorHandler::SaveReverts(chance) => { + RequestErrorHandler::SaveReverts(save_chance) => { // TODO: only set SaveReverts if this is an eth_call or eth_estimateGas? we'll need eth_sendRawTransaction somewhere else // TODO: logging every one is going to flood the database // TODO: have a percent chance to do this. or maybe a "logged reverts per second" @@ -154,7 +159,7 @@ impl OpenRequestHandle { debug!(%method, ?params, "TODO: save the request"); // TODO: don't do this on the hot path. spawn it } else { - debug!(?err, %method, rpc=%conn, "bad response!"); + debug!(?err, %method, rpc=%self.conn, "bad response!"); } } } @@ -166,7 +171,7 @@ impl OpenRequestHandle { debug!(%method, ?params, "TODO: save the request"); // TODO: don't do this on the hot path. spawn it } else { - debug!(?err, %method, rpc=%conn, "bad response!"); + debug!(?err, %method, rpc=%self.conn, "bad response!"); } } } @@ -177,8 +182,8 @@ impl OpenRequestHandle { } else { // TODO: i think ethers already has trace logging (and does it much more fancy) // TODO: opt-in response inspection to log reverts with their request. put into redis or what? - // trace!(rpc=%self.0, %method, ?response); - trace!(%method, rpc=%conn, "response"); + // trace!(rpc=%self.conn, %method, ?response); + trace!(%method, rpc=%self.conn, "response"); } response @@ -187,8 +192,8 @@ impl OpenRequestHandle { impl Drop for OpenRequestHandle { fn drop(&mut self) { - if let Some(conn) = self.conn.lock().take() { - conn.active_requests.fetch_sub(1, atomic::Ordering::AcqRel); - } + self.conn + .active_requests + .fetch_sub(1, atomic::Ordering::AcqRel); } }