diff --git a/src/main.rs b/src/main.rs index 6e4afd6a..cd0be722 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,21 @@ use dashmap::DashMap; use futures::future; -use governor::clock::{QuantaClock, QuantaInstant}; +use governor::clock::{Clock, QuantaClock, QuantaInstant}; use governor::middleware::NoOpMiddleware; use governor::state::{InMemoryState, NotKeyed}; use governor::{NotUntil, RateLimiter}; use std::num::NonZeroU32; use std::sync::Arc; use tokio::sync::RwLock; -// use tokio::time::{sleep, Duration}; +use tokio::time::sleep; use warp::Filter; -type RpcRateLimiter = - RateLimiter>; - type RateLimiterMap = DashMap; type ConnectionsMap = DashMap; +type RpcRateLimiter = + RateLimiter>; + /// Load balance to the least-connection rpc struct BalancedRpcs { rpcs: RwLock>, @@ -23,21 +23,20 @@ struct BalancedRpcs { ratelimits: RateLimiterMap, } -// TODO: also pass rate limits to this? -impl Into for Vec<(&str, u32)> { - fn into(self) -> BalancedRpcs { +impl BalancedRpcs { + fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> BalancedRpcs { let mut rpcs: Vec = vec![]; let connections = DashMap::new(); let ratelimits = DashMap::new(); - for (s, limit) in self.into_iter() { + for (s, limit) in servers.into_iter() { rpcs.push(s.to_string()); connections.insert(s.to_string(), 0); if limit > 0 { let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); - let rate_limiter = governor::RateLimiter::direct(quota); + let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock); ratelimits.insert(s.to_string(), rate_limiter); } @@ -49,9 +48,7 @@ impl Into for Vec<(&str, u32)> { ratelimits, } } -} -impl BalancedRpcs { async fn get_upstream_server(&self) -> Result> { let mut balanced_rpcs = self.rpcs.write().await; @@ -114,18 +111,18 @@ struct LoudRpcs { ratelimits: RateLimiterMap, } -impl Into for Vec<(&str, u32)> { - fn into(self) -> LoudRpcs { +impl LoudRpcs { + fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> LoudRpcs { let mut rpcs: Vec = vec![]; let ratelimits = RateLimiterMap::new(); - for (s, limit) in self.into_iter() { + for (s, limit) in servers.into_iter() { rpcs.push(s.to_string()); if limit > 0 { let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); - let rate_limiter = governor::RateLimiter::direct(quota); + let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock); ratelimits.insert(s.to_string(), rate_limiter); } @@ -133,12 +130,50 @@ impl Into for Vec<(&str, u32)> { LoudRpcs { rpcs, ratelimits } } -} -impl LoudRpcs { - async fn get_upstream_servers(&self) -> Vec { - // - self.rpcs.clone() + async fn get_upstream_servers(&self) -> Result, NotUntil> { + let mut earliest_not_until = None; + + let mut selected_rpcs = vec![]; + + for selected_rpc in self.rpcs.iter() { + // check rate limits + match self.ratelimits.get(selected_rpc).unwrap().check() { + Ok(_) => { + // rate limit succeeded + } + Err(not_until) => { + // rate limit failed + // save the smallest not_until. if nothing succeeds, return an Err with not_until in it + if earliest_not_until.is_none() { + earliest_not_until = Some(not_until); + } else { + let earliest_possible = + earliest_not_until.as_ref().unwrap().earliest_possible(); + let new_earliest_possible = not_until.earliest_possible(); + + if earliest_possible > new_earliest_possible { + earliest_not_until = Some(not_until); + } + } + continue; + } + }; + + // return the selected RPC + selected_rpcs.push(selected_rpc.clone()); + } + + if selected_rpcs.len() > 0 { + return Ok(selected_rpcs); + } + + // return the earliest not_until + if let Some(not_until) = earliest_not_until { + return Err(not_until); + } else { + panic!("i don't think this should happen") + } } fn as_bool(&self) -> bool { @@ -147,7 +182,9 @@ impl LoudRpcs { } struct Web3ProxyState { + clock: QuantaClock, client: reqwest::Client, + // TODO: LoudRpcs and BalancedRpcs should probably share a trait or something balanced_rpc_tiers: Vec, private_rpcs: LoudRpcs, } @@ -157,11 +194,17 @@ impl Web3ProxyState { balanced_rpc_tiers: Vec>, private_rpcs: Vec<(&str, u32)>, ) -> Web3ProxyState { + let clock = QuantaClock::default(); + // TODO: warn if no private relays Web3ProxyState { + clock: clock.clone(), client: reqwest::Client::new(), - balanced_rpc_tiers: balanced_rpc_tiers.into_iter().map(Into::into).collect(), - private_rpcs: private_rpcs.into(), + balanced_rpc_tiers: balanced_rpc_tiers + .into_iter() + .map(|servers| BalancedRpcs::new(servers, &clock)) + .collect(), + private_rpcs: LoudRpcs::new(private_rpcs, &clock), } } @@ -175,37 +218,68 @@ impl Web3ProxyState { if self.private_rpcs.as_bool() && json_body.get("method") == Some(ð_send_raw_transaction) { - // there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs - let upstream_servers = self.private_rpcs.get_upstream_servers().await; - - if let Ok(result) = self - .try_send_requests(upstream_servers, None, &json_body) - .await - { - return Ok(result); + loop { + // there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs + match self.private_rpcs.get_upstream_servers().await { + Ok(upstream_servers) => { + if let Ok(result) = self + .try_send_requests(upstream_servers, None, &json_body) + .await + { + return Ok(result); + } + } + Err(not_until) => { + // TODO: there should probably be a lock on this so that other queries wait + let deadline = not_until.wait_time_from(self.clock.now()); + sleep(deadline).await; + } + }; } } else { // this is not a private transaction (or no private relays are configured) - for balanced_rpcs in self.balanced_rpc_tiers.iter() { - if let Ok(upstream_server) = balanced_rpcs.get_upstream_server().await { - // TODO: capture any errors. at least log them - if let Ok(result) = self - .try_send_requests( - vec![upstream_server], - Some(&balanced_rpcs.connections), - &json_body, - ) - .await - { - return Ok(result); + loop { + let mut earliest_not_until = None; + + for balanced_rpcs in self.balanced_rpc_tiers.iter() { + match balanced_rpcs.get_upstream_server().await { + Ok(upstream_server) => { + // TODO: capture any errors. at least log them + if let Ok(result) = self + .try_send_requests( + vec![upstream_server], + Some(&balanced_rpcs.connections), + &json_body, + ) + .await + { + return Ok(result); + } + } + Err(not_until) => { + // save the smallest not_until. if nothing succeeds, return an Err with not_until in it + if earliest_not_until.is_none() { + earliest_not_until = Some(not_until); + } else { + // TODO: do we need to unwrap this far? can we just compare the not_untils + let earliest_possible = + earliest_not_until.as_ref().unwrap().earliest_possible(); + let new_earliest_possible = not_until.earliest_possible(); + + if earliest_possible > new_earliest_possible { + earliest_not_until = Some(not_until); + } + } + } } - } else { - // TODO: if we got an error. save the ratelimit NotUntil so we can sleep until then before trying again } + + // we haven't returned an Ok, sleep and try again + // unwrap should be safe since we would have returned if it wasn't set + let deadline = earliest_not_until.unwrap().wait_time_from(self.clock.now()); + sleep(deadline).await; } } - - return Err(anyhow::anyhow!("all servers failed")); } async fn try_send_requests( @@ -287,6 +361,7 @@ async fn main() { // TODO: support multiple chains in one process. then we could just point "chain.stytt.com" at this and caddy wouldn't need anything else // TODO: i kind of want to make use of caddy's load balancing and health checking and such though let listen_port = 8445; + // TODO: be smart about about using archive nodes? let state = Web3ProxyState::new( vec![