From d54196f868d60e57774431331909a754f6331b4e Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Wed, 27 Apr 2022 01:25:01 +0000 Subject: [PATCH] pass block_watcher around --- src/block_watcher.rs | 61 +++++++++++++++++++++++++------------------ src/main.rs | 24 ++++++++++++----- src/provider_tiers.rs | 19 +++++++++----- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/src/block_watcher.rs b/src/block_watcher.rs index e408ab56..ab7855eb 100644 --- a/src/block_watcher.rs +++ b/src/block_watcher.rs @@ -2,8 +2,9 @@ use ethers::prelude::{Block, TxHash}; use std::cmp::Ordering; use std::collections::HashMap; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, RwLock}; use tracing::info; // TODO: what type for the Item? String url works, but i don't love it @@ -13,51 +14,62 @@ pub type BlockWatcherSender = mpsc::UnboundedSender; pub type BlockWatcherReceiver = mpsc::UnboundedReceiver; pub struct BlockWatcher { - receiver: BlockWatcherReceiver, + sender: BlockWatcherSender, + receiver: RwLock, /// TODO: i don't think we want a hashmap. we want a left-right or some other concurrent map - blocks: HashMap>, - latest_block: Option>, + blocks: RwLock>>, + latest_block: RwLock>>, } impl BlockWatcher { - pub fn new() -> (BlockWatcher, BlockWatcherSender) { + pub fn new() -> Self { // TODO: this also needs to return a reader for blocks let (sender, receiver) = mpsc::unbounded_channel(); - let watcher = Self { - receiver, + Self { + sender, + receiver: RwLock::new(receiver), blocks: Default::default(), - latest_block: None, - }; - - (watcher, sender) + latest_block: RwLock::new(None), + } } - pub async fn run(&mut self) -> anyhow::Result<()> { - while let Some((rpc, block)) = self.receiver.recv().await { + pub fn clone_sender(&self) -> BlockWatcherSender { + self.sender.clone() + } + + pub async fn run(self: Arc) -> anyhow::Result<()> { + let mut receiver = self.receiver.write().await; + + while let Some((rpc, block)) = receiver.recv().await { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_secs() as i64; - let current_block = self.blocks.get(&rpc); - - if current_block == Some(&block) { - // we already have this block - continue; + { + let blocks = self.blocks.read().await; + if blocks.get(&rpc) == Some(&block) { + // we already have this block + continue; + } } - let label_slow_blocks = if self.latest_block.is_none() { - self.latest_block = Some(block.clone()); + // save the block for this rpc + self.blocks.write().await.insert(rpc.clone(), block.clone()); + + // TODO: we don't always need this to have a write lock + let mut latest_block = self.latest_block.write().await; + + let label_slow_blocks = if latest_block.is_none() { + *latest_block = Some(block.clone()); "+" } else { - let latest_block = self.latest_block.as_ref().unwrap(); - // TODO: what if they have the same number but different hashes? or aren't on the same chain? - match block.number.cmp(&latest_block.number) { + match block.number.cmp(&latest_block.as_ref().unwrap().number) { Ordering::Equal => "", Ordering::Greater => { - self.latest_block = Some(block.clone()); + *latest_block = Some(block.clone()); "+" } Ordering::Less => { @@ -78,7 +90,6 @@ impl BlockWatcher { now - block.timestamp.as_u64() as i64, label_slow_blocks ); - self.blocks.insert(rpc, block); } Ok(()) diff --git a/src/main.rs b/src/main.rs index 908cd7d0..ef99037d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,6 +24,7 @@ static APP_USER_AGENT: &str = concat!( /// The application struct Web3ProxyApp { + block_watcher: Arc, /// clock used for rate limiting /// TODO: use tokio's clock (will require a different ratelimiting crate) clock: QuantaClock, @@ -43,8 +44,6 @@ impl Web3ProxyApp { ) -> anyhow::Result { let clock = QuantaClock::default(); - let (mut block_watcher, block_watcher_sender) = BlockWatcher::new(); - // make a http shared client // TODO: how should we configure the connection pool? // TODO: 5 minutes is probably long enough. unlimited is a bad idea if something @@ -53,15 +52,19 @@ impl Web3ProxyApp { .user_agent(APP_USER_AGENT) .build()?; + let block_watcher = Arc::new(BlockWatcher::new()); + + let block_watcher_clone = Arc::clone(&block_watcher); + // start the block_watcher - tokio::spawn(async move { block_watcher.run().await }); + tokio::spawn(async move { block_watcher_clone.run().await }); let balanced_rpc_tiers = Arc::new( future::join_all(balanced_rpc_tiers.into_iter().map(|balanced_rpc_tier| { Web3ProviderTier::try_new( balanced_rpc_tier, Some(http_client.clone()), - block_watcher_sender.clone(), + block_watcher.clone(), &clock, ) })) @@ -77,7 +80,7 @@ impl Web3ProxyApp { Web3ProviderTier::try_new( private_rpcs, Some(http_client), - block_watcher_sender, + block_watcher.clone(), &clock, ) .await?, @@ -86,6 +89,7 @@ impl Web3ProxyApp { // TODO: warn if no private relays Ok(Web3ProxyApp { + block_watcher, clock, balanced_rpc_tiers, private_rpcs, @@ -111,7 +115,10 @@ impl Web3ProxyApp { loop { let read_lock = self.private_rpcs_ratelimiter_lock.read().await; - match private_rpcs.get_upstream_servers().await { + match private_rpcs + .get_upstream_servers(self.block_watcher.clone()) + .await + { Ok(upstream_servers) => { let (tx, mut rx) = mpsc::unbounded_channel::>(); @@ -160,7 +167,10 @@ impl Web3ProxyApp { let mut earliest_not_until = None; for balanced_rpcs in self.balanced_rpc_tiers.iter() { - match balanced_rpcs.next_upstream_server().await { + match balanced_rpcs + .next_upstream_server(self.block_watcher.clone()) + .await + { Ok(upstream_server) => { let (tx, mut rx) = mpsc::unbounded_channel::>(); diff --git a/src/provider_tiers.rs b/src/provider_tiers.rs index eb281fa8..8edeed69 100644 --- a/src/provider_tiers.rs +++ b/src/provider_tiers.rs @@ -9,7 +9,7 @@ use std::num::NonZeroU32; use std::sync::Arc; use tokio::sync::RwLock; -use crate::block_watcher::BlockWatcherSender; +use crate::block_watcher::BlockWatcher; use crate::provider::Web3Connection; type Web3RateLimiter = @@ -33,7 +33,7 @@ impl Web3ProviderTier { pub async fn try_new( servers: Vec<(&str, u32)>, http_client: Option, - block_watcher_sender: BlockWatcherSender, + block_watcher: Arc, clock: &QuantaClock, ) -> anyhow::Result { let mut rpcs: Vec = vec![]; @@ -46,7 +46,7 @@ impl Web3ProviderTier { let connection = Web3Connection::try_new( s.to_string(), http_client.clone(), - block_watcher_sender.clone(), + block_watcher.clone_sender(), ) .await?; @@ -73,7 +73,10 @@ impl Web3ProviderTier { } /// get the best available rpc server - pub async fn next_upstream_server(&self) -> Result> { + pub async fn next_upstream_server( + &self, + block_watcher: Arc, + ) -> Result> { let mut balanced_rpcs = self.rpcs.write().await; // sort rpcs by their active connections @@ -85,7 +88,8 @@ impl Web3ProviderTier { let mut earliest_not_until = None; for selected_rpc in balanced_rpcs.iter() { - // TODO: check current block number. if behind, make our own NotUntil here + // TODO: check current block number. if too far behind, make our own NotUntil here + let ratelimits = self.ratelimits.write().await; // check rate limits @@ -132,7 +136,10 @@ impl Web3ProviderTier { } /// get all available rpc servers - pub async fn get_upstream_servers(&self) -> Result, NotUntil> { + pub async fn get_upstream_servers( + &self, + block_watcher: Arc, + ) -> Result, NotUntil> { let mut earliest_not_until = None; let mut selected_rpcs = vec![];