diff --git a/Cargo.lock b/Cargo.lock index a54b243c..714c1b33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3764,6 +3764,7 @@ dependencies = [ "ethers", "futures", "governor", + "parking_lot 0.12.0", "regex", "reqwest", "serde", diff --git a/Cargo.toml b/Cargo.toml index 8be50d8a..2ad455f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ ethers = { git = "https://github.com/gakonst/ethers-rs", features = ["rustls", " futures = { version = "0.3.21", features = ["thread-pool"] } governor = { version = "0.4.2", features = ["dashmap", "std"] } tokio = { version = "1.17.0", features = ["full"] } +parking_lot = "0.12" regex = "1.5.5" reqwest = { version = "0.11.10", features = ["json"] } serde = {version = "1.0"} diff --git a/src/block_watcher.rs b/src/block_watcher.rs index f4ef5bcd..faa6d4ac 100644 --- a/src/block_watcher.rs +++ b/src/block_watcher.rs @@ -5,7 +5,7 @@ use std::cmp; use std::sync::atomic::{self, AtomicU64}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, watch, Mutex}; use tracing::info; // TODO: what type for the Item? String url works, but i don't love it @@ -14,7 +14,6 @@ pub type NewHead = (String, Block); pub type BlockWatcherSender = mpsc::UnboundedSender; pub type BlockWatcherReceiver = mpsc::UnboundedReceiver; -#[derive(Eq)] // TODO: ethers has a similar SyncingStatus pub enum SyncStatus { Synced(u64), @@ -22,33 +21,10 @@ pub enum SyncStatus { Unknown, } -// impl Ord for SyncStatus { -// fn cmp(&self, other: &Self) -> cmp::Ordering { -// self.height.cmp(&other.height) -// } -// } - -// impl PartialOrd for SyncStatus { -// fn partial_cmp(&self, other: &Self) -> Option { -// Some(self.cmp(other)) -// } -// } - -impl PartialEq for SyncStatus { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Synced(a), Self::Synced(b)) => a == b, - (Self::Unknown, Self::Unknown) => true, - (Self::Behind(a), Self::Behind(b)) => a == b, - _ => false, - } - } -} - #[derive(Debug)] pub struct BlockWatcher { sender: BlockWatcherSender, - /// parking_lot::Mutex is supposed to be faster, but we only lock this once, so its fine + /// this Mutex is locked over awaits, so we want an async lock receiver: Mutex, block_numbers: DashMap, head_block_number: AtomicU64, @@ -100,7 +76,10 @@ impl BlockWatcher { } } - pub async fn run(self: Arc) -> anyhow::Result<()> { + pub async fn run( + self: Arc, + new_block_sender: watch::Sender, + ) -> anyhow::Result<()> { let mut receiver = self.receiver.lock().await; while let Some((rpc, new_block)) = receiver.recv().await { @@ -160,6 +139,9 @@ impl BlockWatcher { } }; + // have the provider tiers update_synced_rpcs + new_block_sender.send(rpc.clone())?; + // TODO: include time since last update? info!( "{:?} = {}, {}, {} sec{}", diff --git a/src/main.rs b/src/main.rs index a2d5a1a2..fa8b3c8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,12 @@ mod provider_tiers; use futures::future; use governor::clock::{Clock, QuantaClock}; use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{mpsc, watch, RwLock}; use tokio::time::sleep; -use tracing::{info, warn}; +use tracing::{instrument, warn}; use warp::Filter; // use crate::types::{BlockMap, ConnectionsMap, RpcRateLimiterMap}; @@ -24,8 +25,8 @@ static APP_USER_AGENT: &str = concat!( ); /// The application +#[derive(Debug)] struct Web3ProxyApp { - block_watcher: Arc, /// clock used for rate limiting /// TODO: use tokio's clock (will require a different ratelimiting crate) clock: QuantaClock, @@ -34,12 +35,16 @@ struct Web3ProxyApp { /// Send private requests (like eth_sendRawTransaction) to all these servers private_rpcs: Option>, /// write lock on these when all rate limits are hit + /// this lock will be held open over an await, so use async locking balanced_rpc_ratelimiter_lock: RwLock<()>, + /// this lock will be held open over an await, so use async locking private_rpcs_ratelimiter_lock: RwLock<()>, } impl Web3ProxyApp { + #[instrument] async fn try_new( + allowed_lag: u64, balanced_rpc_tiers: Vec>, private_rpcs: Vec<(&str, u32)>, ) -> anyhow::Result { @@ -55,10 +60,14 @@ impl Web3ProxyApp { let block_watcher = Arc::new(BlockWatcher::new()); - let block_watcher_clone = Arc::clone(&block_watcher); + let (new_block_sender, mut new_block_receiver) = watch::channel::("".to_string()); - // start the block_watcher - tokio::spawn(async move { block_watcher_clone.run().await }); + { + // TODO: spawn this later? + // spawn a future for the block_watcher + let block_watcher = block_watcher.clone(); + tokio::spawn(async move { block_watcher.run(new_block_sender).await }); + } let balanced_rpc_tiers = Arc::new( future::join_all(balanced_rpc_tiers.into_iter().map(|balanced_rpc_tier| { @@ -89,8 +98,46 @@ impl Web3ProxyApp { )) }; + { + // spawn a future for sorting our synced rpcs + // TODO: spawn this later? + let balanced_rpc_tiers = balanced_rpc_tiers.clone(); + let private_rpcs = private_rpcs.clone(); + let block_watcher = block_watcher.clone(); + + tokio::spawn(async move { + let mut tier_map = HashMap::new(); + let mut private_map = HashMap::new(); + + for balanced_rpc_tier in balanced_rpc_tiers.iter() { + for rpc in balanced_rpc_tier.clone_rpcs() { + tier_map.insert(rpc, balanced_rpc_tier); + } + } + + if let Some(private_rpcs) = private_rpcs { + for rpc in private_rpcs.clone_rpcs() { + private_map.insert(rpc, private_rpcs.clone()); + } + } + + while new_block_receiver.changed().await.is_ok() { + let updated_rpc = new_block_receiver.borrow().clone(); + + if let Some(tier) = tier_map.get(&updated_rpc) { + tier.update_synced_rpcs(block_watcher.clone(), allowed_lag) + .unwrap(); + } else if let Some(tier) = private_map.get(&updated_rpc) { + tier.update_synced_rpcs(block_watcher.clone(), allowed_lag) + .unwrap(); + } else { + panic!("howd this happen"); + } + } + }); + } + Ok(Web3ProxyApp { - block_watcher, clock, balanced_rpc_tiers, private_rpcs, @@ -101,6 +148,7 @@ impl Web3ProxyApp { /// send the request to the approriate RPCs /// TODO: dry this up + #[instrument] async fn proxy_web3_rpc( self: Arc, json_body: serde_json::Value, @@ -118,10 +166,7 @@ impl Web3ProxyApp { let json_body_clone = json_body.clone(); - match private_rpcs - .get_upstream_servers(1, self.block_watcher.clone()) - .await - { + match private_rpcs.get_upstream_servers().await { Ok(upstream_servers) => { let (tx, mut rx) = mpsc::unbounded_channel::>(); @@ -184,10 +229,8 @@ impl Web3ProxyApp { let incoming_id = json_body.as_object().unwrap().get("id").unwrap(); for balanced_rpcs in self.balanced_rpc_tiers.iter() { - match balanced_rpcs - .next_upstream_server(1, self.block_watcher.clone()) - .await - { + // TODO: what allowed lag? + match balanced_rpcs.next_upstream_server().await { Ok(upstream_server) => { // TODO: better type for this. right now its request (the full jsonrpc object), response (just the inner result) let (tx, mut rx) = @@ -218,7 +261,8 @@ impl Web3ProxyApp { .ok_or_else(|| anyhow::anyhow!("no successful response"))?; if let Ok(partial_response) = response { - info!("forwarding request from {}", upstream_server); + // TODO: trace + // info!("forwarding request from {}", upstream_server); let response = json!({ "jsonrpc": "2.0", @@ -354,6 +398,7 @@ async fn main() { let listen_port = 8445; let state = Web3ProxyApp::try_new( + 1, vec![ // local nodes vec![("ws://10.11.12.16:8545", 0), ("ws://10.11.12.16:8946", 0)], @@ -368,7 +413,7 @@ async fn main() { ], // free nodes vec![ - // ("https://main-rpc.linkpool.io", 0), // linkpool is slow + // ("https://main-rpc.linkpool.io", 0), // linkpool is slow and often offline ("https://rpc.ankr.com/eth", 0), ], ], diff --git a/src/provider_tiers.rs b/src/provider_tiers.rs index 0c428d10..aaebbc3c 100644 --- a/src/provider_tiers.rs +++ b/src/provider_tiers.rs @@ -1,4 +1,5 @@ -/// Communicate with groups of web3 providers +///! Communicate with groups of web3 providers +use arc_swap::ArcSwap; use dashmap::DashMap; use governor::clock::{QuantaClock, QuantaInstant}; use governor::middleware::NoOpMiddleware; @@ -6,9 +7,9 @@ use governor::state::{InMemoryState, NotKeyed}; use governor::NotUntil; use governor::RateLimiter; use std::cmp; +use std::collections::HashMap; use std::num::NonZeroU32; use std::sync::Arc; -use tokio::sync::RwLock; use tracing::{info, instrument}; use crate::block_watcher::{BlockWatcher, SyncStatus}; @@ -24,9 +25,10 @@ pub type Web3ConnectionMap = DashMap; /// Load balance to the rpc #[derive(Debug)] pub struct Web3ProviderTier { - /// RPC urls sorted by active requests - /// TODO: what type for the rpc? i think we want this to be the key for the provider and not the provider itself - rpcs: RwLock>, + /// TODO: what type for the rpc? Vec isn't great. i think we want this to be the key for the provider and not the provider itself + /// TODO: we probably want a better lock + synced_rpcs: ArcSwap>, + rpcs: Vec, connections: Arc, ratelimiters: Web3RateLimiterMap, } @@ -64,7 +66,8 @@ impl Web3ProviderTier { } Ok(Web3ProviderTier { - rpcs: RwLock::new(rpcs), + synced_rpcs: ArcSwap::from(Arc::new(vec![])), + rpcs, connections: Arc::new(connections), ratelimiters: ratelimits, }) @@ -74,32 +77,36 @@ impl Web3ProviderTier { self.connections.clone() } - /// get the best available rpc server - #[instrument] - pub async fn next_upstream_server( + pub fn clone_rpcs(&self) -> Vec { + self.rpcs.clone() + } + + pub fn update_synced_rpcs( &self, - allowed_lag: u64, block_watcher: Arc, - ) -> Result> { - let mut available_rpcs = self.rpcs.write().await; + allowed_lag: u64, + ) -> anyhow::Result<()> { + let mut available_rpcs = self.rpcs.clone(); - // sort rpcs by their active connections - available_rpcs.sort_unstable_by(|a, b| { - self.connections - .get(a) - .unwrap() - .cmp(&self.connections.get(b).unwrap()) - }); + // collect sync status for all the rpcs + let sync_status: HashMap = available_rpcs + .clone() + .into_iter() + .map(|rpc| { + let status = block_watcher.sync_status(&rpc, allowed_lag); + (rpc, status) + }) + .collect(); - // sort rpcs by their block height + // sort rpcs by their sync status and active connections available_rpcs.sort_unstable_by(|a, b| { - let a_synced = block_watcher.sync_status(a, allowed_lag); - let b_synced = block_watcher.sync_status(b, allowed_lag); + let a_synced = sync_status.get(a).unwrap(); + let b_synced = sync_status.get(b).unwrap(); match (a_synced, b_synced) { (SyncStatus::Synced(a), SyncStatus::Synced(b)) => { if a != b { - return a.cmp(&b); + return a.cmp(b); } // else they are equal and we want to compare on active connections } @@ -125,7 +132,7 @@ impl Web3ProviderTier { } (SyncStatus::Behind(a), SyncStatus::Behind(b)) => { if a != b { - return a.cmp(&b); + return a.cmp(b); } // else they are equal and we want to compare on active connections } @@ -141,23 +148,23 @@ impl Web3ProviderTier { .cmp(&self.connections.get(b).unwrap()) }); + // filter out + let synced_rpcs: Vec = available_rpcs + .into_iter() + .take_while(|rpc| matches!(sync_status.get(rpc).unwrap(), SyncStatus::Synced(_))) + .collect(); + + self.synced_rpcs.swap(Arc::new(synced_rpcs)); + + Ok(()) + } + + /// get the best available rpc server + #[instrument] + pub async fn next_upstream_server(&self) -> Result> { let mut earliest_not_until = None; - for selected_rpc in available_rpcs.iter() { - // check current block number - // TODO: i don't like that we fetched sync_status above and then do it again here. cache? - if let SyncStatus::Synced(_) = block_watcher.sync_status(selected_rpc, allowed_lag) { - // rpc is synced - } else { - // skip this rpc because it is not synced - // TODO: make a NotUntil here? - // TODO: include how many blocks behind - // TODO: better log - info!("{} is not synced", selected_rpc); - // we sorted on block height. so if this one isn't synced, none of the later ones will be either - break; - } - + for selected_rpc in self.synced_rpcs.load().iter() { // check rate limits if let Some(ratelimiter) = self.ratelimiters.get(selected_rpc) { match ratelimiter.check() { @@ -205,23 +212,10 @@ impl Web3ProviderTier { } /// get all available rpc servers - pub async fn get_upstream_servers( - &self, - allowed_lag: u64, - block_watcher: Arc, - ) -> Result, NotUntil> { + pub async fn get_upstream_servers(&self) -> Result, NotUntil> { let mut earliest_not_until = None; - let mut selected_rpcs = vec![]; - - for selected_rpc in self.rpcs.read().await.iter() { - if let SyncStatus::Synced(_) = block_watcher.sync_status(selected_rpc, allowed_lag) { - // rpc is synced - } else { - // skip this rpc because it is not synced - continue; - } - + for selected_rpc in self.synced_rpcs.load().iter() { // check rate limits match self.ratelimiters.get(selected_rpc).unwrap().check() { Ok(_) => {