From fbe0ecfbfff73a5cf8845d2628747f248ac2e507 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Wed, 31 May 2023 16:05:41 -0700 Subject: [PATCH] rate limit on websockets --- web3_proxy/src/app/mod.rs | 2 +- web3_proxy/src/app/ws.rs | 118 +++++++++++++++++++++--- web3_proxy/src/frontend/rpc_proxy_ws.rs | 40 +++----- 3 files changed, 120 insertions(+), 40 deletions(-) diff --git a/web3_proxy/src/app/mod.rs b/web3_proxy/src/app/mod.rs index f2680bea..f7009192 100644 --- a/web3_proxy/src/app/mod.rs +++ b/web3_proxy/src/app/mod.rs @@ -985,7 +985,7 @@ impl Web3ProxyApp { .proxy_web3_rpc_requests(&authorization, requests) .await?; - // TODO: real status code. i don't think we are following the spec here + // TODO: real status code. if an error happens, i don't think we are following the spec here ( StatusCode::OK, JsonRpcForwardedResponseEnum::Batch(responses), diff --git a/web3_proxy/src/app/ws.rs b/web3_proxy/src/app/ws.rs index cc13b44f..441b6537 100644 --- a/web3_proxy/src/app/ws.rs +++ b/web3_proxy/src/app/ws.rs @@ -7,15 +7,18 @@ use crate::jsonrpc::JsonRpcForwardedResponse; use crate::jsonrpc::JsonRpcRequest; use crate::response_cache::JsonRpcResponseEnum; use crate::rpcs::transactions::TxStatus; -use axum::extract::ws::Message; +use axum::extract::ws::{CloseFrame, Message}; +use deferred_rate_limiter::DeferredRateLimitResult; use ethers::types::U64; use futures::future::AbortHandle; use futures::future::Abortable; use futures::stream::StreamExt; -use log::trace; +use http::StatusCode; +use log::{error, trace}; use serde_json::json; -use std::sync::atomic::{self, AtomicUsize}; +use std::sync::atomic::{self, AtomicU64}; use std::sync::Arc; +use tokio::time::Instant; use tokio_stream::wrappers::{BroadcastStream, WatchStream}; impl Web3ProxyApp { @@ -23,7 +26,7 @@ impl Web3ProxyApp { self: &'a Arc, authorization: Arc, jsonrpc_request: JsonRpcRequest, - subscription_count: &'a AtomicUsize, + subscription_count: &'a AtomicU64, // TODO: taking a sender for Message instead of the exact json we are planning to send feels wrong, but its easier for now response_sender: flume::Sender, ) -> Web3ProxyResult<(AbortHandle, JsonRpcForwardedResponse)> { @@ -40,18 +43,25 @@ impl Web3ProxyApp { // TODO: this only needs to be unique per connection. we don't need it globably unique // TODO: have a max number of subscriptions per key/ip. have a global max number of subscriptions? how should this be calculated? let subscription_id = subscription_count.fetch_add(1, atomic::Ordering::SeqCst); - let subscription_id = U64::from(subscription_id as u64); + let subscription_id = U64::from(subscription_id); // save the id so we can use it in the response let id = jsonrpc_request.id.clone(); + let subscribe_to = jsonrpc_request + .params + .get(0) + .and_then(|x| x.as_str()) + .ok_or_else(|| { + Web3ProxyError::BadRequest("unable to subscribe using these params".into()) + })?; + // TODO: calling json! on every request is probably not fast. but we can only match against // TODO: i think we need a stricter EthSubscribeRequest type that JsonRpcRequest can turn into - if jsonrpc_request.params == json!(["newHeads"]) { + if subscribe_to == "newHeads" { let head_block_receiver = self.watch_consensus_head_receiver.clone(); let app = self.clone(); - trace!("newHeads subscription {:?}", subscription_id); tokio::spawn(async move { let mut head_block_receiver = Abortable::new( WatchStream::new(head_block_receiver), @@ -73,6 +83,14 @@ impl Web3ProxyApp { ) .await; + if let Some(close_message) = app + .rate_limit_close_websocket(&subscription_request_metadata) + .await + { + let _ = response_sender.send_async(close_message).await; + break; + } + // TODO: make a struct for this? using our JsonRpcForwardedResponse won't work because it needs an id let response_json = json!({ "jsonrpc": "2.0", @@ -105,7 +123,7 @@ impl Web3ProxyApp { trace!("closed newHeads subscription {:?}", subscription_id); }); - } else if jsonrpc_request.params == json!(["newPendingTransactions"]) { + } else if subscribe_to == "newPendingTransactions" { let pending_tx_receiver = self.pending_tx_sender.subscribe(); let app = self.clone(); @@ -119,7 +137,6 @@ impl Web3ProxyApp { subscription_id ); - // TODO: do something with this handle? tokio::spawn(async move { while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await { let subscription_request_metadata = RequestMetadata::new( @@ -130,6 +147,14 @@ impl Web3ProxyApp { ) .await; + if let Some(close_message) = app + .rate_limit_close_websocket(&subscription_request_metadata) + .await + { + let _ = response_sender.send_async(close_message).await; + break; + } + let new_tx = match new_tx_state { TxStatus::Pending(tx) => tx, TxStatus::Confirmed(..) => continue, @@ -154,7 +179,7 @@ impl Web3ProxyApp { subscription_request_metadata.add_response(response_bytes); - // TODO: do clients support binary messages? + // TODO: do clients support binary messages? reply with binary if thats what we were sent let response_msg = Message::Text(response_str); if response_sender.send_async(response_msg).await.is_err() { @@ -168,7 +193,7 @@ impl Web3ProxyApp { subscription_id ); }); - } else if jsonrpc_request.params == json!(["newPendingFullTransactions"]) { + } else if subscribe_to == "newPendingFullTransactions" { // TODO: too much copy/pasta with newPendingTransactions let pending_tx_receiver = self.pending_tx_sender.subscribe(); let app = self.clone(); @@ -183,7 +208,6 @@ impl Web3ProxyApp { subscription_id ); - // TODO: do something with this handle? tokio::spawn(async move { while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await { let subscription_request_metadata = RequestMetadata::new( @@ -194,6 +218,14 @@ impl Web3ProxyApp { ) .await; + if let Some(close_message) = app + .rate_limit_close_websocket(&subscription_request_metadata) + .await + { + let _ = response_sender.send_async(close_message).await; + break; + } + let new_tx = match new_tx_state { TxStatus::Pending(tx) => tx, TxStatus::Confirmed(..) => continue, @@ -230,7 +262,7 @@ impl Web3ProxyApp { subscription_id ); }); - } else if jsonrpc_request.params == json!(["newPendingRawTransactions"]) { + } else if subscribe_to == "newPendingRawTransactions" { // TODO: too much copy/pasta with newPendingTransactions let pending_tx_receiver = self.pending_tx_sender.subscribe(); let app = self.clone(); @@ -245,7 +277,6 @@ impl Web3ProxyApp { subscription_id ); - // TODO: do something with this handle? tokio::spawn(async move { while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await { let subscription_request_metadata = RequestMetadata::new( @@ -256,6 +287,14 @@ impl Web3ProxyApp { ) .await; + if let Some(close_message) = app + .rate_limit_close_websocket(&subscription_request_metadata) + .await + { + let _ = response_sender.send_async(close_message).await; + break; + } + let new_tx = match new_tx_state { TxStatus::Pending(tx) => tx, TxStatus::Confirmed(..) => continue, @@ -311,4 +350,55 @@ impl Web3ProxyApp { // TODO: make a `SubscriptonHandle(AbortHandle, JoinHandle)` struct? Ok((subscription_abort_handle, response)) } + + async fn rate_limit_close_websocket( + &self, + request_metadata: &RequestMetadata, + ) -> Option { + if let Some(authorization) = request_metadata.authorization.as_ref() { + if authorization.checks.rpc_secret_key_id.is_none() { + if let Some(rate_limiter) = &self.frontend_ip_rate_limiter { + match rate_limiter + .throttle( + authorization.ip, + authorization.checks.max_requests_per_period, + 1, + ) + .await + { + Ok(DeferredRateLimitResult::RetryNever) => { + let close_frame = CloseFrame { + code: StatusCode::TOO_MANY_REQUESTS.as_u16(), + reason: + "rate limited. upgrade to premium for unlimited websocket messages" + .into(), + }; + + return Some(Message::Close(Some(close_frame))); + } + Ok(DeferredRateLimitResult::RetryAt(retry_at)) => { + let retry_at = retry_at.duration_since(Instant::now()); + + let reason = format!("rate limited. upgrade to premium for unlimited websocket messages. retry in {}s", retry_at.as_secs_f32()); + + let close_frame = CloseFrame { + code: StatusCode::TOO_MANY_REQUESTS.as_u16(), + reason: reason.into(), + }; + + return Some(Message::Close(Some(close_frame))); + } + Ok(_) => {} + Err(err) => { + // this an internal error of some kind, not the rate limit being hit + // TODO: i really want axum to do this for us in a single place. + error!("rate limiter is unhappy. allowing ip. err={:?}", err); + } + } + } + } + } + + None + } } diff --git a/web3_proxy/src/frontend/rpc_proxy_ws.rs b/web3_proxy/src/frontend/rpc_proxy_ws.rs index fc376e12..74b8f3ba 100644 --- a/web3_proxy/src/frontend/rpc_proxy_ws.rs +++ b/web3_proxy/src/frontend/rpc_proxy_ws.rs @@ -32,6 +32,7 @@ use hashbrown::HashMap; use http::StatusCode; use log::{info, trace}; use serde_json::json; +use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::{str::from_utf8_mut, sync::atomic::AtomicUsize}; use tokio::sync::{broadcast, OwnedSemaphorePermit, RwLock}; @@ -318,13 +319,12 @@ async fn proxy_web3_socket( } /// websockets support a few more methods than http clients -/// TODO: i think this subscriptions hashmap grows unbounded async fn handle_socket_payload( app: Arc, authorization: &Arc, payload: &str, response_sender: &flume::Sender, - subscription_count: &AtomicUsize, + subscription_count: &AtomicU64, subscriptions: Arc>>, ) -> Web3ProxyResult<(Message, Option)> { let (authorization, semaphore) = match authorization.check_again(&app).await { @@ -456,7 +456,7 @@ async fn read_web3_socket( ) { // RwLock should be fine here. a user isn't going to be opening tons of subscriptions let subscriptions = Arc::new(RwLock::new(HashMap::new())); - let subscription_count = Arc::new(AtomicUsize::new(1)); + let subscription_count = Arc::new(AtomicU64::new(1)); let (close_sender, mut close_receiver) = broadcast::channel(1); @@ -464,8 +464,7 @@ async fn read_web3_socket( tokio::select! { msg = ws_rx.next() => { if let Some(Ok(msg)) = msg { - // spawn so that we can serve responses from this loop even faster - // TODO: only do these clones if the msg is text/binary? + // clone things so we can handle multiple messages in parallel let close_sender = close_sender.clone(); let app = app.clone(); let authorization = authorization.clone(); @@ -474,13 +473,12 @@ async fn read_web3_socket( let subscription_count = subscription_count.clone(); let f = async move { - let mut _semaphore = None; - - // new message from our client. forward to a backend and then send it through response_tx - let response_msg = match msg { + // new message from our client. forward to a backend and then send it through response_sender + let (response_msg, _semaphore) = match msg { Message::Text(ref payload) => { - // TODO: do not unwrap! - let (msg, s) = handle_socket_payload( + // TODO: do not unwrap! turn errors into a jsonrpc response and send that instead + // TODO: some providers close the connection on error. i don't like that + let (m, s) = handle_socket_payload( app.clone(), &authorization, payload, @@ -490,13 +488,11 @@ async fn read_web3_socket( ) .await.unwrap(); - _semaphore = s; - - msg + (m, Some(s)) } Message::Ping(x) => { trace!("ping: {:?}", x); - Message::Pong(x) + (Message::Pong(x), None) } Message::Pong(x) => { trace!("pong: {:?}", x); @@ -511,8 +507,8 @@ async fn read_web3_socket( Message::Binary(mut payload) => { let payload = from_utf8_mut(&mut payload).unwrap(); - // TODO: do not unwrap! - let (msg, s) = handle_socket_payload( + // TODO: do not unwrap! turn errors into a jsonrpc response and send that instead + let (m, s) = handle_socket_payload( app.clone(), &authorization, payload, @@ -522,18 +518,13 @@ async fn read_web3_socket( ) .await.unwrap(); - _semaphore = s; - - msg + (m, Some(s)) } }; if response_sender.send_async(response_msg).await.is_err() { let _ = close_sender.send(true); - return; }; - - _semaphore = None; }; tokio::spawn(f); @@ -554,11 +545,10 @@ async fn write_web3_socket( ) { // TODO: increment counter for open websockets - // TODO: is there any way to make this stream receive. while let Ok(msg) = response_rx.recv_async().await { // a response is ready - // TODO: poke rate limits for this user? + // we do not check rate limits here. they are checked before putting things into response_sender; // forward the response to through the websocket if let Err(err) = ws_tx.send(msg).await {