From 49c60ac1b5d92aa9dbe0e1348c201bc517294ef5 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Thu, 15 Jun 2023 09:50:21 -0700 Subject: [PATCH] improve websocket error handling --- web3_proxy/src/errors.rs | 16 ++++ web3_proxy/src/frontend/rpc_proxy_ws.rs | 102 ++++++++---------------- 2 files changed, 51 insertions(+), 67 deletions(-) diff --git a/web3_proxy/src/errors.rs b/web3_proxy/src/errors.rs index ce9c461e..45fc5eb2 100644 --- a/web3_proxy/src/errors.rs +++ b/web3_proxy/src/errors.rs @@ -4,6 +4,7 @@ use crate::frontend::authorization::Authorization; use crate::jsonrpc::{JsonRpcErrorData, JsonRpcForwardedResponse}; use crate::response_cache::JsonRpcResponseEnum; use crate::rpcs::provider::EthersHttpProvider; +use axum::extract::ws::Message; use axum::{ headers, http::StatusCode, @@ -1062,3 +1063,18 @@ where self.map_err(|err| Web3ProxyError::WithContext(Some(Box::new(err.into())), msg.into())) } } + +impl Web3ProxyError { + pub fn into_message(self, id: Option>) -> Message { + let (_, err) = self.as_response_parts(); + + let id = id.unwrap_or_default(); + + let err = JsonRpcForwardedResponse::from_response_data(err, id); + + let msg = serde_json::to_string(&err).expect("errors should always serialize to json"); + + // TODO: what about a binary message? + Message::Text(msg) + } +} diff --git a/web3_proxy/src/frontend/rpc_proxy_ws.rs b/web3_proxy/src/frontend/rpc_proxy_ws.rs index 1f37585c..3ba1a020 100644 --- a/web3_proxy/src/frontend/rpc_proxy_ws.rs +++ b/web3_proxy/src/frontend/rpc_proxy_ws.rs @@ -10,7 +10,6 @@ use crate::{ errors::Web3ProxyResult, jsonrpc::{JsonRpcForwardedResponse, JsonRpcForwardedResponseEnum, JsonRpcRequest}, }; -use anyhow::Context; use axum::headers::{Origin, Referer, UserAgent}; use axum::{ extract::ws::{Message, WebSocket, WebSocketUpgrade}, @@ -20,8 +19,6 @@ use axum::{ }; use axum_client_ip::InsecureClientIp; use axum_macros::debug_handler; -use ethers::types::U64; -use fstrings::{f, format_args_f}; use futures::SinkExt; use futures::{ future::AbortHandle, @@ -254,9 +251,6 @@ async fn _websocket_handler_with_key( } None => { // if no websocket upgrade, this is probably a user loading the url with their browser - - // TODO: rate limit here? key_is_authorized might be enough - match ( &app.config.redirect_public_url, &app.config.redirect_rpc_key_url, @@ -313,23 +307,11 @@ async fn handle_socket_payload( payload: &str, response_sender: &flume::Sender, subscription_count: &AtomicU64, - subscriptions: Arc>>, + subscriptions: Arc>>, ) -> Web3ProxyResult<(Message, Option)> { - let (authorization, semaphore) = match authorization.check_again(&app).await { - Ok((a, s)) => (a, s), - Err(err) => { - let (_, err) = err.as_response_parts(); + let (authorization, semaphore) = authorization.check_again(&app).await?; - let err = JsonRpcForwardedResponse::from_response_data(err, Default::default()); - - let err = serde_json::to_string(&err)?; - - return Ok((Message::Text(err), None)); - } - }; - - // TODO: do any clients send batches over websockets? - // TODO: change response into response_data + // TODO: handle batched requests let (response_id, response) = match serde_json::from_str::(payload) { Ok(json_request) => { let response_id = json_request.id.clone(); @@ -350,20 +332,10 @@ async fn handle_socket_payload( .await { Ok((handle, response)) => { - { + if let Some(subscription_id) = response.result.clone() { let mut x = subscriptions.write().await; - - let result: &serde_json::value::RawValue = response - .result - .as_ref() - .context("there should be a result here")?; - - // TODO: there must be a better way to turn a RawValue - let k: U64 = serde_json::from_str(result.get()) - .context("subscription ids must be U64s")?; - - x.insert(k, handle); - }; + x.insert(subscription_id.get().to_string(), handle); + } Ok(response.into()) } @@ -375,34 +347,25 @@ async fn handle_socket_payload( RequestMetadata::new(&app, authorization.clone(), &json_request, None) .await; - let subscription_id: U64 = if json_request.params.is_array() { - if let Some(params) = json_request.params.get(0) { - serde_json::from_value(params.clone()).map_err(|err| { - Web3ProxyError::BadRequest( - format!("invalid params for eth_unsubscribe: {}", err).into(), - ) - })? + let subscription_id = + if let Some(params) = json_request.params.get(0).and_then(|x| x.as_str()) { + params + } else if let Some(param) = json_request.params.as_str() { + param } else { return Err(Web3ProxyError::BadRequest( - f!("no params for eth_unsubscribe").into(), + format!( + "unexpected params given for eth_unsubscribe ({:?})", + json_request.params + ) + .into(), )); - } - } else if json_request.params.is_string() { - serde_json::from_value(json_request.params).map_err(|err| { - Web3ProxyError::BadRequest( - format!("invalid params for eth_unsubscribe: {}", err).into(), - ) - })? - } else { - return Err(Web3ProxyError::BadRequest( - "unexpected params given for eth_unsubscribe".into(), - )); - }; + }; // TODO: is this the right response? let partial_response = { let mut x = subscriptions.write().await; - match x.remove(&subscription_id) { + match x.remove(subscription_id) { None => false, Some(handle) => { handle.abort(); @@ -411,7 +374,6 @@ async fn handle_socket_payload( } }; - // TODO: don't create the response here. use a JsonRpcResponseData instead let response = JsonRpcForwardedResponse::from_value( json!(partial_response), response_id.clone(), @@ -477,9 +439,7 @@ async fn read_web3_socket( // new message from our client. forward to a backend and then send it through response_sender let (response_msg, _semaphore) = match msg { Message::Text(ref payload) => { - // TODO: do not unwrap! turn errors into a jsonrpc response and send that instead - // TODO: some providers close the connection on error. i don't like that - let (m, s) = handle_socket_payload( + match handle_socket_payload( app.clone(), &authorization, payload, @@ -487,9 +447,14 @@ async fn read_web3_socket( &subscription_count, subscriptions, ) - .await.unwrap(); - - (m, Some(s)) + .await { + Ok((m, s)) => (m, Some(s)), + Err(err) => { + // TODO: how can we get the id out of the payload? + let m = err.into_message(None); + (m, None) + } + } } Message::Ping(x) => { trace!("ping: {:?}", x); @@ -508,8 +473,7 @@ async fn read_web3_socket( Message::Binary(mut payload) => { let payload = from_utf8_mut(&mut payload).unwrap(); - // TODO: do not unwrap! turn errors into a jsonrpc response and send that instead - let (m, s) = handle_socket_payload( + match handle_socket_payload( app.clone(), &authorization, payload, @@ -517,9 +481,13 @@ async fn read_web3_socket( &subscription_count, subscriptions, ) - .await.unwrap(); - - (m, Some(s)) + .await { + Ok((m, s)) => (m, Some(s)), + Err(err) => { + let m = err.into_message(None); + (m, None) + } + } } };