diff --git a/src/main.rs b/src/main.rs index b760935e..f47dd52e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -269,16 +269,11 @@ impl Web3ProxyApp { async move { // get the client for this rpc server - let provider = connections.read().await.get(&rpc).unwrap().clone_provider(); + let provider = connections.get(&rpc).unwrap().clone_provider(); let response = provider.request(&method, params).await; - connections - .write() - .await - .get_mut(&rpc) - .unwrap() - .dec_active_requests(); + connections.get_mut(&rpc).unwrap().dec_active_requests(); let mut response = response?; diff --git a/src/provider_tiers.rs b/src/provider_tiers.rs index 8a60ce73..7a1b9cfe 100644 --- a/src/provider_tiers.rs +++ b/src/provider_tiers.rs @@ -1,10 +1,10 @@ -///! Communicate with groups of web3 providers +/// Communicate with groups of web3 providers +use dashmap::DashMap; use governor::clock::{QuantaClock, QuantaInstant}; use governor::middleware::NoOpMiddleware; use governor::state::{InMemoryState, NotKeyed}; use governor::NotUntil; use governor::RateLimiter; -use std::collections::HashMap; use std::num::NonZeroU32; use std::sync::Arc; use tokio::sync::RwLock; @@ -15,9 +15,9 @@ use crate::provider::Web3Connection; type Web3RateLimiter = RateLimiter>; -type Web3RateLimiterMap = RwLock>; +type Web3RateLimiterMap = DashMap; -pub type Web3ConnectionMap = RwLock>; +pub type Web3ConnectionMap = DashMap; /// Load balance to the rpc /// TODO: i'm not sure about having 3 locks here. can we share them? @@ -26,7 +26,7 @@ pub struct Web3ProviderTier { /// TODO: what type for the rpc? rpcs: RwLock>, connections: Arc, - ratelimits: Web3RateLimiterMap, + ratelimiters: Web3RateLimiterMap, } impl Web3ProviderTier { @@ -37,8 +37,8 @@ impl Web3ProviderTier { clock: &QuantaClock, ) -> anyhow::Result { let mut rpcs: Vec = vec![]; - let mut connections = HashMap::new(); - let mut ratelimits = HashMap::new(); + let connections = DashMap::new(); + let ratelimits = DashMap::new(); for (s, limit) in servers.into_iter() { rpcs.push(s.to_string()); @@ -63,8 +63,8 @@ impl Web3ProviderTier { Ok(Web3ProviderTier { rpcs: RwLock::new(rpcs), - connections: Arc::new(RwLock::new(connections)), - ratelimits: RwLock::new(ratelimits), + connections: Arc::new(connections), + ratelimiters: ratelimits, }) } @@ -80,10 +80,12 @@ impl Web3ProviderTier { let mut balanced_rpcs = self.rpcs.write().await; // sort rpcs by their active connections - let connections = self.connections.read().await; - - balanced_rpcs - .sort_unstable_by(|a, b| connections.get(a).unwrap().cmp(connections.get(b).unwrap())); + balanced_rpcs.sort_unstable_by(|a, b| { + self.connections + .get(a) + .unwrap() + .cmp(&self.connections.get(b).unwrap()) + }); let mut earliest_not_until = None; @@ -98,35 +100,33 @@ impl Web3ProviderTier { continue; } - let ratelimits = self.ratelimits.write().await; - // check rate limits - match ratelimits.get(selected_rpc).unwrap().check() { - Ok(_) => { - // rate limit succeeded - } - Err(not_until) => { - // rate limit failed - // save the smallest not_until. if nothing succeeds, return an Err with not_until in it - if earliest_not_until.is_none() { - earliest_not_until = Some(not_until); - } else { - let earliest_possible = - earliest_not_until.as_ref().unwrap().earliest_possible(); - let new_earliest_possible = not_until.earliest_possible(); - - if earliest_possible > new_earliest_possible { - earliest_not_until = Some(not_until); - } + if let Some(ratelimiter) = self.ratelimiters.get(selected_rpc) { + match ratelimiter.check() { + Ok(_) => { + // rate limit succeeded + } + Err(not_until) => { + // rate limit failed + // save the smallest not_until. if nothing succeeds, return an Err with not_until in it + if earliest_not_until.is_none() { + earliest_not_until = Some(not_until); + } else { + let earliest_possible = + earliest_not_until.as_ref().unwrap().earliest_possible(); + let new_earliest_possible = not_until.earliest_possible(); + + if earliest_possible > new_earliest_possible { + earliest_not_until = Some(not_until); + } + } + continue; } - continue; } }; // increment our connection counter self.connections - .write() - .await .get_mut(selected_rpc) .unwrap() .inc_active_requests(); @@ -164,14 +164,7 @@ impl Web3ProviderTier { } // check rate limits - match self - .ratelimits - .write() - .await - .get(selected_rpc) - .unwrap() - .check() - { + match self.ratelimiters.get(selected_rpc).unwrap().check() { Ok(_) => { // rate limit succeeded } @@ -195,8 +188,6 @@ impl Web3ProviderTier { // increment our connection counter self.connections - .write() - .await .get_mut(selected_rpc) .unwrap() .inc_active_requests();