From 0387492df8b6708e7c8ced69b3719a60591e99ed Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Sun, 29 May 2022 17:28:41 +0000 Subject: [PATCH] add websocket server --- Cargo.lock | 2 + web3-proxy/Cargo.toml | 2 +- web3-proxy/src/app.rs | 8 +-- web3-proxy/src/frontend.rs | 112 ++++++++++++++++++++++++++++++------- web3-proxy/src/jsonrpc.rs | 17 ++++++ 5 files changed, 115 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 15d40a9d..5aef90d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,6 +183,7 @@ checksum = "ab2504b827a8bef941ba3dd64bdffe9cf56ca182908a147edd6189c95fbcae7d" dependencies = [ "async-trait", "axum-core", + "base64 0.13.0", "bitflags", "bytes", "futures-util", @@ -198,6 +199,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "sha-1", "sync_wrapper", "tokio", "tokio-tungstenite", diff --git a/web3-proxy/Cargo.toml b/web3-proxy/Cargo.toml index 0bf59ab0..14289968 100644 --- a/web3-proxy/Cargo.toml +++ b/web3-proxy/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" anyhow = "1.0.57" arc-swap = "1.5.0" argh = "0.1.7" -axum = { version = "0.5.6", features = ["serde_json", "tokio-tungstenite"] } +axum = { version = "0.5.6", features = ["serde_json", "tokio-tungstenite", "ws"] } counter = "0.5.5" dashmap = "5.3.4" derive_more = "0.99.17" diff --git a/web3-proxy/src/app.rs b/web3-proxy/src/app.rs index 9f3d04df..ce6543f1 100644 --- a/web3-proxy/src/app.rs +++ b/web3-proxy/src/app.rs @@ -141,7 +141,7 @@ impl Web3ProxyApp { /// TODO: dry this up #[instrument(skip_all)] pub async fn proxy_web3_rpc( - self: Arc, + &self, request: JsonRpcRequestEnum, ) -> anyhow::Result { // TODO: i don't always see this in the logs. why? @@ -168,7 +168,7 @@ impl Web3ProxyApp { // #[instrument(skip_all)] async fn proxy_web3_rpc_requests( - self: Arc, + &self, requests: Vec, ) -> anyhow::Result> { // TODO: we should probably change ethers-rs to support this directly @@ -179,7 +179,7 @@ impl Web3ProxyApp { let responses = join_all( requests .into_iter() - .map(|request| self.clone().proxy_web3_rpc_request(request)) + .map(|request| self.proxy_web3_rpc_request(request)) .collect::>(), ) .await; @@ -232,7 +232,7 @@ impl Web3ProxyApp { // #[instrument(skip_all)] async fn proxy_web3_rpc_request( - self: Arc, + &self, request: JsonRpcRequest, ) -> anyhow::Result { trace!("Received request: {:?}", request); diff --git a/web3-proxy/src/frontend.rs b/web3-proxy/src/frontend.rs index 1cad1d15..70afe576 100644 --- a/web3-proxy/src/frontend.rs +++ b/web3-proxy/src/frontend.rs @@ -1,14 +1,14 @@ /// this should move into web3-proxy once the basics are working use axum::{ - // error_handling::HandleError, + extract::ws::{Message, WebSocket, WebSocketUpgrade}, handler::Handler, http::StatusCode, response::IntoResponse, routing::{get, post}, - Extension, - Json, - Router, + Extension, Json, Router, }; +use futures::stream::{SplitSink, SplitStream, StreamExt}; +use futures::SinkExt; use serde_json::json; use serde_json::value::RawValue; use std::net::SocketAddr; @@ -17,7 +17,7 @@ use tracing::warn; use crate::{ app::Web3ProxyApp, - jsonrpc::{JsonRpcErrorData, JsonRpcForwardedResponse, JsonRpcRequestEnum}, + jsonrpc::{JsonRpcForwardedResponse, JsonRpcRequest, JsonRpcRequestEnum}, }; pub async fn run(port: u16, proxy_app: Arc) -> anyhow::Result<()> { @@ -27,6 +27,8 @@ pub async fn run(port: u16, proxy_app: Arc) -> anyhow::Result<()> .route("/", get(root)) // `POST /` goes to `proxy_web3_rpc` .route("/", post(proxy_web3_rpc)) + // `websocket /` goes to `proxy_web3_ws` + .route("/ws", get(websocket_handler)) // `GET /status` goes to `status` .route("/status", get(status)) .layer(Extension(proxy_app)); @@ -61,14 +63,92 @@ async fn proxy_web3_rpc( } } +async fn websocket_handler( + app: Extension>, + ws: WebSocketUpgrade, +) -> impl IntoResponse { + ws.on_upgrade(|socket| proxy_web3_socket(app, socket)) +} + +async fn proxy_web3_socket(app: Extension>, socket: WebSocket) { + // split the websocket so we can read and write concurrently + let (ws_tx, ws_rx) = socket.split(); + + // create a channel for our reader and writer can communicate. todo: benchmark different channels + let (response_tx, response_rx) = flume::unbounded::(); + + tokio::spawn(write_web3_socket(response_rx, ws_tx)); + tokio::spawn(read_web3_socket(app, ws_rx, response_tx)); +} + +async fn read_web3_socket( + app: Extension>, + mut ws_rx: SplitStream, + response_tx: flume::Sender, +) { + while let Some(Ok(msg)) = ws_rx.next().await { + // new message from our client. forward to a backend and then send it through response_tx + // TODO: spawn this processing? + let response_msg = match msg { + Message::Text(payload) => { + let (id, response) = match serde_json::from_str(&payload) { + Ok(payload) => { + let payload: JsonRpcRequest = payload; + + let id = payload.id.clone(); + + let payload = JsonRpcRequestEnum::Single(payload); + + (id, app.0.proxy_web3_rpc(payload).await) + } + Err(err) => { + let id = RawValue::from_string("-1".to_string()).unwrap(); + (id, Err(err.into())) + } + }; + + let response_str = match response { + Ok(x) => serde_json::to_string(&x), + Err(err) => { + // we have an anyhow error. turn it into + let response = JsonRpcForwardedResponse::from_anyhow_error(err, id); + serde_json::to_string(&response) + } + } + .unwrap(); + + Message::Text(response_str) + } + Message::Ping(x) => Message::Pong(x), + _ => unimplemented!(), + }; + + if response_tx.send_async(response_msg).await.is_err() { + // TODO: log the error + break; + }; + } +} + +async fn write_web3_socket( + response_rx: flume::Receiver, + mut ws_tx: SplitSink, +) { + while let Ok(msg) = response_rx.recv_async().await { + // a response is ready. write it to ws_tx + if ws_tx.send(msg).await.is_err() { + // TODO: log the error + break; + }; + } +} + /// Very basic status page async fn status(app: Extension>) -> impl IntoResponse { let app = app.0.as_ref(); let balanced_rpcs = app.get_balanced_rpcs(); - let private_rpcs = app.get_private_rpcs(); - let num_active_requests = app.get_active_requests().len(); // TODO: what else should we include? uptime? prometheus? @@ -91,22 +171,12 @@ async fn handler_404() -> impl IntoResponse { /// handle errors by converting them into something that implements `IntoResponse` /// TODO: use this. i can't get https://docs.rs/axum/latest/axum/error_handling/index.html to work async fn _handle_anyhow_error(err: anyhow::Error, code: Option) -> impl IntoResponse { - let err = format!("{:?}", err); + // TODO: what id can we use? how do we make sure the incoming id gets attached to this? + let id = RawValue::from_string("0".to_string()).unwrap(); - warn!("Responding with error: {}", err); + let err = JsonRpcForwardedResponse::from_anyhow_error(err, id); - let err = JsonRpcForwardedResponse { - jsonrpc: "2.0".to_string(), - // TODO: what id can we use? how do we make sure the incoming id gets attached to this? - id: RawValue::from_string("0".to_string()).unwrap(), - result: None, - error: Some(JsonRpcErrorData { - // TODO: set this jsonrpc error code to match the http status code - code: -32099, - message: err, - data: None, - }), - }; + warn!("Responding with error: {:?}", err); let code = code.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); diff --git a/web3-proxy/src/jsonrpc.rs b/web3-proxy/src/jsonrpc.rs index e1c226bc..3e1dd6e5 100644 --- a/web3-proxy/src/jsonrpc.rs +++ b/web3-proxy/src/jsonrpc.rs @@ -162,6 +162,23 @@ impl fmt::Debug for JsonRpcForwardedResponse { } impl JsonRpcForwardedResponse { + pub fn from_anyhow_error(err: anyhow::Error, id: Box) -> Self { + let err = format!("{:?}", err); + + JsonRpcForwardedResponse { + jsonrpc: "2.0".to_string(), + // TODO: what id can we use? how do we make sure the incoming id gets attached to this? + id, + result: None, + error: Some(JsonRpcErrorData { + // TODO: set this jsonrpc error code to match the http status code + code: -32099, + message: err, + data: None, + }), + } + } + pub fn from_response_result( result: Result, ProviderError>, id: Box,