add websocket server

This commit is contained in:
Bryan Stitt 2022-05-29 17:28:41 +00:00
parent ca4b757911
commit 0387492df8
5 changed files with 115 additions and 26 deletions

2
Cargo.lock generated
View File

@ -183,6 +183,7 @@ checksum = "ab2504b827a8bef941ba3dd64bdffe9cf56ca182908a147edd6189c95fbcae7d"
dependencies = [
"async-trait",
"axum-core",
"base64 0.13.0",
"bitflags",
"bytes",
"futures-util",
@ -198,6 +199,7 @@ dependencies = [
"serde",
"serde_json",
"serde_urlencoded",
"sha-1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",

View File

@ -9,7 +9,7 @@ edition = "2021"
anyhow = "1.0.57"
arc-swap = "1.5.0"
argh = "0.1.7"
axum = { version = "0.5.6", features = ["serde_json", "tokio-tungstenite"] }
axum = { version = "0.5.6", features = ["serde_json", "tokio-tungstenite", "ws"] }
counter = "0.5.5"
dashmap = "5.3.4"
derive_more = "0.99.17"

View File

@ -141,7 +141,7 @@ impl Web3ProxyApp {
/// TODO: dry this up
#[instrument(skip_all)]
pub async fn proxy_web3_rpc(
self: Arc<Web3ProxyApp>,
&self,
request: JsonRpcRequestEnum,
) -> anyhow::Result<JsonRpcForwardedResponseEnum> {
// TODO: i don't always see this in the logs. why?
@ -168,7 +168,7 @@ impl Web3ProxyApp {
// #[instrument(skip_all)]
async fn proxy_web3_rpc_requests(
self: Arc<Web3ProxyApp>,
&self,
requests: Vec<JsonRpcRequest>,
) -> anyhow::Result<Vec<JsonRpcForwardedResponse>> {
// TODO: we should probably change ethers-rs to support this directly
@ -179,7 +179,7 @@ impl Web3ProxyApp {
let responses = join_all(
requests
.into_iter()
.map(|request| self.clone().proxy_web3_rpc_request(request))
.map(|request| self.proxy_web3_rpc_request(request))
.collect::<Vec<_>>(),
)
.await;
@ -232,7 +232,7 @@ impl Web3ProxyApp {
// #[instrument(skip_all)]
async fn proxy_web3_rpc_request(
self: Arc<Web3ProxyApp>,
&self,
request: JsonRpcRequest,
) -> anyhow::Result<JsonRpcForwardedResponse> {
trace!("Received request: {:?}", request);

View File

@ -1,14 +1,14 @@
/// this should move into web3-proxy once the basics are working
use axum::{
// error_handling::HandleError,
extract::ws::{Message, WebSocket, WebSocketUpgrade},
handler::Handler,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Extension,
Json,
Router,
Extension, Json, Router,
};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use futures::SinkExt;
use serde_json::json;
use serde_json::value::RawValue;
use std::net::SocketAddr;
@ -17,7 +17,7 @@ use tracing::warn;
use crate::{
app::Web3ProxyApp,
jsonrpc::{JsonRpcErrorData, JsonRpcForwardedResponse, JsonRpcRequestEnum},
jsonrpc::{JsonRpcForwardedResponse, JsonRpcRequest, JsonRpcRequestEnum},
};
pub async fn run(port: u16, proxy_app: Arc<Web3ProxyApp>) -> anyhow::Result<()> {
@ -27,6 +27,8 @@ pub async fn run(port: u16, proxy_app: Arc<Web3ProxyApp>) -> anyhow::Result<()>
.route("/", get(root))
// `POST /` goes to `proxy_web3_rpc`
.route("/", post(proxy_web3_rpc))
// `websocket /` goes to `proxy_web3_ws`
.route("/ws", get(websocket_handler))
// `GET /status` goes to `status`
.route("/status", get(status))
.layer(Extension(proxy_app));
@ -61,14 +63,92 @@ async fn proxy_web3_rpc(
}
}
async fn websocket_handler(
app: Extension<Arc<Web3ProxyApp>>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
ws.on_upgrade(|socket| proxy_web3_socket(app, socket))
}
async fn proxy_web3_socket(app: Extension<Arc<Web3ProxyApp>>, socket: WebSocket) {
// split the websocket so we can read and write concurrently
let (ws_tx, ws_rx) = socket.split();
// create a channel for our reader and writer can communicate. todo: benchmark different channels
let (response_tx, response_rx) = flume::unbounded::<Message>();
tokio::spawn(write_web3_socket(response_rx, ws_tx));
tokio::spawn(read_web3_socket(app, ws_rx, response_tx));
}
async fn read_web3_socket(
app: Extension<Arc<Web3ProxyApp>>,
mut ws_rx: SplitStream<WebSocket>,
response_tx: flume::Sender<Message>,
) {
while let Some(Ok(msg)) = ws_rx.next().await {
// new message from our client. forward to a backend and then send it through response_tx
// TODO: spawn this processing?
let response_msg = match msg {
Message::Text(payload) => {
let (id, response) = match serde_json::from_str(&payload) {
Ok(payload) => {
let payload: JsonRpcRequest = payload;
let id = payload.id.clone();
let payload = JsonRpcRequestEnum::Single(payload);
(id, app.0.proxy_web3_rpc(payload).await)
}
Err(err) => {
let id = RawValue::from_string("-1".to_string()).unwrap();
(id, Err(err.into()))
}
};
let response_str = match response {
Ok(x) => serde_json::to_string(&x),
Err(err) => {
// we have an anyhow error. turn it into
let response = JsonRpcForwardedResponse::from_anyhow_error(err, id);
serde_json::to_string(&response)
}
}
.unwrap();
Message::Text(response_str)
}
Message::Ping(x) => Message::Pong(x),
_ => unimplemented!(),
};
if response_tx.send_async(response_msg).await.is_err() {
// TODO: log the error
break;
};
}
}
async fn write_web3_socket(
response_rx: flume::Receiver<Message>,
mut ws_tx: SplitSink<WebSocket, Message>,
) {
while let Ok(msg) = response_rx.recv_async().await {
// a response is ready. write it to ws_tx
if ws_tx.send(msg).await.is_err() {
// TODO: log the error
break;
};
}
}
/// Very basic status page
async fn status(app: Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse {
let app = app.0.as_ref();
let balanced_rpcs = app.get_balanced_rpcs();
let private_rpcs = app.get_private_rpcs();
let num_active_requests = app.get_active_requests().len();
// TODO: what else should we include? uptime? prometheus?
@ -91,22 +171,12 @@ async fn handler_404() -> impl IntoResponse {
/// handle errors by converting them into something that implements `IntoResponse`
/// TODO: use this. i can't get https://docs.rs/axum/latest/axum/error_handling/index.html to work
async fn _handle_anyhow_error(err: anyhow::Error, code: Option<StatusCode>) -> impl IntoResponse {
let err = format!("{:?}", err);
// TODO: what id can we use? how do we make sure the incoming id gets attached to this?
let id = RawValue::from_string("0".to_string()).unwrap();
warn!("Responding with error: {}", err);
let err = JsonRpcForwardedResponse::from_anyhow_error(err, id);
let err = JsonRpcForwardedResponse {
jsonrpc: "2.0".to_string(),
// TODO: what id can we use? how do we make sure the incoming id gets attached to this?
id: RawValue::from_string("0".to_string()).unwrap(),
result: None,
error: Some(JsonRpcErrorData {
// TODO: set this jsonrpc error code to match the http status code
code: -32099,
message: err,
data: None,
}),
};
warn!("Responding with error: {:?}", err);
let code = code.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);

View File

@ -162,6 +162,23 @@ impl fmt::Debug for JsonRpcForwardedResponse {
}
impl JsonRpcForwardedResponse {
pub fn from_anyhow_error(err: anyhow::Error, id: Box<RawValue>) -> Self {
let err = format!("{:?}", err);
JsonRpcForwardedResponse {
jsonrpc: "2.0".to_string(),
// TODO: what id can we use? how do we make sure the incoming id gets attached to this?
id,
result: None,
error: Some(JsonRpcErrorData {
// TODO: set this jsonrpc error code to match the http status code
code: -32099,
message: err,
data: None,
}),
}
}
pub fn from_response_result(
result: Result<Box<RawValue>, ProviderError>,
id: Box<RawValue>,