From 710cef5da3b67f65863b40566bc8acd93711ccca Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Tue, 26 Apr 2022 06:54:24 +0000 Subject: [PATCH] use channels to return early --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/main.rs | 496 ++++++++++++++++++++++++++++++++++++---------------- 3 files changed, 344 insertions(+), 156 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 41596019..e0f332ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3739,9 +3739,9 @@ dependencies = [ "governor", "regex", "reqwest", + "serde", "serde_json", "tokio", - "tokio-tungstenite 0.17.1", "tracing", "tracing-subscriber", "url", diff --git a/Cargo.toml b/Cargo.toml index 737abc7f..dab05462 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,8 @@ governor = { version = "0.4.2", features = ["dashmap", "std"] } tokio = { version = "1.17.0", features = ["full"] } regex = "1.5.5" reqwest = { version = "0.11.10", features = ["json"] } +serde = {version = "1.0"} serde_json = { version = "1.0.79", default-features = false, features = ["alloc"] } -tokio-tungstenite = "0.17.1" tracing = "0.1" tracing-subscriber = "0.3" url = "2.2.2" diff --git a/src/main.rs b/src/main.rs index 3949b260..a4ef4802 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,47 +1,196 @@ -use dashmap::DashMap; +// TODO: don't use dashmap. we need something for async + +use ethers::prelude::{Block, TxHash}; +use ethers::providers::Middleware; use futures::future; -use futures::future::{AbortHandle, Abortable}; -use futures::SinkExt; use futures::StreamExt; use governor::clock::{Clock, QuantaClock, QuantaInstant}; use governor::middleware::NoOpMiddleware; use governor::state::{InMemoryState, NotKeyed}; use governor::{NotUntil, RateLimiter}; -use regex::Regex; +use std::cmp::Ordering; +use std::collections::HashMap; use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use tokio::time::sleep; -use tokio_tungstenite::{connect_async, tungstenite}; use warp::Filter; -type RateLimiterMap = DashMap; +static APP_USER_AGENT: &str = concat!( + "satoshiandkin/", + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION"), +); + +// TODO: i'm not sure we need this. i think we can use dyn +enum EthersProvider { + Http(ethers::providers::Provider), + Ws(ethers::providers::Provider), +} + +// TODO: seems like this should be derivable +impl From> for EthersProvider { + fn from(item: ethers::providers::Provider) -> Self { + EthersProvider::Http(item) + } +} + +// TODO: seems like this should be derivable +impl From> for EthersProvider { + fn from(item: ethers::providers::Provider) -> Self { + EthersProvider::Ws(item) + } +} + +impl EthersProvider { + pub async fn request( + &self, + method: &str, + params: serde_json::Value, + ) -> Result { + match self { + Self::Http(provider) => provider.request(method, params).await, + Self::Ws(provider) => provider.request(method, params).await, + } + } +} + +struct EthersConnection { + /// keep track of currently open requests. We sort on this + active_requests: u32, + provider: Arc, +} + +impl EthersConnection { + async fn try_new( + url_str: &str, + http_client: Option, + blocks: Arc, + ) -> anyhow::Result { + // TODO: create an ethers-rs rpc client and subscribe/watch new heads in a spawned task + let provider = if url_str.starts_with("http") { + let url: url::Url = url_str.try_into()?; + + let http_client = http_client.ok_or_else(|| anyhow::anyhow!("no http_client"))?; + + ethers::providers::Http::new_with_client(url, http_client) + } else if url_str.starts_with("ws") { + // ethers::providers::Ws::connect(s.to_string()).await? + // TODO: make sure this survives disconnects + unimplemented!(); + } else { + return Err(anyhow::anyhow!("only http and ws servers are supported")); + }; + + let provider = ethers::providers::Provider::new(provider) + .interval(Duration::from_secs(1)) + .into(); + + match &provider { + EthersProvider::Http(provider) => { + let mut stream = provider.watch_blocks().await?.take(3); + while let Some(block_number) = stream.next().await { + let block = provider.get_block(block_number).await?.unwrap(); + + println!( + "{:?} = Ts: {:?}, block number: {}", + block.hash.unwrap(), + block.timestamp, + block.number.unwrap(), + ); + + let mut blocks = blocks.write().await; + + blocks.insert(url_str.to_string(), block); + } + } + EthersProvider::Ws(provider) => { + let mut stream = provider.subscribe_blocks().await?; + while let Some(block) = stream.next().await { + // TODO: save the block into a dashmap on + println!( + "{:?} = Ts: {:?}, block number: {}", + block.hash.unwrap(), + block.timestamp, + block.number.unwrap(), + ); + } + } + } + + // TODO: subscribe to new_heads + // TODO: if http, maybe we should check them all on the same interval. and if there is at least one websocket, use them for the interval + + Ok(EthersConnection { + active_requests: 0, + provider: Arc::new(provider), + }) + } + + fn inc(&mut self) { + self.active_requests += 1; + } + + fn dec(&mut self) { + self.active_requests -= 1; + } +} + +impl Eq for EthersConnection {} + +impl Ord for EthersConnection { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.active_requests.cmp(&other.active_requests) + } +} + +impl PartialOrd for EthersConnection { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for EthersConnection { + fn eq(&self, other: &Self) -> bool { + self.active_requests == other.active_requests + } +} + +type BlockMap = RwLock>>; +type RateLimiterMap = RwLock>; // TODO: include the ethers client on this map -type ConnectionsMap = DashMap; +type ConnectionsMap = RwLock>; type RpcRateLimiter = RateLimiter>; /// Load balance to the least-connection rpc -struct BalancedRpcs { +struct RpcTier { rpcs: RwLock>, - connections: ConnectionsMap, + connections: Arc, ratelimits: RateLimiterMap, } -impl BalancedRpcs { - fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> BalancedRpcs { +impl RpcTier { + async fn try_new( + servers: Vec<(&str, u32)>, + http_client: Option, + blocks: Arc, + clock: &QuantaClock, + ) -> anyhow::Result { let mut rpcs: Vec = vec![]; - let connections = DashMap::new(); - let ratelimits = DashMap::new(); + let mut connections = HashMap::new(); + let mut ratelimits = HashMap::new(); for (s, limit) in servers.into_iter() { rpcs.push(s.to_string()); - // TODO: subscribe to new_heads. if websocket, this is easy. otherwise we + let connection = + EthersConnection::try_new(s, http_client.clone(), blocks.clone()).await?; - connections.insert(s.to_string(), 0); + connections.insert(s.to_string(), connection); if limit > 0 { let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); @@ -52,6 +201,7 @@ impl BalancedRpcs { } } + /* let new_heads_handles = rpcs .clone() .into_iter() @@ -108,31 +258,33 @@ impl BalancedRpcs { abort_handle }) .collect(); + */ - BalancedRpcs { + Ok(RpcTier { rpcs: RwLock::new(rpcs), - connections, - ratelimits, - new_heads_handles, - } + connections: Arc::new(RwLock::new(connections)), + ratelimits: RwLock::new(ratelimits), + }) } /// get the best available rpc server - async fn get_upstream_server(&self) -> Result> { + async fn next_upstream_server(&self) -> Result> { let mut balanced_rpcs = self.rpcs.write().await; - balanced_rpcs.sort_unstable_by(|a, b| { - self.connections - .get(a) - .unwrap() - .cmp(&self.connections.get(b).unwrap()) - }); + // sort rpcs by their active connections + let connections = self.connections.read().await; + + balanced_rpcs + .sort_unstable_by(|a, b| connections.get(a).unwrap().cmp(connections.get(b).unwrap())); let mut earliest_not_until = None; for selected_rpc in balanced_rpcs.iter() { + // TODO: check current block number. if behind, make our own NotUntil here + let ratelimits = self.ratelimits.write().await; + // check rate limits - match self.ratelimits.get(selected_rpc).unwrap().check() { + match ratelimits.get(selected_rpc).unwrap().check() { Ok(_) => { // rate limit succeeded } @@ -155,8 +307,12 @@ impl BalancedRpcs { }; // increment our connection counter - let mut connections = self.connections.get_mut(selected_rpc).unwrap(); - *connections += 1; + self.connections + .write() + .await + .get_mut(selected_rpc) + .unwrap() + .inc(); // return the selected RPC return Ok(selected_rpc.clone()); @@ -169,36 +325,6 @@ impl BalancedRpcs { unimplemented!(); } } -} - -/// Send to all the Rpcs -/// Unlike BalancedRpcs, there is no tracking of connections -/// We do still track rate limits -struct LoudRpcs { - rpcs: Vec, - // TODO: what type? store with connections? - ratelimits: RateLimiterMap, -} - -impl LoudRpcs { - fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> LoudRpcs { - let mut rpcs: Vec = vec![]; - let ratelimits = RateLimiterMap::new(); - - for (s, limit) in servers.into_iter() { - rpcs.push(s.to_string()); - - if limit > 0 { - let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); - - let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock); - - ratelimits.insert(s.to_string(), rate_limiter); - } - } - - LoudRpcs { rpcs, ratelimits } - } /// get all available rpc servers async fn get_upstream_servers(&self) -> Result, NotUntil> { @@ -206,9 +332,16 @@ impl LoudRpcs { let mut selected_rpcs = vec![]; - for selected_rpc in self.rpcs.iter() { + for selected_rpc in self.rpcs.read().await.iter() { // check rate limits - match self.ratelimits.get(selected_rpc).unwrap().check() { + match self + .ratelimits + .write() + .await + .get(selected_rpc) + .unwrap() + .check() + { Ok(_) => { // rate limit succeeded } @@ -230,6 +363,14 @@ impl LoudRpcs { } }; + // increment our connection counter + self.connections + .write() + .await + .get_mut(selected_rpc) + .unwrap() + .inc(); + // this is rpc should work selected_rpcs.push(selected_rpc.clone()); } @@ -242,52 +383,71 @@ impl LoudRpcs { if let Some(not_until) = earliest_not_until { Err(not_until) } else { - panic!("i don't think this should happen") + // TODO: is this right? + Ok(vec![]) } } - - fn as_bool(&self) -> bool { - !self.rpcs.is_empty() - } } struct Web3ProxyState { clock: QuantaClock, - client: reqwest::Client, - // TODO: LoudRpcs and BalancedRpcs should probably share a trait or something - balanced_rpc_tiers: Vec, - private_rpcs: LoudRpcs, - /// lock this when all rate limiters are hit + balanced_rpc_tiers: Arc>, + private_rpcs: Option>, + /// write lock on these when all rate limits are hit balanced_rpc_ratelimiter_lock: RwLock<()>, private_rpcs_ratelimiter_lock: RwLock<()>, } impl Web3ProxyState { - fn new( + async fn try_new( balanced_rpc_tiers: Vec>, private_rpcs: Vec<(&str, u32)>, - ) -> Web3ProxyState { + ) -> anyhow::Result { let clock = QuantaClock::default(); - let balanced_rpc_tiers = balanced_rpc_tiers - .into_iter() - .map(|servers| BalancedRpcs::new(servers, &clock)) - .collect(); + let blocks = Arc::new(BlockMap::default()); - let private_rpcs = LoudRpcs::new(private_rpcs, &clock); + // TODO: 5 minutes is probably long enough. unlimited is a bad idea if something + let http_client = reqwest::ClientBuilder::new() + .timeout(Duration::from_secs(300)) + .user_agent(APP_USER_AGENT) + .build()?; + + // TODO: i'm sure we s + let balanced_rpc_tiers = Arc::new( + future::join_all(balanced_rpc_tiers.into_iter().map(|balanced_rpc_tier| { + RpcTier::try_new( + balanced_rpc_tier, + Some(http_client.clone()), + blocks.clone(), + &clock, + ) + })) + .await + .into_iter() + .collect::>>()?, + ); + + let private_rpcs = if private_rpcs.is_empty() { + None + } else { + Some(Arc::new( + RpcTier::try_new(private_rpcs, Some(http_client), blocks.clone(), &clock).await?, + )) + }; // TODO: warn if no private relays - Web3ProxyState { + Ok(Web3ProxyState { clock, - client: reqwest::Client::new(), balanced_rpc_tiers, private_rpcs, balanced_rpc_ratelimiter_lock: Default::default(), private_rpcs_ratelimiter_lock: Default::default(), - } + }) } /// send the request to the approriate RPCs + /// TODO: dry this up async fn proxy_web3_rpc( self: Arc, json_body: serde_json::Value, @@ -295,20 +455,34 @@ impl Web3ProxyState { let eth_send_raw_transaction = serde_json::Value::String("eth_sendRawTransaction".to_string()); - if self.private_rpcs.as_bool() && json_body.get("method") == Some(ð_send_raw_transaction) + if self.private_rpcs.is_some() && json_body.get("method") == Some(ð_send_raw_transaction) { + let private_rpcs = self.private_rpcs.clone().unwrap(); + // there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs loop { let read_lock = self.private_rpcs_ratelimiter_lock.read().await; - match self.private_rpcs.get_upstream_servers().await { + match private_rpcs.get_upstream_servers().await { Ok(upstream_servers) => { - if let Ok(result) = self - .try_send_requests(upstream_servers, None, &json_body) + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let clone = self.clone(); + let connections = private_rpcs.connections.clone(); + let json_body = json_body.clone(); + + tokio::spawn(async move { + clone + .try_send_requests(upstream_servers, connections, json_body, tx) + .await + }); + + let response = rx + .recv() .await - { - return Ok(result); - } + .ok_or_else(|| anyhow::anyhow!("no response"))?; + + return Ok(warp::reply::json(&response)); } Err(not_until) => { // TODO: move this to a helper function @@ -333,19 +507,31 @@ impl Web3ProxyState { let mut earliest_not_until = None; for balanced_rpcs in self.balanced_rpc_tiers.iter() { - match balanced_rpcs.get_upstream_server().await { + match balanced_rpcs.next_upstream_server().await { Ok(upstream_server) => { - // TODO: capture any errors. at least log them - if let Ok(result) = self - .try_send_requests( - vec![upstream_server], - Some(&balanced_rpcs.connections), - &json_body, - ) + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let clone = self.clone(); + let connections = balanced_rpcs.connections.clone(); + let json_body = json_body.clone(); + + tokio::spawn(async move { + clone + .try_send_requests( + vec![upstream_server], + connections, + json_body, + tx, + ) + .await + }); + + let response = rx + .recv() .await - { - return Ok(result); - } + .ok_or_else(|| anyhow::anyhow!("no response"))?; + + return Ok(warp::reply::json(&response)); } Err(not_until) => { // save the smallest not_until. if nothing succeeds, return an Err with not_until in it @@ -382,59 +568,62 @@ impl Web3ProxyState { async fn try_send_requests( &self, - upstream_servers: Vec, - connections: Option<&ConnectionsMap>, - json_body: &serde_json::Value, - ) -> anyhow::Result { + rpc_servers: Vec, + connections: Arc, + json_request_body: serde_json::Value, + tx: mpsc::UnboundedSender, + ) -> anyhow::Result<()> { + // {"jsonrpc":"2.0","method":"eth_syncing","params":[],"id":1} + let incoming_id = json_request_body + .get("id") + .ok_or_else(|| anyhow::anyhow!("bad id"))? + .to_owned(); + let method = json_request_body + .get("method") + .and_then(|x| x.as_str()) + .ok_or_else(|| anyhow::anyhow!("bad id"))? + .to_string(); + let params = json_request_body + .get("params") + .ok_or_else(|| anyhow::anyhow!("no params"))? + .to_owned(); + // send the query to all the servers - let bodies = future::join_all(upstream_servers.into_iter().map(|url| { - let client = self.client.clone(); - let json_body = json_body.clone(); - tokio::spawn(async move { + let bodies = future::join_all(rpc_servers.into_iter().map(|rpc| { + let incoming_id = incoming_id.clone(); + let connections = connections.clone(); + let method = method.clone(); + let params = params.clone(); + let tx = tx.clone(); + + async move { + // get the client for this rpc server + let provider = connections.read().await.get(&rpc).unwrap().provider.clone(); + // TODO: there has to be a better way to attach the url to the result - client - .post(&url) - .json(&json_body) - .send() - .await - // add the url to the error so that we can reduce connection counters - .map_err(|e| (url.clone(), e))? - .text() - .await - // add the url to the result so that we can reduce connection counters - .map(|t| (url.clone(), t)) - // add the url to the error so that we can reduce connection counters - .map_err(|e| (url, e)) - }) + let mut response = provider.request(&method, params).await?; + + connections.write().await.get_mut(&rpc).unwrap().dec(); + + if let Some(response_id) = response.get_mut("id") { + *response_id = incoming_id; + } + + // send the first good response to a one shot channel. that way we respond quickly + // drop the result because errors are expected after the first send + // TODO: if "no block with that header" or some other jsonrpc errors, skip this response + let _ = tx.send(response); + + Ok::<(), anyhow::Error>(()) + } })) .await; - // we are going to collect successes and failures - let mut oks = vec![]; + // TODO: use iterators instead of pushing into a vec let mut errs = vec![]; - - // TODO: parallel? - for b in bodies { - match b { - Ok(Ok((url, b))) => { - // reduce connection counter - if let Some(connections) = connections { - *connections.get_mut(&url).unwrap() -= 1; - } - - // TODO: if "no block with that header" or some other jsonrpc errors, skip this response - oks.push(b); - } - Ok(Err((url, e))) => { - // reduce connection counter - if let Some(connections) = connections { - *connections.get_mut(&url).unwrap() -= 1; - } - - // TODO: better errors - eprintln!("Got a reqwest::Error: {}", e); - errs.push(anyhow::anyhow!("Got a reqwest::Error")); - } + for x in bodies { + match x { + Ok(_) => {} Err(e) => { // TODO: better errors eprintln!("Got a tokio::JoinError: {}", e); @@ -443,10 +632,7 @@ impl Web3ProxyState { } } - // TODO: which response should we use? - if !oks.is_empty() { - Ok(oks.pop().unwrap()) - } else if !errs.is_empty() { + if !errs.is_empty() { Err(errs.pop().unwrap()) } else { return Err(anyhow::anyhow!("no successful responses")); @@ -462,7 +648,7 @@ async fn main() { let listen_port = 8445; // TODO: be smart about about using archive nodes? - let state = Web3ProxyState::new( + let state = Web3ProxyState::try_new( vec![ // local nodes vec![("ws://10.11.12.16:8545", 0), ("ws://10.11.12.16:8946", 0)], @@ -479,7 +665,9 @@ async fn main() { ("https://api.edennetwork.io/v1/beta", 0), ("https://api.edennetwork.io/v1/", 0), ], - ); + ) + .await + .unwrap(); let state: Arc = Arc::new(state);