diff --git a/web3_proxy/src/app/ws.rs b/web3_proxy/src/app/ws.rs index 830059b8..3656d03c 100644 --- a/web3_proxy/src/app/ws.rs +++ b/web3_proxy/src/app/ws.rs @@ -28,7 +28,7 @@ impl Web3ProxyApp { jsonrpc_request: JsonRpcRequest, subscription_count: &'a AtomicU64, // TODO: taking a sender for Message instead of the exact json we are planning to send feels wrong, but its easier for now - response_sender: mpsc::UnboundedSender, + response_sender: mpsc::Sender, ) -> Web3ProxyResult<(AbortHandle, JsonRpcForwardedResponse)> { let request_metadata = RequestMetadata::new( self, @@ -87,7 +87,7 @@ impl Web3ProxyApp { .rate_limit_close_websocket(&subscription_request_metadata) .await { - let _ = response_sender.send(close_message); + let _ = response_sender.send(close_message).await; break; } @@ -112,7 +112,7 @@ impl Web3ProxyApp { // TODO: can we check a content type header? let response_msg = Message::Text(response_str); - if response_sender.send(response_msg).is_err() { + if response_sender.send(response_msg).await.is_err() { // TODO: increment error_response? i don't think so. i think this will happen once every time a client disconnects. // TODO: cancel this subscription earlier? select on head_block_receiver.next() and an abort handle? break; diff --git a/web3_proxy/src/compute_units.rs b/web3_proxy/src/compute_units.rs index b925043f..1f0cd909 100644 --- a/web3_proxy/src/compute_units.rs +++ b/web3_proxy/src/compute_units.rs @@ -121,6 +121,7 @@ impl ComputeUnit { (_, "trace_replayBlockTransactions") => 2983, (_, "trace_replayTransaction") => 2983, (_, "trace_transaction") => 26, + (_, "invalid_method") => 100, (_, "web3_clientVersion") => 15, (_, "web3_sha3") => 15, (_, method) => { diff --git a/web3_proxy/src/errors.rs b/web3_proxy/src/errors.rs index 9f908ee0..2ef5853c 100644 --- a/web3_proxy/src/errors.rs +++ b/web3_proxy/src/errors.rs @@ -902,7 +902,7 @@ impl Web3ProxyError { ( StatusCode::UNAUTHORIZED, JsonRpcErrorData { - message: format!("siwe verification error: {}", err.to_string()).into(), + message: format!("siwe verification error: {}", err).into(), code: StatusCode::UNAUTHORIZED.as_u16().into(), data: None, }, diff --git a/web3_proxy/src/frontend/authorization.rs b/web3_proxy/src/frontend/authorization.rs index ffec9528..e36b137c 100644 --- a/web3_proxy/src/frontend/authorization.rs +++ b/web3_proxy/src/frontend/authorization.rs @@ -445,6 +445,7 @@ impl<'a> From<&'a str> for RequestOrMethod<'a> { pub enum ResponseOrBytes<'a> { Json(&'a serde_json::Value), Response(&'a JsonRpcForwardedResponse), + Error(&'a Web3ProxyError), Bytes(usize), } @@ -464,6 +465,11 @@ impl ResponseOrBytes<'_> { .expect("this should always serialize") .len(), Self::Bytes(num_bytes) => *num_bytes, + Self::Error(x) => { + let (_, x) = x.as_response_parts::<()>(); + + x.num_bytes() as usize + } } } } diff --git a/web3_proxy/src/frontend/rpc_proxy_http.rs b/web3_proxy/src/frontend/rpc_proxy_http.rs index fbfc7036..86b697e0 100644 --- a/web3_proxy/src/frontend/rpc_proxy_http.rs +++ b/web3_proxy/src/frontend/rpc_proxy_http.rs @@ -15,6 +15,7 @@ use http::HeaderMap; use itertools::Itertools; use std::net::IpAddr; use std::sync::Arc; +use std::time::Duration; /// POST /rpc -- Public entrypoint for HTTP JSON-RPC requests. Web3 wallets use this. /// Defaults to rate limiting by IP address, but can also read the Authorization header for a bearer token. @@ -66,6 +67,10 @@ async fn _proxy_web3_rpc( let authorization = Arc::new(authorization); + payload + .tarpit_invalid(&app, &authorization, Duration::from_secs(5)) + .await?; + // TODO: calculate payload bytes here (before turning into serde_json::Value). that will save serializing later // TODO: is first_id the right thing to attach to this error? @@ -254,6 +259,10 @@ async fn _proxy_web3_rpc_with_key( let authorization = Arc::new(authorization); + payload + .tarpit_invalid(&app, &authorization, Duration::from_secs(2)) + .await?; + let rpc_secret_key_id = authorization.checks.rpc_secret_key_id; let (status_code, response, rpcs) = app diff --git a/web3_proxy/src/frontend/rpc_proxy_ws.rs b/web3_proxy/src/frontend/rpc_proxy_ws.rs index 83164d19..8283bb20 100644 --- a/web3_proxy/src/frontend/rpc_proxy_ws.rs +++ b/web3_proxy/src/frontend/rpc_proxy_ws.rs @@ -10,7 +10,6 @@ use crate::{ errors::Web3ProxyResult, jsonrpc::{JsonRpcForwardedResponse, JsonRpcForwardedResponseEnum, JsonRpcRequest}, }; -use anyhow::Context; use axum::headers::{Origin, Referer, UserAgent}; use axum::{ extract::ws::{Message, WebSocket, WebSocketUpgrade}, @@ -30,6 +29,7 @@ use handlebars::Handlebars; use hashbrown::HashMap; use http::{HeaderMap, StatusCode}; use serde_json::json; +use serde_json::value::RawValue; use std::net::IpAddr; use std::str::from_utf8_mut; use std::sync::atomic::AtomicU64; @@ -305,20 +305,110 @@ async fn proxy_web3_socket( // split the websocket so we can read and write concurrently let (ws_tx, ws_rx) = socket.split(); + let buffer = authorization.checks.max_concurrent_requests.unwrap_or(2048) as usize; + // create a channel for our reader and writer can communicate. todo: benchmark different channels // TODO: this should be bounded. async blocking on too many messages would be fine - let (response_sender, response_receiver) = mpsc::unbounded_channel::(); + let (response_sender, response_receiver) = mpsc::channel::(buffer); tokio::spawn(write_web3_socket(response_receiver, ws_tx)); tokio::spawn(read_web3_socket(app, authorization, ws_rx, response_sender)); } +async fn websocket_proxy_web3_rpc( + app: Arc, + authorization: Arc, + json_request: JsonRpcRequest, + response_sender: &mpsc::Sender, + subscription_count: &AtomicU64, + subscriptions: &AsyncRwLock>, +) -> (Box, Web3ProxyResult) { + let response_id = json_request.id.clone(); + + // TODO: move this to a seperate function so we can use the try operator + let response: Web3ProxyResult = match &json_request.method[..] { + "eth_subscribe" => { + // TODO: how can we subscribe with proxy_mode? + match app + .eth_subscribe( + authorization, + json_request, + subscription_count, + response_sender.clone(), + ) + .await + { + Ok((handle, response)) => { + if let Some(subscription_id) = response.result.clone() { + let mut x = subscriptions.write().await; + + let key: U64 = serde_json::from_str(subscription_id.get()).unwrap(); + + x.insert(key, handle); + } + + Ok(response.into()) + } + Err(err) => Err(err), + } + } + "eth_unsubscribe" => { + let request_metadata = + RequestMetadata::new(&app, authorization, &json_request, None).await; + + let maybe_id = json_request + .params + .get(0) + .cloned() + .unwrap_or(json_request.params); + + let subscription_id: U64 = match serde_json::from_value::(maybe_id) { + Ok(x) => x, + Err(err) => { + return ( + response_id, + Err(Web3ProxyError::BadRequest( + format!("unexpected params given for eth_unsubscribe: {:?}", err) + .into(), + )), + ) + } + }; + + // TODO: is this the right response? + let partial_response = { + let mut x = subscriptions.write().await; + match x.remove(&subscription_id) { + None => false, + Some(handle) => { + handle.abort(); + true + } + } + }; + + let response = + JsonRpcForwardedResponse::from_value(json!(partial_response), response_id.clone()); + + request_metadata.add_response(&response); + + Ok(response.into()) + } + _ => app + .proxy_web3_rpc(authorization, json_request.into()) + .await + .map(|(_, response, _)| response), + }; + + (response_id, response) +} + /// websockets support a few more methods than http clients async fn handle_socket_payload( app: Arc, authorization: &Arc, payload: &str, - response_sender: &mpsc::UnboundedSender, + response_sender: &mpsc::Sender, subscription_count: &AtomicU64, subscriptions: Arc>>, ) -> Web3ProxyResult<(Message, Option)> { @@ -327,89 +417,21 @@ async fn handle_socket_payload( // TODO: handle batched requests let (response_id, response) = match serde_json::from_str::(payload) { Ok(json_request) => { - let response_id = json_request.id.clone(); + // // TODO: move tarpit code to an invidual request, or change this to handle enums + // json_request + // .tarpit_invalid(&app, &authorization, Duration::from_secs(2)) + // .await?; // TODO: move this to a seperate function so we can use the try operator - let response: Web3ProxyResult = match &json_request.method - [..] - { - "eth_subscribe" => { - // TODO: how can we subscribe with proxy_mode? - match app - .eth_subscribe( - authorization.clone(), - json_request, - subscription_count, - response_sender.clone(), - ) - .await - { - Ok((handle, response)) => { - if let Some(subscription_id) = response.result.clone() { - let mut x = subscriptions.write().await; - - let key: U64 = serde_json::from_str(subscription_id.get()).unwrap(); - - x.insert(key, handle); - } - - Ok(response.into()) - } - Err(err) => Err(err), - } - } - "eth_unsubscribe" => { - let request_metadata = - RequestMetadata::new(&app, authorization.clone(), &json_request, None) - .await; - - let subscription_id: U64 = - if let Some(param) = json_request.params.get(0).cloned() { - serde_json::from_value(param) - .context("failed parsing [subscription_id] as a U64")? - } else { - match serde_json::from_value::(json_request.params) { - Ok(x) => x, - Err(err) => { - return Err(Web3ProxyError::BadRequest( - format!( - "unexpected params given for eth_unsubscribe: {:?}", - err - ) - .into(), - )) - } - } - }; - - // TODO: is this the right response? - let partial_response = { - let mut x = subscriptions.write().await; - match x.remove(&subscription_id) { - None => false, - Some(handle) => { - handle.abort(); - true - } - } - }; - - let response = JsonRpcForwardedResponse::from_value( - json!(partial_response), - response_id.clone(), - ); - - request_metadata.add_response(&response); - - Ok(response.into()) - } - _ => app - .proxy_web3_rpc(authorization.clone(), json_request.into()) - .await - .map(|(_, response, _)| response), - }; - - (response_id, response) + websocket_proxy_web3_rpc( + app, + authorization.clone(), + json_request, + response_sender, + subscription_count, + &subscriptions, + ) + .await } Err(err) => { let id = JsonRpcId::None.to_raw_value(); @@ -435,7 +457,7 @@ async fn read_web3_socket( app: Arc, authorization: Arc, mut ws_rx: SplitStream, - response_sender: mpsc::UnboundedSender, + response_sender: mpsc::Sender, ) { let subscriptions = Arc::new(AsyncRwLock::new(HashMap::new())); let subscription_count = Arc::new(AtomicU64::new(1)); @@ -520,7 +542,7 @@ async fn read_web3_socket( } }; - if response_sender.send(response_msg).is_err() { + if response_sender.send(response_msg).await.is_err() { let _ = close_sender.send(true); }; }; @@ -538,7 +560,7 @@ async fn read_web3_socket( } async fn write_web3_socket( - mut response_rx: mpsc::UnboundedReceiver, + mut response_rx: mpsc::Receiver, mut ws_tx: SplitSink, ) { // TODO: increment counter for open websockets diff --git a/web3_proxy/src/jsonrpc.rs b/web3_proxy/src/jsonrpc.rs index c3403738..aabe8c71 100644 --- a/web3_proxy/src/jsonrpc.rs +++ b/web3_proxy/src/jsonrpc.rs @@ -1,4 +1,8 @@ +use crate::app::Web3ProxyApp; +use crate::errors::Web3ProxyError; +use crate::frontend::authorization::{Authorization, RequestMetadata, RequestOrMethod}; use crate::response_cache::JsonRpcResponseEnum; +use axum::response::Response; use derive_more::From; use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor}; use serde::{Deserialize, Serialize}; @@ -6,7 +10,9 @@ use serde_json::json; use serde_json::value::{to_raw_value, RawValue}; use std::borrow::Cow; use std::fmt; -use std::sync::Arc; +use std::sync::{atomic, Arc}; +use std::time::Duration; +use tokio::time::sleep; pub trait JsonRpcParams = fmt::Debug + serde::Serialize + Send + Sync + 'static; pub trait JsonRpcResultData = serde::Serialize + serde::de::DeserializeOwned + fmt::Debug + Send; @@ -44,6 +50,7 @@ impl JsonRpcId { } impl JsonRpcRequest { + // TODO: Web3ProxyResult? can this even fail? pub fn new(id: JsonRpcId, method: String, params: serde_json::Value) -> anyhow::Result { let x = Self { jsonrpc: "2.0".to_string(), @@ -54,6 +61,12 @@ impl JsonRpcRequest { Ok(x) } + + pub fn validate_method(&self) -> bool { + self.method + .chars() + .all(|x| x.is_ascii_alphanumeric() || x == '_' || x == '(' || x == ')') + } } impl fmt::Debug for JsonRpcRequest { @@ -69,7 +82,7 @@ impl fmt::Debug for JsonRpcRequest { } /// Requests can come in multiple formats -#[derive(Debug, From)] +#[derive(Debug, From, Serialize)] pub enum JsonRpcRequestEnum { Batch(Vec), Single(JsonRpcRequest), @@ -82,6 +95,62 @@ impl JsonRpcRequestEnum { Self::Single(x) => Some(x.id.clone()), } } + + /// returns the id of the first invalid result (if any). None is good + pub fn validate(&self) -> Option> { + match self { + Self::Batch(x) => x + .iter() + .find_map(|x| (!x.validate_method()).then_some(x.id.clone())), + Self::Single(x) => { + if x.validate_method() { + None + } else { + Some(x.id.clone()) + } + } + } + } + + /// returns the id of the first invalid result (if any). None is good + pub async fn tarpit_invalid( + &self, + app: &Web3ProxyApp, + authorization: &Arc, + duration: Duration, + ) -> Result<(), Response> { + let err_id = match self.validate() { + None => return Ok(()), + Some(x) => x, + }; + + let size = serde_json::to_string(&self) + .expect("JsonRpcRequestEnum should always serialize") + .len(); + + let request = RequestOrMethod::Method("invalid_method", size); + + // TODO: create a stat so we can penalize + // TODO: what request size + let metadata = RequestMetadata::new(app, authorization.clone(), request, None).await; + + metadata + .user_error_response + .store(true, atomic::Ordering::Release); + + let response = Web3ProxyError::BadRequest("request failed validation".into()); + + metadata.add_response(&response); + + let response = response.into_response_with_id(Some(err_id)); + + // TODO: variable duration depending on the IP + sleep(duration).await; + + let _ = metadata.try_send_arc_stat(); + + Err(response) + } } impl<'de> Deserialize<'de> for JsonRpcRequestEnum { diff --git a/web3_proxy/src/rpcs/many.rs b/web3_proxy/src/rpcs/many.rs index cad744e1..c3df64a9 100644 --- a/web3_proxy/src/rpcs/many.rs +++ b/web3_proxy/src/rpcs/many.rs @@ -411,7 +411,7 @@ impl Web3Rpcs { count_map.insert(s.clone(), partial_response); } - counts.update([s].into_iter()); + counts.update([s]); } // return the most_common success if any. otherwise return the most_common error diff --git a/web3_proxy/src/stats/stat_buffer.rs b/web3_proxy/src/stats/stat_buffer.rs index a09f946d..b9678a3e 100644 --- a/web3_proxy/src/stats/stat_buffer.rs +++ b/web3_proxy/src/stats/stat_buffer.rs @@ -437,3 +437,11 @@ impl StatBuffer { count } } + +#[cfg(test)] +mod tests { + #[test] + fn test_something() { + panic!() + } +}