rate limit on websockets

This commit is contained in:
Bryan Stitt 2023-05-31 16:05:41 -07:00
parent cadab50692
commit fbe0ecfbff
3 changed files with 120 additions and 40 deletions

View File

@ -985,7 +985,7 @@ impl Web3ProxyApp {
.proxy_web3_rpc_requests(&authorization, requests)
.await?;
// TODO: real status code. i don't think we are following the spec here
// TODO: real status code. if an error happens, i don't think we are following the spec here
(
StatusCode::OK,
JsonRpcForwardedResponseEnum::Batch(responses),

View File

@ -7,15 +7,18 @@ use crate::jsonrpc::JsonRpcForwardedResponse;
use crate::jsonrpc::JsonRpcRequest;
use crate::response_cache::JsonRpcResponseEnum;
use crate::rpcs::transactions::TxStatus;
use axum::extract::ws::Message;
use axum::extract::ws::{CloseFrame, Message};
use deferred_rate_limiter::DeferredRateLimitResult;
use ethers::types::U64;
use futures::future::AbortHandle;
use futures::future::Abortable;
use futures::stream::StreamExt;
use log::trace;
use http::StatusCode;
use log::{error, trace};
use serde_json::json;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::atomic::{self, AtomicU64};
use std::sync::Arc;
use tokio::time::Instant;
use tokio_stream::wrappers::{BroadcastStream, WatchStream};
impl Web3ProxyApp {
@ -23,7 +26,7 @@ impl Web3ProxyApp {
self: &'a Arc<Self>,
authorization: Arc<Authorization>,
jsonrpc_request: JsonRpcRequest,
subscription_count: &'a AtomicUsize,
subscription_count: &'a AtomicU64,
// TODO: taking a sender for Message instead of the exact json we are planning to send feels wrong, but its easier for now
response_sender: flume::Sender<Message>,
) -> Web3ProxyResult<(AbortHandle, JsonRpcForwardedResponse)> {
@ -40,18 +43,25 @@ impl Web3ProxyApp {
// TODO: this only needs to be unique per connection. we don't need it globably unique
// TODO: have a max number of subscriptions per key/ip. have a global max number of subscriptions? how should this be calculated?
let subscription_id = subscription_count.fetch_add(1, atomic::Ordering::SeqCst);
let subscription_id = U64::from(subscription_id as u64);
let subscription_id = U64::from(subscription_id);
// save the id so we can use it in the response
let id = jsonrpc_request.id.clone();
let subscribe_to = jsonrpc_request
.params
.get(0)
.and_then(|x| x.as_str())
.ok_or_else(|| {
Web3ProxyError::BadRequest("unable to subscribe using these params".into())
})?;
// TODO: calling json! on every request is probably not fast. but we can only match against
// TODO: i think we need a stricter EthSubscribeRequest type that JsonRpcRequest can turn into
if jsonrpc_request.params == json!(["newHeads"]) {
if subscribe_to == "newHeads" {
let head_block_receiver = self.watch_consensus_head_receiver.clone();
let app = self.clone();
trace!("newHeads subscription {:?}", subscription_id);
tokio::spawn(async move {
let mut head_block_receiver = Abortable::new(
WatchStream::new(head_block_receiver),
@ -73,6 +83,14 @@ impl Web3ProxyApp {
)
.await;
if let Some(close_message) = app
.rate_limit_close_websocket(&subscription_request_metadata)
.await
{
let _ = response_sender.send_async(close_message).await;
break;
}
// TODO: make a struct for this? using our JsonRpcForwardedResponse won't work because it needs an id
let response_json = json!({
"jsonrpc": "2.0",
@ -105,7 +123,7 @@ impl Web3ProxyApp {
trace!("closed newHeads subscription {:?}", subscription_id);
});
} else if jsonrpc_request.params == json!(["newPendingTransactions"]) {
} else if subscribe_to == "newPendingTransactions" {
let pending_tx_receiver = self.pending_tx_sender.subscribe();
let app = self.clone();
@ -119,7 +137,6 @@ impl Web3ProxyApp {
subscription_id
);
// TODO: do something with this handle?
tokio::spawn(async move {
while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await {
let subscription_request_metadata = RequestMetadata::new(
@ -130,6 +147,14 @@ impl Web3ProxyApp {
)
.await;
if let Some(close_message) = app
.rate_limit_close_websocket(&subscription_request_metadata)
.await
{
let _ = response_sender.send_async(close_message).await;
break;
}
let new_tx = match new_tx_state {
TxStatus::Pending(tx) => tx,
TxStatus::Confirmed(..) => continue,
@ -154,7 +179,7 @@ impl Web3ProxyApp {
subscription_request_metadata.add_response(response_bytes);
// TODO: do clients support binary messages?
// TODO: do clients support binary messages? reply with binary if thats what we were sent
let response_msg = Message::Text(response_str);
if response_sender.send_async(response_msg).await.is_err() {
@ -168,7 +193,7 @@ impl Web3ProxyApp {
subscription_id
);
});
} else if jsonrpc_request.params == json!(["newPendingFullTransactions"]) {
} else if subscribe_to == "newPendingFullTransactions" {
// TODO: too much copy/pasta with newPendingTransactions
let pending_tx_receiver = self.pending_tx_sender.subscribe();
let app = self.clone();
@ -183,7 +208,6 @@ impl Web3ProxyApp {
subscription_id
);
// TODO: do something with this handle?
tokio::spawn(async move {
while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await {
let subscription_request_metadata = RequestMetadata::new(
@ -194,6 +218,14 @@ impl Web3ProxyApp {
)
.await;
if let Some(close_message) = app
.rate_limit_close_websocket(&subscription_request_metadata)
.await
{
let _ = response_sender.send_async(close_message).await;
break;
}
let new_tx = match new_tx_state {
TxStatus::Pending(tx) => tx,
TxStatus::Confirmed(..) => continue,
@ -230,7 +262,7 @@ impl Web3ProxyApp {
subscription_id
);
});
} else if jsonrpc_request.params == json!(["newPendingRawTransactions"]) {
} else if subscribe_to == "newPendingRawTransactions" {
// TODO: too much copy/pasta with newPendingTransactions
let pending_tx_receiver = self.pending_tx_sender.subscribe();
let app = self.clone();
@ -245,7 +277,6 @@ impl Web3ProxyApp {
subscription_id
);
// TODO: do something with this handle?
tokio::spawn(async move {
while let Some(Ok(new_tx_state)) = pending_tx_receiver.next().await {
let subscription_request_metadata = RequestMetadata::new(
@ -256,6 +287,14 @@ impl Web3ProxyApp {
)
.await;
if let Some(close_message) = app
.rate_limit_close_websocket(&subscription_request_metadata)
.await
{
let _ = response_sender.send_async(close_message).await;
break;
}
let new_tx = match new_tx_state {
TxStatus::Pending(tx) => tx,
TxStatus::Confirmed(..) => continue,
@ -311,4 +350,55 @@ impl Web3ProxyApp {
// TODO: make a `SubscriptonHandle(AbortHandle, JoinHandle)` struct?
Ok((subscription_abort_handle, response))
}
async fn rate_limit_close_websocket(
&self,
request_metadata: &RequestMetadata,
) -> Option<Message> {
if let Some(authorization) = request_metadata.authorization.as_ref() {
if authorization.checks.rpc_secret_key_id.is_none() {
if let Some(rate_limiter) = &self.frontend_ip_rate_limiter {
match rate_limiter
.throttle(
authorization.ip,
authorization.checks.max_requests_per_period,
1,
)
.await
{
Ok(DeferredRateLimitResult::RetryNever) => {
let close_frame = CloseFrame {
code: StatusCode::TOO_MANY_REQUESTS.as_u16(),
reason:
"rate limited. upgrade to premium for unlimited websocket messages"
.into(),
};
return Some(Message::Close(Some(close_frame)));
}
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
let retry_at = retry_at.duration_since(Instant::now());
let reason = format!("rate limited. upgrade to premium for unlimited websocket messages. retry in {}s", retry_at.as_secs_f32());
let close_frame = CloseFrame {
code: StatusCode::TOO_MANY_REQUESTS.as_u16(),
reason: reason.into(),
};
return Some(Message::Close(Some(close_frame)));
}
Ok(_) => {}
Err(err) => {
// this an internal error of some kind, not the rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!("rate limiter is unhappy. allowing ip. err={:?}", err);
}
}
}
}
}
None
}
}

View File

@ -32,6 +32,7 @@ use hashbrown::HashMap;
use http::StatusCode;
use log::{info, trace};
use serde_json::json;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::{str::from_utf8_mut, sync::atomic::AtomicUsize};
use tokio::sync::{broadcast, OwnedSemaphorePermit, RwLock};
@ -318,13 +319,12 @@ async fn proxy_web3_socket(
}
/// websockets support a few more methods than http clients
/// TODO: i think this subscriptions hashmap grows unbounded
async fn handle_socket_payload(
app: Arc<Web3ProxyApp>,
authorization: &Arc<Authorization>,
payload: &str,
response_sender: &flume::Sender<Message>,
subscription_count: &AtomicUsize,
subscription_count: &AtomicU64,
subscriptions: Arc<RwLock<HashMap<U64, AbortHandle>>>,
) -> Web3ProxyResult<(Message, Option<OwnedSemaphorePermit>)> {
let (authorization, semaphore) = match authorization.check_again(&app).await {
@ -456,7 +456,7 @@ async fn read_web3_socket(
) {
// RwLock should be fine here. a user isn't going to be opening tons of subscriptions
let subscriptions = Arc::new(RwLock::new(HashMap::new()));
let subscription_count = Arc::new(AtomicUsize::new(1));
let subscription_count = Arc::new(AtomicU64::new(1));
let (close_sender, mut close_receiver) = broadcast::channel(1);
@ -464,8 +464,7 @@ async fn read_web3_socket(
tokio::select! {
msg = ws_rx.next() => {
if let Some(Ok(msg)) = msg {
// spawn so that we can serve responses from this loop even faster
// TODO: only do these clones if the msg is text/binary?
// clone things so we can handle multiple messages in parallel
let close_sender = close_sender.clone();
let app = app.clone();
let authorization = authorization.clone();
@ -474,13 +473,12 @@ async fn read_web3_socket(
let subscription_count = subscription_count.clone();
let f = async move {
let mut _semaphore = None;
// new message from our client. forward to a backend and then send it through response_tx
let response_msg = match msg {
// new message from our client. forward to a backend and then send it through response_sender
let (response_msg, _semaphore) = match msg {
Message::Text(ref payload) => {
// TODO: do not unwrap!
let (msg, s) = handle_socket_payload(
// TODO: do not unwrap! turn errors into a jsonrpc response and send that instead
// TODO: some providers close the connection on error. i don't like that
let (m, s) = handle_socket_payload(
app.clone(),
&authorization,
payload,
@ -490,13 +488,11 @@ async fn read_web3_socket(
)
.await.unwrap();
_semaphore = s;
msg
(m, Some(s))
}
Message::Ping(x) => {
trace!("ping: {:?}", x);
Message::Pong(x)
(Message::Pong(x), None)
}
Message::Pong(x) => {
trace!("pong: {:?}", x);
@ -511,8 +507,8 @@ async fn read_web3_socket(
Message::Binary(mut payload) => {
let payload = from_utf8_mut(&mut payload).unwrap();
// TODO: do not unwrap!
let (msg, s) = handle_socket_payload(
// TODO: do not unwrap! turn errors into a jsonrpc response and send that instead
let (m, s) = handle_socket_payload(
app.clone(),
&authorization,
payload,
@ -522,18 +518,13 @@ async fn read_web3_socket(
)
.await.unwrap();
_semaphore = s;
msg
(m, Some(s))
}
};
if response_sender.send_async(response_msg).await.is_err() {
let _ = close_sender.send(true);
return;
};
_semaphore = None;
};
tokio::spawn(f);
@ -554,11 +545,10 @@ async fn write_web3_socket(
) {
// TODO: increment counter for open websockets
// TODO: is there any way to make this stream receive.
while let Ok(msg) = response_rx.recv_async().await {
// a response is ready
// TODO: poke rate limits for this user?
// we do not check rate limits here. they are checked before putting things into response_sender;
// forward the response to through the websocket
if let Err(err) = ws_tx.send(msg).await {