improve websocket error handling

This commit is contained in:
Bryan Stitt 2023-06-15 09:50:21 -07:00
parent 5859cd8a8d
commit 49c60ac1b5
2 changed files with 51 additions and 67 deletions

@ -4,6 +4,7 @@ use crate::frontend::authorization::Authorization;
use crate::jsonrpc::{JsonRpcErrorData, JsonRpcForwardedResponse}; use crate::jsonrpc::{JsonRpcErrorData, JsonRpcForwardedResponse};
use crate::response_cache::JsonRpcResponseEnum; use crate::response_cache::JsonRpcResponseEnum;
use crate::rpcs::provider::EthersHttpProvider; use crate::rpcs::provider::EthersHttpProvider;
use axum::extract::ws::Message;
use axum::{ use axum::{
headers, headers,
http::StatusCode, http::StatusCode,
@ -1062,3 +1063,18 @@ where
self.map_err(|err| Web3ProxyError::WithContext(Some(Box::new(err.into())), msg.into())) self.map_err(|err| Web3ProxyError::WithContext(Some(Box::new(err.into())), msg.into()))
} }
} }
impl Web3ProxyError {
pub fn into_message(self, id: Option<Box<RawValue>>) -> Message {
let (_, err) = self.as_response_parts();
let id = id.unwrap_or_default();
let err = JsonRpcForwardedResponse::from_response_data(err, id);
let msg = serde_json::to_string(&err).expect("errors should always serialize to json");
// TODO: what about a binary message?
Message::Text(msg)
}
}

