From 63428cad6bc9f1b341dd4a214389d9dd37c19670 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Thu, 28 Apr 2022 22:03:26 +0000 Subject: [PATCH] move rate limits --- src/provider.rs | 32 ++++++++++- src/provider_tiers.rs | 126 ++++++++++++++++-------------------------- 2 files changed, 80 insertions(+), 78 deletions(-) diff --git a/src/provider.rs b/src/provider.rs index a8b3d9d3..f816e9af 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -2,6 +2,11 @@ use derive_more::From; use ethers::prelude::{BlockNumber, Middleware}; use futures::StreamExt; +use governor::clock::{QuantaClock, QuantaInstant}; +use governor::middleware::NoOpMiddleware; +use governor::state::{InMemoryState, NotKeyed}; +use governor::NotUntil; +use governor::RateLimiter; use std::fmt; use std::time::Duration; use std::{cmp::Ordering, sync::Arc}; @@ -10,6 +15,9 @@ use tracing::{info, warn}; use crate::block_watcher::BlockWatcherSender; +type Web3RateLimiter = + RateLimiter>; + // TODO: instead of an enum, I tried to use Box, but hit https://github.com/gakonst/ethers-rs/issues/592 #[derive(From)] pub enum Web3Provider { @@ -85,6 +93,7 @@ pub struct Web3Connection { /// keep track of currently open requests. We sort on this active_requests: u32, provider: Arc, + ratelimiter: Option, } impl Web3Connection { @@ -97,6 +106,7 @@ impl Web3Connection { url_str: String, http_client: Option, block_watcher_sender: BlockWatcherSender, + ratelimiter: Option, ) -> anyhow::Result { let provider = if url_str.starts_with("http") { let url: url::Url = url_str.parse()?; @@ -138,11 +148,31 @@ impl Web3Connection { Ok(Web3Connection { active_requests: 0, provider, + ratelimiter, }) } - pub fn inc_active_requests(&mut self) { + pub fn try_inc_active_requests(&mut self) -> Result<(), NotUntil> { + // check rate limits + if let Some(ratelimiter) = self.ratelimiter.as_ref() { + match ratelimiter.check() { + Ok(_) => { + // rate limit succeeded + } + Err(not_until) => { + // rate limit failed + // save the smallest not_until. if nothing succeeds, return an Err with not_until in it + // TODO: use tracing better + warn!("Exhausted rate limit on {:?}: {}", self, not_until); + + return Err(not_until); + } + } + }; + self.active_requests += 1; + + Ok(()) } pub fn dec_active_requests(&mut self) { diff --git a/src/provider_tiers.rs b/src/provider_tiers.rs index 820617e6..0254afc5 100644 --- a/src/provider_tiers.rs +++ b/src/provider_tiers.rs @@ -2,25 +2,17 @@ use arc_swap::ArcSwap; use dashmap::DashMap; use governor::clock::{QuantaClock, QuantaInstant}; -use governor::middleware::NoOpMiddleware; -use governor::state::{InMemoryState, NotKeyed}; use governor::NotUntil; -use governor::RateLimiter; use std::cmp; use std::collections::HashMap; use std::fmt; use std::num::NonZeroU32; use std::sync::Arc; -use tracing::{info, instrument}; use crate::block_watcher::{BlockWatcher, SyncStatus}; use crate::provider::Web3Connection; -type Web3RateLimiter = - RateLimiter>; - -type Web3RateLimiterMap = DashMap; - +// TODO: move the rate limiter into the connection pub type Web3ConnectionMap = DashMap; /// Load balance to the rpc @@ -30,7 +22,6 @@ pub struct Web3ProviderTier { synced_rpcs: ArcSwap>, rpcs: Vec, connections: Arc, - ratelimiters: Web3RateLimiterMap, } impl fmt::Debug for Web3ProviderTier { @@ -49,34 +40,35 @@ impl Web3ProviderTier { ) -> anyhow::Result { let mut rpcs: Vec = vec![]; let connections = DashMap::new(); - let ratelimits = DashMap::new(); for (s, limit) in servers.into_iter() { rpcs.push(s.to_string()); + let ratelimiter = if limit > 0 { + let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); + + let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock); + + Some(rate_limiter) + } else { + None + }; + let connection = Web3Connection::try_new( s.to_string(), http_client.clone(), block_watcher.clone_sender(), + ratelimiter, ) .await?; connections.insert(s.to_string(), connection); - - if limit > 0 { - let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); - - let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock); - - ratelimits.insert(s.to_string(), rate_limiter); - } } Ok(Web3ProviderTier { synced_rpcs: ArcSwap::from(Arc::new(vec![])), rpcs, connections: Arc::new(connections), - ratelimiters: ratelimits, }) } @@ -167,44 +159,31 @@ impl Web3ProviderTier { } /// get the best available rpc server - #[instrument] pub async fn next_upstream_server(&self) -> Result>> { let mut earliest_not_until = None; for selected_rpc in self.synced_rpcs.load().iter() { - // check rate limits - if let Some(ratelimiter) = self.ratelimiters.get(selected_rpc) { - match ratelimiter.check() { - Ok(_) => { - // rate limit succeeded - } - Err(not_until) => { - // rate limit failed - // save the smallest not_until. if nothing succeeds, return an Err with not_until in it - // TODO: use tracing better - info!("Exhausted rate limit on {}: {}", selected_rpc, not_until); - - if earliest_not_until.is_none() { - earliest_not_until = Some(not_until); - } else { - let earliest_possible = - earliest_not_until.as_ref().unwrap().earliest_possible(); - let new_earliest_possible = not_until.earliest_possible(); - - if earliest_possible > new_earliest_possible { - earliest_not_until = Some(not_until); - } - } - continue; - } - } - }; - // increment our connection counter - self.connections + if let Err(not_until) = self + .connections .get_mut(selected_rpc) .unwrap() - .inc_active_requests(); + .try_inc_active_requests() + { + // TODO: do this better + if earliest_not_until.is_none() { + earliest_not_until = Some(not_until); + } else { + let earliest_possible = + earliest_not_until.as_ref().unwrap().earliest_possible(); + let new_earliest_possible = not_until.earliest_possible(); + + if earliest_possible > new_earliest_possible { + earliest_not_until = Some(not_until); + } + } + continue; + } // return the selected RPC return Ok(selected_rpc.clone()); @@ -221,34 +200,27 @@ impl Web3ProviderTier { let mut earliest_not_until = None; let mut selected_rpcs = vec![]; for selected_rpc in self.synced_rpcs.load().iter() { - // check rate limits - match self.ratelimiters.get(selected_rpc).unwrap().check() { - Ok(_) => { - // rate limit succeeded - } - Err(not_until) => { - // rate limit failed - // save the smallest not_until. if nothing succeeds, return an Err with not_until in it - if earliest_not_until.is_none() { - earliest_not_until = Some(not_until); - } else { - let earliest_possible = - earliest_not_until.as_ref().unwrap().earliest_possible(); - let new_earliest_possible = not_until.earliest_possible(); - - if earliest_possible > new_earliest_possible { - earliest_not_until = Some(not_until); - } - } - continue; - } - }; - - // increment our connection counter - self.connections + // check rate limits and increment our connection counter + // TODO: share code with next_upstream_server + if let Err(not_until) = self + .connections .get_mut(selected_rpc) .unwrap() - .inc_active_requests(); + .try_inc_active_requests() + { + if earliest_not_until.is_none() { + earliest_not_until = Some(not_until); + } else { + let earliest_possible = + earliest_not_until.as_ref().unwrap().earliest_possible(); + let new_earliest_possible = not_until.earliest_possible(); + + if earliest_possible > new_earliest_possible { + earliest_not_until = Some(not_until); + } + } + continue; + } // this is rpc should work selected_rpcs.push(selected_rpc.clone());