From 90d3371eee892dc7d56b92034e326bf3d733e290 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Wed, 18 Jan 2023 16:17:43 -0800 Subject: [PATCH] improved rate limiting on websockets --- TODO.md | 4 +- web3_proxy/src/app/mod.rs | 8 +- web3_proxy/src/app_stats.rs | 2 +- web3_proxy/src/bin/web3_proxy_cli/daemon.rs | 2 - web3_proxy/src/frontend/authorization.rs | 35 +++- web3_proxy/src/frontend/errors.rs | 27 +-- web3_proxy/src/frontend/rpc_proxy_ws.rs | 174 +++++++++++++------- web3_proxy/src/rpcs/request.rs | 18 +- 8 files changed, 175 insertions(+), 95 deletions(-) diff --git a/TODO.md b/TODO.md index 081844ac..2f4b8754 100644 --- a/TODO.md +++ b/TODO.md @@ -304,6 +304,7 @@ These are not yet ordered. There might be duplicates. We might not actually need - [x] standalone healthcheck daemon (sentryd) - [x] status page should show version - [x] combine the proxy and cli into one bin +- [x] improve rate limiting on websockets - [-] proxy mode for benchmarking all backends - [-] proxy mode for sending to multiple backends - [-] let users choose a % of reverts to log (or maybe x/second). someone like curve logging all reverts will be a BIG database very quickly @@ -514,7 +515,8 @@ in another repo: event subscriber - [ ] if the call is something simple like "symbol" or "decimals", cache that too. though i think this could bite us. - [ ] add a subscription that returns the head block number and hash but nothing else - [ ] if chain split detected, what should we do? don't send transactions? -- [ ] archive check works well for local servers, but public nodes (especially on other chains) seem to give unreliable results. likely because of load balancers. maybe have a "max block data limit" +- [ ] archive check works well for local servers, but public nodes (especially on other chains) seem to give unreliable results. likely because of load balancers. + - [x] configurable block data limit until better checks - [ ] https://docs.rs/derive_builder/latest/derive_builder/ - [ ] Detect orphaned transactions - [ ] https://crates.io/crates/reqwest-middleware easy retry with exponential back off diff --git a/web3_proxy/src/app/mod.rs b/web3_proxy/src/app/mod.rs index 44df90af..1e3a327a 100644 --- a/web3_proxy/src/app/mod.rs +++ b/web3_proxy/src/app/mod.rs @@ -4,7 +4,7 @@ mod ws; use crate::app_stats::{ProxyResponseStat, StatEmitter, Web3ProxyStat}; use crate::block_number::{block_needed, BlockNeeded}; use crate::config::{AppConfig, TopConfig}; -use crate::frontend::authorization::{Authorization, RequestMetadata}; +use crate::frontend::authorization::{Authorization, RequestMetadata, RpcSecretKey}; use crate::frontend::errors::FrontendErrorResponse; use crate::frontend::rpc_proxy_ws::ProxyMode; use crate::jsonrpc::{ @@ -136,12 +136,14 @@ pub type AnyhowJoinHandle = JoinHandle>; #[derive(Clone, Debug, Default, From)] pub struct AuthorizationChecks { - /// database id of the primary user. + /// database id of the primary user. 0 if anon /// TODO: do we need this? its on the authorization so probably not pub user_id: u64, + /// the key used (if any) + pub rpc_secret_key: Option, /// database id of the rpc key /// if this is None, then this request is being rate limited by ip - pub rpc_key_id: Option, + pub rpc_secret_key_id: Option, /// if None, allow unlimited queries. inherited from the user_tier pub max_requests_per_period: Option, // if None, allow unlimited concurrent requests. inherited from the user_tier diff --git a/web3_proxy/src/app_stats.rs b/web3_proxy/src/app_stats.rs index 204effd5..681dfcea 100644 --- a/web3_proxy/src/app_stats.rs +++ b/web3_proxy/src/app_stats.rs @@ -36,7 +36,7 @@ impl ProxyResponseStat { fn key(&self) -> ProxyResponseAggregateKey { // include either the rpc_key_id or the origin let (mut rpc_key_id, origin) = match ( - self.authorization.checks.rpc_key_id, + self.authorization.checks.rpc_secret_key_id, &self.authorization.origin, ) { (Some(rpc_key_id), _) => { diff --git a/web3_proxy/src/bin/web3_proxy_cli/daemon.rs b/web3_proxy/src/bin/web3_proxy_cli/daemon.rs index 09998ea4..69d0e2c7 100644 --- a/web3_proxy/src/bin/web3_proxy_cli/daemon.rs +++ b/web3_proxy/src/bin/web3_proxy_cli/daemon.rs @@ -204,7 +204,6 @@ mod tests { disabled: false, display_name: None, url: anvil.endpoint(), - backup: None, block_data_limit: None, soft_limit: 100, hard_limit: None, @@ -219,7 +218,6 @@ mod tests { disabled: false, display_name: None, url: anvil.ws_endpoint(), - backup: None, block_data_limit: None, soft_limit: 100, hard_limit: None, diff --git a/web3_proxy/src/frontend/authorization.rs b/web3_proxy/src/frontend/authorization.rs index f98cf7d0..c04ba8c2 100644 --- a/web3_proxy/src/frontend/authorization.rs +++ b/web3_proxy/src/frontend/authorization.rs @@ -660,13 +660,11 @@ impl Web3ProxyApp { let db_replica = self.db_replica().context("Getting database connection")?; - let rpc_secret_key: Uuid = rpc_secret_key.into(); - // TODO: join the user table to this to return the User? we don't always need it // TODO: join on secondary users // TODO: join on user tier match rpc_key::Entity::find() - .filter(rpc_key::Column::SecretKey.eq(rpc_secret_key)) + .filter(rpc_key::Column::SecretKey.eq(::from(rpc_secret_key))) .filter(rpc_key::Column::Active.eq(true)) .one(db_replica.conn()) .await? @@ -741,7 +739,8 @@ impl Web3ProxyApp { Ok(AuthorizationChecks { user_id: rpc_key_model.user_id, - rpc_key_id, + rpc_secret_key: Some(rpc_secret_key), + rpc_secret_key_id: rpc_key_id, allowed_ips, allowed_origins, allowed_referers, @@ -774,7 +773,7 @@ impl Web3ProxyApp { let authorization_checks = self.authorization_checks(rpc_key).await?; // if no rpc_key_id matching the given rpc was found, then we can't rate limit by key - if authorization_checks.rpc_key_id.is_none() { + if authorization_checks.rpc_secret_key_id.is_none() { return Ok(RateLimitResult::UnknownKey); } @@ -845,3 +844,29 @@ impl Web3ProxyApp { } } } + +impl Authorization { + pub async fn check_again( + &self, + app: &Arc, + ) -> Result<(Arc, Option), FrontendErrorResponse> { + // TODO: we could probably do this without clones. but this is easy + let (a, s) = if let Some(rpc_secret_key) = self.checks.rpc_secret_key { + key_is_authorized( + app, + rpc_secret_key, + self.ip, + self.origin.clone(), + self.referer.clone(), + self.user_agent.clone(), + ) + .await? + } else { + ip_is_authorized(app, self.ip, self.origin.clone()).await? + }; + + let a = Arc::new(a); + + Ok((a, s)) + } +} diff --git a/web3_proxy/src/frontend/errors.rs b/web3_proxy/src/frontend/errors.rs index 30ee053f..22f048ee 100644 --- a/web3_proxy/src/frontend/errors.rs +++ b/web3_proxy/src/frontend/errors.rs @@ -35,7 +35,6 @@ pub enum FrontendErrorResponse { NotFound, RateLimited(Authorization, Option), Redis(RedisError), - Response(Response), /// simple way to return an error message to the user and an anyhow to our logs StatusCode(StatusCode, String, Option), /// TODO: what should be attached to the timout? @@ -44,11 +43,9 @@ pub enum FrontendErrorResponse { UnknownKey, } -impl IntoResponse for FrontendErrorResponse { - fn into_response(self) -> Response { - // TODO: include the request id in these so that users can give us something that will point to logs - // TODO: status code is in the jsonrpc response and is also the first item in the tuple. DRY - let (status_code, response) = match self { +impl FrontendErrorResponse { + pub fn into_response_parts(self) -> (StatusCode, JsonRpcForwardedResponse) { + match self { Self::AccessDenied => { // TODO: attach something to this trace. probably don't include much in the message though. don't want to leak creds by accident trace!("access denied"); @@ -174,12 +171,12 @@ impl IntoResponse for FrontendErrorResponse { }; // create a string with either the IP or the rpc_key_id - let msg = if authorization.checks.rpc_key_id.is_none() { + let msg = if authorization.checks.rpc_secret_key_id.is_none() { format!("too many requests from {}.{}", authorization.ip, retry_msg) } else { format!( "too many requests from rpc key #{}.{}", - authorization.checks.rpc_key_id.unwrap(), + authorization.checks.rpc_secret_key_id.unwrap(), retry_msg ) }; @@ -204,10 +201,6 @@ impl IntoResponse for FrontendErrorResponse { ), ) } - Self::Response(r) => { - debug_assert_ne!(r.status(), StatusCode::OK); - return r; - } Self::SemaphoreAcquireError(err) => { warn!("semaphore acquire err={:?}", err); ( @@ -274,7 +267,15 @@ impl IntoResponse for FrontendErrorResponse { None, ), ), - }; + } + } +} + +impl IntoResponse for FrontendErrorResponse { + fn into_response(self) -> Response { + // TODO: include the request id in these so that users can give us something that will point to logs + // TODO: status code is in the jsonrpc response and is also the first item in the tuple. DRY + let (status_code, response) = self.into_response_parts(); (status_code, Json(response)).into_response() } diff --git a/web3_proxy/src/frontend/rpc_proxy_ws.rs b/web3_proxy/src/frontend/rpc_proxy_ws.rs index 23516738..ae6b700b 100644 --- a/web3_proxy/src/frontend/rpc_proxy_ws.rs +++ b/web3_proxy/src/frontend/rpc_proxy_ws.rs @@ -32,6 +32,7 @@ use serde_json::json; use serde_json::value::to_raw_value; use std::sync::Arc; use std::{str::from_utf8_mut, sync::atomic::AtomicUsize}; +use tokio::sync::{broadcast, OwnedSemaphorePermit, RwLock}; #[derive(Copy, Clone)] pub enum ProxyMode { @@ -52,7 +53,7 @@ pub async fn websocket_handler( origin: Option>, ws_upgrade: Option, ) -> FrontendResult { - _websocket_handler(ProxyMode::Fastest(1), app, ip, origin, ws_upgrade).await + _websocket_handler(ProxyMode::Best, app, ip, origin, ws_upgrade).await } /// Public entrypoint for WebSocket JSON-RPC requests that uses all synced servers. @@ -226,7 +227,7 @@ async fn _websocket_handler_with_key( match ( &app.config.redirect_public_url, &app.config.redirect_rpc_key_url, - authorization.checks.rpc_key_id, + authorization.checks.rpc_secret_key_id, ) { (None, None, _) => Err(FrontendErrorResponse::StatusCode( StatusCode::BAD_REQUEST, @@ -239,7 +240,7 @@ async fn _websocket_handler_with_key( (_, Some(redirect_rpc_key_url), rpc_key_id) => { let reg = Handlebars::new(); - if authorization.checks.rpc_key_id.is_none() { + if authorization.checks.rpc_secret_key_id.is_none() { // i don't think this is possible Err(FrontendErrorResponse::StatusCode( StatusCode::UNAUTHORIZED, @@ -298,9 +299,20 @@ async fn handle_socket_payload( payload: &str, response_sender: &flume::Sender, subscription_count: &AtomicUsize, - subscriptions: &mut HashMap, + subscriptions: Arc>>, proxy_mode: ProxyMode, -) -> Message { +) -> (Message, Option) { + let (authorization, semaphore) = match authorization.check_again(&app).await { + Ok((a, s)) => (a, s), + Err(err) => { + let (_, err) = err.into_response_parts(); + + let err = serde_json::to_string(&err).expect("to_string should always work here"); + + return (Message::Text(err), None); + } + }; + // TODO: do any clients send batches over websockets? let (id, response) = match serde_json::from_str::(payload) { Ok(json_request) => { @@ -322,7 +334,9 @@ async fn handle_socket_payload( { Ok((handle, response)) => { // TODO: better key - subscriptions.insert( + let mut x = subscriptions.write().await; + + x.insert( response .result .as_ref() @@ -346,8 +360,10 @@ async fn handle_socket_payload( let subscription_id = json_request.params.unwrap().to_string(); + let mut x = subscriptions.write().await; + // TODO: is this the right response? - let partial_response = match subscriptions.remove(&subscription_id) { + let partial_response = match x.remove(&subscription_id) { None => false, Some(handle) => { handle.abort(); @@ -355,6 +371,8 @@ async fn handle_socket_payload( } }; + drop(x); + let response = JsonRpcForwardedResponse::from_value(json!(partial_response), id.clone()); @@ -409,9 +427,7 @@ async fn handle_socket_payload( } }; - // TODO: what error should this be? - - Message::Text(response_str) + (Message::Text(response_str), semaphore) } async fn read_web3_socket( @@ -421,61 +437,97 @@ async fn read_web3_socket( response_sender: flume::Sender, proxy_mode: ProxyMode, ) { - let mut subscriptions = HashMap::new(); - let subscription_count = AtomicUsize::new(1); + // TODO: need a concurrent hashmap + let subscriptions = Arc::new(RwLock::new(HashMap::new())); + let subscription_count = Arc::new(AtomicUsize::new(1)); - while let Some(Ok(msg)) = ws_rx.next().await { - // TODO: spawn this? - // new message from our client. forward to a backend and then send it through response_tx - let response_msg = match msg { - Message::Text(payload) => { - handle_socket_payload( - app.clone(), - &authorization, - &payload, - &response_sender, - &subscription_count, - &mut subscriptions, - proxy_mode, - ) - .await + let (close_sender, mut close_receiver) = broadcast::channel(1); + + loop { + tokio::select! { + msg = ws_rx.next() => { + if let Some(Ok(msg)) = msg { + // spawn so that we can serve responses from this loop even faster + // TODO: only do these clones if the msg is text/binary? + let close_sender = close_sender.clone(); + let app = app.clone(); + let authorization = authorization.clone(); + let response_sender = response_sender.clone(); + let subscriptions = subscriptions.clone(); + let subscription_count = subscription_count.clone(); + + let f = async move { + let mut _semaphore = None; + + // new message from our client. forward to a backend and then send it through response_tx + let response_msg = match msg { + Message::Text(payload) => { + let (msg, s) = handle_socket_payload( + app.clone(), + &authorization, + &payload, + &response_sender, + &subscription_count, + subscriptions, + proxy_mode, + ) + .await; + + _semaphore = s; + + msg + } + Message::Ping(x) => { + trace!("ping: {:?}", x); + Message::Pong(x) + } + Message::Pong(x) => { + trace!("pong: {:?}", x); + return; + } + Message::Close(_) => { + info!("closing websocket connection"); + // TODO: do something to close subscriptions? + let _ = close_sender.send(true); + return; + } + Message::Binary(mut payload) => { + let payload = from_utf8_mut(&mut payload).unwrap(); + + let (msg, s) = handle_socket_payload( + app.clone(), + &authorization, + payload, + &response_sender, + &subscription_count, + subscriptions, + proxy_mode, + ) + .await; + + _semaphore = s; + + msg + } + }; + + if response_sender.send_async(response_msg).await.is_err() { + let _ = close_sender.send(true); + return; + }; + + _semaphore = None; + }; + + tokio::spawn(f); + } else { + break; + } } - Message::Ping(x) => { - trace!("ping: {:?}", x); - Message::Pong(x) - } - Message::Pong(x) => { - trace!("pong: {:?}", x); - continue; - } - Message::Close(_) => { - info!("closing websocket connection"); + _ = close_receiver.recv() => { break; } - Message::Binary(mut payload) => { - // TODO: poke rate limit for the user/ip - let payload = from_utf8_mut(&mut payload).unwrap(); - - handle_socket_payload( - app.clone(), - &authorization, - payload, - &response_sender, - &subscription_count, - &mut subscriptions, - proxy_mode, - ) - .await - } - }; - - match response_sender.send_async(response_msg).await { - Ok(_) => {} - Err(err) => { - error!("{}", err); - break; - } - }; + } } } diff --git a/web3_proxy/src/rpcs/request.rs b/web3_proxy/src/rpcs/request.rs index 7db16fd5..8cf22bbf 100644 --- a/web3_proxy/src/rpcs/request.rs +++ b/web3_proxy/src/rpcs/request.rs @@ -84,7 +84,7 @@ impl Authorization { method: Method, params: EthCallFirstParams, ) -> anyhow::Result<()> { - let rpc_key_id = match self.checks.rpc_key_id { + let rpc_key_id = match self.checks.rpc_secret_key_id { Some(rpc_key_id) => rpc_key_id.into(), None => { // // trace!(?self, "cannot save revert without rpc_key_id"); @@ -240,14 +240,14 @@ impl OpenRequestHandle { Web3Provider::Ws(provider) => provider.request(method, params).await, }; - // TODO: i think ethers already has trace logging (and does it much more fancy) - trace!( - "response from {} for {} {:?}: {:?}", - self.conn, - method, - params, - response, - ); + // // TODO: i think ethers already has trace logging (and does it much more fancy) + // trace!( + // "response from {} for {} {:?}: {:?}", + // self.conn, + // method, + // params, + // response, + // ); if let Err(err) = &response { // only save reverts for some types of calls