@ -10,7 +10,6 @@ use crate::{
errors::Web3ProxyResult, errors::Web3ProxyResult,
jsonrpc::{JsonRpcForwardedResponse, JsonRpcForwardedResponseEnum, JsonRpcRequest}, jsonrpc::{JsonRpcForwardedResponse, JsonRpcForwardedResponseEnum, JsonRpcRequest},
}; };
use anyhow::Context;
use axum::headers::{Origin, Referer, UserAgent}; use axum::headers::{Origin, Referer, UserAgent};
use axum::{ use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade}, extract::ws::{Message, WebSocket, WebSocketUpgrade},
@ -20,8 +19,6 @@ use axum::{
}; };
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use axum_macros::debug_handler; use axum_macros::debug_handler;
use ethers::types::U64;
use fstrings::{f, format_args_f};
use futures::SinkExt; use futures::SinkExt;
use futures::{ use futures::{
future::AbortHandle, future::AbortHandle,
@ -254,9 +251,6 @@ async fn _websocket_handler_with_key(
} }
None => { None => {
// if no websocket upgrade, this is probably a user loading the url with their browser // if no websocket upgrade, this is probably a user loading the url with their browser
// TODO: rate limit here? key_is_authorized might be enough
match ( match (
&app.config.redirect_public_url, &app.config.redirect_public_url,
&app.config.redirect_rpc_key_url, &app.config.redirect_rpc_key_url,
@ -313,23 +307,11 @@ async fn handle_socket_payload(
payload: &str, payload: &str,
response_sender: &flume::Sender<Message>, response_sender: &flume::Sender<Message>,
subscription_count: &AtomicU64, subscription_count: &AtomicU64,
subscriptions: Arc<RwLock<HashMap<U64, AbortHandle>>>, subscriptions: Arc<RwLock<HashMap<String, AbortHandle>>>,
) -> Web3ProxyResult<(Message, Option<OwnedSemaphorePermit>)> { ) -> Web3ProxyResult<(Message, Option<OwnedSemaphorePermit>)> {
let (authorization, semaphore) = match authorization.check_again(&app).await { let (authorization, semaphore) = authorization.check_again(&app).await?;
Ok((a, s)) => (a, s),
Err(err) => {
let (_, err) = err.as_response_parts();
let err = JsonRpcForwardedResponse::from_response_data(err, Default::default()); // TODO: handle batched requests
let err = serde_json::to_string(&err)?;
return Ok((Message::Text(err), None));
}
};
// TODO: do any clients send batches over websockets?
// TODO: change response into response_data
let (response_id, response) = match serde_json::from_str::<JsonRpcRequest>(payload) { let (response_id, response) = match serde_json::from_str::<JsonRpcRequest>(payload) {
Ok(json_request) => { Ok(json_request) => {
let response_id = json_request.id.clone(); let response_id = json_request.id.clone();
@ -350,20 +332,10 @@ async fn handle_socket_payload(
.await .await
{ {
Ok((handle, response)) => { Ok((handle, response)) => {
{ if let Some(subscription_id) = response.result.clone() {
let mut x = subscriptions.write().await; let mut x = subscriptions.write().await;
x.insert(subscription_id.get().to_string(), handle);
let result: &serde_json::value::RawValue = response }
.result
.as_ref()
.context("there should be a result here")?;
// TODO: there must be a better way to turn a RawValue
let k: U64 = serde_json::from_str(result.get())
.context("subscription ids must be U64s")?;
x.insert(k, handle);
};
Ok(response.into()) Ok(response.into())
} }
@ -375,34 +347,25 @@ async fn handle_socket_payload(
RequestMetadata::new(&app, authorization.clone(), &json_request, None) RequestMetadata::new(&app, authorization.clone(), &json_request, None)
.await; .await;
let subscription_id: U64 = if json_request.params.is_array() { let subscription_id =
if let Some(params) = json_request.params.get(0) { if let Some(params) = json_request.params.get(0).and_then(|x| x.as_str()) {
serde_json::from_value(params.clone()).map_err(|err| { params
Web3ProxyError::BadRequest( } else if let Some(param) = json_request.params.as_str() {
format!("invalid params for eth_unsubscribe: {}", err).into(), param
)
})?
} else { } else {
return Err(Web3ProxyError::BadRequest( return Err(Web3ProxyError::BadRequest(
f!("no params for eth_unsubscribe").into(), format!(
"unexpected params given for eth_unsubscribe ({:?})",
json_request.params
)
.into(),
)); ));
} };
} else if json_request.params.is_string() {
serde_json::from_value(json_request.params).map_err(|err| {
Web3ProxyError::BadRequest(
format!("invalid params for eth_unsubscribe: {}", err).into(),
)
})?
} else {
return Err(Web3ProxyError::BadRequest(
"unexpected params given for eth_unsubscribe".into(),
));
};
// TODO: is this the right response? // TODO: is this the right response?
let partial_response = { let partial_response = {
let mut x = subscriptions.write().await; let mut x = subscriptions.write().await;
match x.remove(&subscription_id) { match x.remove(subscription_id) {
None => false, None => false,
Some(handle) => { Some(handle) => {
handle.abort(); handle.abort();
@ -411,7 +374,6 @@ async fn handle_socket_payload(
} }
}; };
// TODO: don't create the response here. use a JsonRpcResponseData instead
let response = JsonRpcForwardedResponse::from_value( let response = JsonRpcForwardedResponse::from_value(
json!(partial_response), json!(partial_response),
response_id.clone(), response_id.clone(),
@ -477,9 +439,7 @@ async fn read_web3_socket(
// new message from our client. forward to a backend and then send it through response_sender // new message from our client. forward to a backend and then send it through response_sender
let (response_msg, _semaphore) = match msg { let (response_msg, _semaphore) = match msg {
Message::Text(ref payload) => { Message::Text(ref payload) => {
// TODO: do not unwrap! turn errors into a jsonrpc response and send that instead match handle_socket_payload(
// TODO: some providers close the connection on error. i don't like that
let (m, s) = handle_socket_payload(
app.clone(), app.clone(),
&authorization, &authorization,
payload, payload,
@ -487,9 +447,14 @@ async fn read_web3_socket(
&subscription_count, &subscription_count,
subscriptions, subscriptions,
) )
.await.unwrap(); .await {
Ok((m, s)) => (m, Some(s)),
(m, Some(s)) Err(err) => {
// TODO: how can we get the id out of the payload?
let m = err.into_message(None);
(m, None)
}
}
} }
Message::Ping(x) => { Message::Ping(x) => {
trace!("ping: {:?}", x); trace!("ping: {:?}", x);
@ -508,8 +473,7 @@ async fn read_web3_socket(
Message::Binary(mut payload) => { Message::Binary(mut payload) => {
let payload = from_utf8_mut(&mut payload).unwrap(); let payload = from_utf8_mut(&mut payload).unwrap();
// TODO: do not unwrap! turn errors into a jsonrpc response and send that instead match handle_socket_payload(
let (m, s) = handle_socket_payload(
app.clone(), app.clone(),
&authorization, &authorization,
payload, payload,
@ -517,9 +481,13 @@ async fn read_web3_socket(
&subscription_count, &subscription_count,
subscriptions, subscriptions,
) )
.await.unwrap(); .await {
Ok((m, s)) => (m, Some(s)),
(m, Some(s)) Err(err) => {
let m = err.into_message(None);
(m, None)
}
}
} }
}; };