use channels to return early

This commit is contained in:
Bryan Stitt 2022-04-26 06:54:24 +00:00
parent f8ff0370d2
commit 710cef5da3
3 changed files with 344 additions and 156 deletions

2
Cargo.lock generated
View File

@ -3739,9 +3739,9 @@ dependencies = [
"governor", "governor",
"regex", "regex",
"reqwest", "reqwest",
"serde",
"serde_json", "serde_json",
"tokio", "tokio",
"tokio-tungstenite 0.17.1",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",

View File

@ -16,8 +16,8 @@ governor = { version = "0.4.2", features = ["dashmap", "std"] }
tokio = { version = "1.17.0", features = ["full"] } tokio = { version = "1.17.0", features = ["full"] }
regex = "1.5.5" regex = "1.5.5"
reqwest = { version = "0.11.10", features = ["json"] } reqwest = { version = "0.11.10", features = ["json"] }
serde = {version = "1.0"}
serde_json = { version = "1.0.79", default-features = false, features = ["alloc"] } serde_json = { version = "1.0.79", default-features = false, features = ["alloc"] }
tokio-tungstenite = "0.17.1"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
url = "2.2.2" url = "2.2.2"

View File

@ -1,47 +1,196 @@
use dashmap::DashMap; // TODO: don't use dashmap. we need something for async
use ethers::prelude::{Block, TxHash};
use ethers::providers::Middleware;
use futures::future; use futures::future;
use futures::future::{AbortHandle, Abortable};
use futures::SinkExt;
use futures::StreamExt; use futures::StreamExt;
use governor::clock::{Clock, QuantaClock, QuantaInstant}; use governor::clock::{Clock, QuantaClock, QuantaInstant};
use governor::middleware::NoOpMiddleware; use governor::middleware::NoOpMiddleware;
use governor::state::{InMemoryState, NotKeyed}; use governor::state::{InMemoryState, NotKeyed};
use governor::{NotUntil, RateLimiter}; use governor::{NotUntil, RateLimiter};
use regex::Regex; use std::cmp::Ordering;
use std::collections::HashMap;
use std::num::NonZeroU32; use std::num::NonZeroU32;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::RwLock; use tokio::sync::{mpsc, RwLock};
use tokio::time::sleep; use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite};
use warp::Filter; use warp::Filter;
type RateLimiterMap = DashMap<String, RpcRateLimiter>; static APP_USER_AGENT: &str = concat!(
"satoshiandkin/",
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION"),
);
// TODO: i'm not sure we need this. i think we can use dyn
enum EthersProvider {
Http(ethers::providers::Provider<ethers::providers::Http>),
Ws(ethers::providers::Provider<ethers::providers::Ws>),
}
// TODO: seems like this should be derivable
impl From<ethers::providers::Provider<ethers::providers::Http>> for EthersProvider {
fn from(item: ethers::providers::Provider<ethers::providers::Http>) -> Self {
EthersProvider::Http(item)
}
}
// TODO: seems like this should be derivable
impl From<ethers::providers::Provider<ethers::providers::Ws>> for EthersProvider {
fn from(item: ethers::providers::Provider<ethers::providers::Ws>) -> Self {
EthersProvider::Ws(item)
}
}
impl EthersProvider {
pub async fn request(
&self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ethers::prelude::ProviderError> {
match self {
Self::Http(provider) => provider.request(method, params).await,
Self::Ws(provider) => provider.request(method, params).await,
}
}
}
struct EthersConnection {
/// keep track of currently open requests. We sort on this
active_requests: u32,
provider: Arc<EthersProvider>,
}
impl EthersConnection {
async fn try_new(
url_str: &str,
http_client: Option<reqwest::Client>,
blocks: Arc<BlockMap>,
) -> anyhow::Result<EthersConnection> {
// TODO: create an ethers-rs rpc client and subscribe/watch new heads in a spawned task
let provider = if url_str.starts_with("http") {
let url: url::Url = url_str.try_into()?;
let http_client = http_client.ok_or_else(|| anyhow::anyhow!("no http_client"))?;
ethers::providers::Http::new_with_client(url, http_client)
} else if url_str.starts_with("ws") {
// ethers::providers::Ws::connect(s.to_string()).await?
// TODO: make sure this survives disconnects
unimplemented!();
} else {
return Err(anyhow::anyhow!("only http and ws servers are supported"));
};
let provider = ethers::providers::Provider::new(provider)
.interval(Duration::from_secs(1))
.into();
match &provider {
EthersProvider::Http(provider) => {
let mut stream = provider.watch_blocks().await?.take(3);
while let Some(block_number) = stream.next().await {
let block = provider.get_block(block_number).await?.unwrap();
println!(
"{:?} = Ts: {:?}, block number: {}",
block.hash.unwrap(),
block.timestamp,
block.number.unwrap(),
);
let mut blocks = blocks.write().await;
blocks.insert(url_str.to_string(), block);
}
}
EthersProvider::Ws(provider) => {
let mut stream = provider.subscribe_blocks().await?;
while let Some(block) = stream.next().await {
// TODO: save the block into a dashmap on
println!(
"{:?} = Ts: {:?}, block number: {}",
block.hash.unwrap(),
block.timestamp,
block.number.unwrap(),
);
}
}
}
// TODO: subscribe to new_heads
// TODO: if http, maybe we should check them all on the same interval. and if there is at least one websocket, use them for the interval
Ok(EthersConnection {
active_requests: 0,
provider: Arc::new(provider),
})
}
fn inc(&mut self) {
self.active_requests += 1;
}
fn dec(&mut self) {
self.active_requests -= 1;
}
}
impl Eq for EthersConnection {}
impl Ord for EthersConnection {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.active_requests.cmp(&other.active_requests)
}
}
impl PartialOrd for EthersConnection {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for EthersConnection {
fn eq(&self, other: &Self) -> bool {
self.active_requests == other.active_requests
}
}
type BlockMap = RwLock<HashMap<String, Block<TxHash>>>;
type RateLimiterMap = RwLock<HashMap<String, RpcRateLimiter>>;
// TODO: include the ethers client on this map // TODO: include the ethers client on this map
type ConnectionsMap = DashMap<String, u32>; type ConnectionsMap = RwLock<HashMap<String, EthersConnection>>;
type RpcRateLimiter = type RpcRateLimiter =
RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>; RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>;
/// Load balance to the least-connection rpc /// Load balance to the least-connection rpc
struct BalancedRpcs { struct RpcTier {
rpcs: RwLock<Vec<String>>, rpcs: RwLock<Vec<String>>,
connections: ConnectionsMap, connections: Arc<ConnectionsMap>,
ratelimits: RateLimiterMap, ratelimits: RateLimiterMap,
} }
impl BalancedRpcs { impl RpcTier {
fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> BalancedRpcs { async fn try_new(
servers: Vec<(&str, u32)>,
http_client: Option<reqwest::Client>,
blocks: Arc<BlockMap>,
clock: &QuantaClock,
) -> anyhow::Result<RpcTier> {
let mut rpcs: Vec<String> = vec![]; let mut rpcs: Vec<String> = vec![];
let connections = DashMap::new(); let mut connections = HashMap::new();
let ratelimits = DashMap::new(); let mut ratelimits = HashMap::new();
for (s, limit) in servers.into_iter() { for (s, limit) in servers.into_iter() {
rpcs.push(s.to_string()); rpcs.push(s.to_string());
// TODO: subscribe to new_heads. if websocket, this is easy. otherwise we let connection =
EthersConnection::try_new(s, http_client.clone(), blocks.clone()).await?;
connections.insert(s.to_string(), 0); connections.insert(s.to_string(), connection);
if limit > 0 { if limit > 0 {
let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap()); let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap());
@ -52,6 +201,7 @@ impl BalancedRpcs {
} }
} }
/*
let new_heads_handles = rpcs let new_heads_handles = rpcs
.clone() .clone()
.into_iter() .into_iter()
@ -108,31 +258,33 @@ impl BalancedRpcs {
abort_handle abort_handle
}) })
.collect(); .collect();
*/
BalancedRpcs { Ok(RpcTier {
rpcs: RwLock::new(rpcs), rpcs: RwLock::new(rpcs),
connections, connections: Arc::new(RwLock::new(connections)),
ratelimits, ratelimits: RwLock::new(ratelimits),
new_heads_handles, })
}
} }
/// get the best available rpc server /// get the best available rpc server
async fn get_upstream_server(&self) -> Result<String, NotUntil<QuantaInstant>> { async fn next_upstream_server(&self) -> Result<String, NotUntil<QuantaInstant>> {
let mut balanced_rpcs = self.rpcs.write().await; let mut balanced_rpcs = self.rpcs.write().await;
balanced_rpcs.sort_unstable_by(|a, b| { // sort rpcs by their active connections
self.connections let connections = self.connections.read().await;
.get(a)
.unwrap() balanced_rpcs
.cmp(&self.connections.get(b).unwrap()) .sort_unstable_by(|a, b| connections.get(a).unwrap().cmp(connections.get(b).unwrap()));
});
let mut earliest_not_until = None; let mut earliest_not_until = None;
for selected_rpc in balanced_rpcs.iter() { for selected_rpc in balanced_rpcs.iter() {
// TODO: check current block number. if behind, make our own NotUntil here
let ratelimits = self.ratelimits.write().await;
// check rate limits // check rate limits
match self.ratelimits.get(selected_rpc).unwrap().check() { match ratelimits.get(selected_rpc).unwrap().check() {
Ok(_) => { Ok(_) => {
// rate limit succeeded // rate limit succeeded
} }
@ -155,8 +307,12 @@ impl BalancedRpcs {
}; };
// increment our connection counter // increment our connection counter
let mut connections = self.connections.get_mut(selected_rpc).unwrap(); self.connections
*connections += 1; .write()
.await
.get_mut(selected_rpc)
.unwrap()
.inc();
// return the selected RPC // return the selected RPC
return Ok(selected_rpc.clone()); return Ok(selected_rpc.clone());
@ -169,36 +325,6 @@ impl BalancedRpcs {
unimplemented!(); unimplemented!();
} }
} }
}
/// Send to all the Rpcs
/// Unlike BalancedRpcs, there is no tracking of connections
/// We do still track rate limits
struct LoudRpcs {
rpcs: Vec<String>,
// TODO: what type? store with connections?
ratelimits: RateLimiterMap,
}
impl LoudRpcs {
fn new(servers: Vec<(&str, u32)>, clock: &QuantaClock) -> LoudRpcs {
let mut rpcs: Vec<String> = vec![];
let ratelimits = RateLimiterMap::new();
for (s, limit) in servers.into_iter() {
rpcs.push(s.to_string());
if limit > 0 {
let quota = governor::Quota::per_second(NonZeroU32::new(limit).unwrap());
let rate_limiter = governor::RateLimiter::direct_with_clock(quota, clock);
ratelimits.insert(s.to_string(), rate_limiter);
}
}
LoudRpcs { rpcs, ratelimits }
}
/// get all available rpc servers /// get all available rpc servers
async fn get_upstream_servers(&self) -> Result<Vec<String>, NotUntil<QuantaInstant>> { async fn get_upstream_servers(&self) -> Result<Vec<String>, NotUntil<QuantaInstant>> {
@ -206,9 +332,16 @@ impl LoudRpcs {
let mut selected_rpcs = vec![]; let mut selected_rpcs = vec![];
for selected_rpc in self.rpcs.iter() { for selected_rpc in self.rpcs.read().await.iter() {
// check rate limits // check rate limits
match self.ratelimits.get(selected_rpc).unwrap().check() { match self
.ratelimits
.write()
.await
.get(selected_rpc)
.unwrap()
.check()
{
Ok(_) => { Ok(_) => {
// rate limit succeeded // rate limit succeeded
} }
@ -230,6 +363,14 @@ impl LoudRpcs {
} }
}; };
// increment our connection counter
self.connections
.write()
.await
.get_mut(selected_rpc)
.unwrap()
.inc();
// this is rpc should work // this is rpc should work
selected_rpcs.push(selected_rpc.clone()); selected_rpcs.push(selected_rpc.clone());
} }
@ -242,52 +383,71 @@ impl LoudRpcs {
if let Some(not_until) = earliest_not_until { if let Some(not_until) = earliest_not_until {
Err(not_until) Err(not_until)
} else { } else {
panic!("i don't think this should happen") // TODO: is this right?
Ok(vec![])
} }
} }
fn as_bool(&self) -> bool {
!self.rpcs.is_empty()
}
} }
struct Web3ProxyState { struct Web3ProxyState {
clock: QuantaClock, clock: QuantaClock,
client: reqwest::Client, balanced_rpc_tiers: Arc<Vec<RpcTier>>,
// TODO: LoudRpcs and BalancedRpcs should probably share a trait or something private_rpcs: Option<Arc<RpcTier>>,
balanced_rpc_tiers: Vec<BalancedRpcs>, /// write lock on these when all rate limits are hit
private_rpcs: LoudRpcs,
/// lock this when all rate limiters are hit
balanced_rpc_ratelimiter_lock: RwLock<()>, balanced_rpc_ratelimiter_lock: RwLock<()>,
private_rpcs_ratelimiter_lock: RwLock<()>, private_rpcs_ratelimiter_lock: RwLock<()>,
} }
impl Web3ProxyState { impl Web3ProxyState {
fn new( async fn try_new(
balanced_rpc_tiers: Vec<Vec<(&str, u32)>>, balanced_rpc_tiers: Vec<Vec<(&str, u32)>>,
private_rpcs: Vec<(&str, u32)>, private_rpcs: Vec<(&str, u32)>,
) -> Web3ProxyState { ) -> anyhow::Result<Web3ProxyState> {
let clock = QuantaClock::default(); let clock = QuantaClock::default();
let balanced_rpc_tiers = balanced_rpc_tiers let blocks = Arc::new(BlockMap::default());
.into_iter()
.map(|servers| BalancedRpcs::new(servers, &clock))
.collect();
let private_rpcs = LoudRpcs::new(private_rpcs, &clock); // TODO: 5 minutes is probably long enough. unlimited is a bad idea if something
let http_client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(300))
.user_agent(APP_USER_AGENT)
.build()?;
// TODO: i'm sure we s
let balanced_rpc_tiers = Arc::new(
future::join_all(balanced_rpc_tiers.into_iter().map(|balanced_rpc_tier| {
RpcTier::try_new(
balanced_rpc_tier,
Some(http_client.clone()),
blocks.clone(),
&clock,
)
}))
.await
.into_iter()
.collect::<anyhow::Result<Vec<RpcTier>>>()?,
);
let private_rpcs = if private_rpcs.is_empty() {
None
} else {
Some(Arc::new(
RpcTier::try_new(private_rpcs, Some(http_client), blocks.clone(), &clock).await?,
))
};
// TODO: warn if no private relays // TODO: warn if no private relays
Web3ProxyState { Ok(Web3ProxyState {
clock, clock,
client: reqwest::Client::new(),
balanced_rpc_tiers, balanced_rpc_tiers,
private_rpcs, private_rpcs,
balanced_rpc_ratelimiter_lock: Default::default(), balanced_rpc_ratelimiter_lock: Default::default(),
private_rpcs_ratelimiter_lock: Default::default(), private_rpcs_ratelimiter_lock: Default::default(),
} })
} }
/// send the request to the approriate RPCs /// send the request to the approriate RPCs
/// TODO: dry this up
async fn proxy_web3_rpc( async fn proxy_web3_rpc(
self: Arc<Web3ProxyState>, self: Arc<Web3ProxyState>,
json_body: serde_json::Value, json_body: serde_json::Value,
@ -295,20 +455,34 @@ impl Web3ProxyState {
let eth_send_raw_transaction = let eth_send_raw_transaction =
serde_json::Value::String("eth_sendRawTransaction".to_string()); serde_json::Value::String("eth_sendRawTransaction".to_string());
if self.private_rpcs.as_bool() && json_body.get("method") == Some(&eth_send_raw_transaction) if self.private_rpcs.is_some() && json_body.get("method") == Some(&eth_send_raw_transaction)
{ {
let private_rpcs = self.private_rpcs.clone().unwrap();
// there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs // there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs
loop { loop {
let read_lock = self.private_rpcs_ratelimiter_lock.read().await; let read_lock = self.private_rpcs_ratelimiter_lock.read().await;
match self.private_rpcs.get_upstream_servers().await { match private_rpcs.get_upstream_servers().await {
Ok(upstream_servers) => { Ok(upstream_servers) => {
if let Ok(result) = self let (tx, mut rx) = mpsc::unbounded_channel::<serde_json::Value>();
.try_send_requests(upstream_servers, None, &json_body)
let clone = self.clone();
let connections = private_rpcs.connections.clone();
let json_body = json_body.clone();
tokio::spawn(async move {
clone
.try_send_requests(upstream_servers, connections, json_body, tx)
.await
});
let response = rx
.recv()
.await .await
{ .ok_or_else(|| anyhow::anyhow!("no response"))?;
return Ok(result);
} return Ok(warp::reply::json(&response));
} }
Err(not_until) => { Err(not_until) => {
// TODO: move this to a helper function // TODO: move this to a helper function
@ -333,19 +507,31 @@ impl Web3ProxyState {
let mut earliest_not_until = None; let mut earliest_not_until = None;
for balanced_rpcs in self.balanced_rpc_tiers.iter() { for balanced_rpcs in self.balanced_rpc_tiers.iter() {
match balanced_rpcs.get_upstream_server().await { match balanced_rpcs.next_upstream_server().await {
Ok(upstream_server) => { Ok(upstream_server) => {
// TODO: capture any errors. at least log them let (tx, mut rx) = mpsc::unbounded_channel::<serde_json::Value>();
if let Ok(result) = self
.try_send_requests( let clone = self.clone();
vec![upstream_server], let connections = balanced_rpcs.connections.clone();
Some(&balanced_rpcs.connections), let json_body = json_body.clone();
&json_body,
) tokio::spawn(async move {
clone
.try_send_requests(
vec![upstream_server],
connections,
json_body,
tx,
)
.await
});
let response = rx
.recv()
.await .await
{ .ok_or_else(|| anyhow::anyhow!("no response"))?;
return Ok(result);
} return Ok(warp::reply::json(&response));
} }
Err(not_until) => { Err(not_until) => {
// save the smallest not_until. if nothing succeeds, return an Err with not_until in it // save the smallest not_until. if nothing succeeds, return an Err with not_until in it
@ -382,59 +568,62 @@ impl Web3ProxyState {
async fn try_send_requests( async fn try_send_requests(
&self, &self,
upstream_servers: Vec<String>, rpc_servers: Vec<String>,
connections: Option<&ConnectionsMap>, connections: Arc<ConnectionsMap>,
json_body: &serde_json::Value, json_request_body: serde_json::Value,
) -> anyhow::Result<String> { tx: mpsc::UnboundedSender<serde_json::Value>,
) -> anyhow::Result<()> {
// {"jsonrpc":"2.0","method":"eth_syncing","params":[],"id":1}
let incoming_id = json_request_body
.get("id")
.ok_or_else(|| anyhow::anyhow!("bad id"))?
.to_owned();
let method = json_request_body
.get("method")
.and_then(|x| x.as_str())
.ok_or_else(|| anyhow::anyhow!("bad id"))?
.to_string();
let params = json_request_body
.get("params")
.ok_or_else(|| anyhow::anyhow!("no params"))?
.to_owned();
// send the query to all the servers // send the query to all the servers
let bodies = future::join_all(upstream_servers.into_iter().map(|url| { let bodies = future::join_all(rpc_servers.into_iter().map(|rpc| {
let client = self.client.clone(); let incoming_id = incoming_id.clone();
let json_body = json_body.clone(); let connections = connections.clone();
tokio::spawn(async move { let method = method.clone();
let params = params.clone();
let tx = tx.clone();
async move {
// get the client for this rpc server
let provider = connections.read().await.get(&rpc).unwrap().provider.clone();
// TODO: there has to be a better way to attach the url to the result // TODO: there has to be a better way to attach the url to the result
client let mut response = provider.request(&method, params).await?;
.post(&url)
.json(&json_body) connections.write().await.get_mut(&rpc).unwrap().dec();
.send()
.await if let Some(response_id) = response.get_mut("id") {
// add the url to the error so that we can reduce connection counters *response_id = incoming_id;
.map_err(|e| (url.clone(), e))? }
.text()
.await // send the first good response to a one shot channel. that way we respond quickly
// add the url to the result so that we can reduce connection counters // drop the result because errors are expected after the first send
.map(|t| (url.clone(), t)) // TODO: if "no block with that header" or some other jsonrpc errors, skip this response
// add the url to the error so that we can reduce connection counters let _ = tx.send(response);
.map_err(|e| (url, e))
}) Ok::<(), anyhow::Error>(())
}
})) }))
.await; .await;
// we are going to collect successes and failures // TODO: use iterators instead of pushing into a vec
let mut oks = vec![];
let mut errs = vec![]; let mut errs = vec![];
for x in bodies {
// TODO: parallel? match x {
for b in bodies { Ok(_) => {}
match b {
Ok(Ok((url, b))) => {
// reduce connection counter
if let Some(connections) = connections {
*connections.get_mut(&url).unwrap() -= 1;
}
// TODO: if "no block with that header" or some other jsonrpc errors, skip this response
oks.push(b);
}
Ok(Err((url, e))) => {
// reduce connection counter
if let Some(connections) = connections {
*connections.get_mut(&url).unwrap() -= 1;
}
// TODO: better errors
eprintln!("Got a reqwest::Error: {}", e);
errs.push(anyhow::anyhow!("Got a reqwest::Error"));
}
Err(e) => { Err(e) => {
// TODO: better errors // TODO: better errors
eprintln!("Got a tokio::JoinError: {}", e); eprintln!("Got a tokio::JoinError: {}", e);
@ -443,10 +632,7 @@ impl Web3ProxyState {
} }
} }
// TODO: which response should we use? if !errs.is_empty() {
if !oks.is_empty() {
Ok(oks.pop().unwrap())
} else if !errs.is_empty() {
Err(errs.pop().unwrap()) Err(errs.pop().unwrap())
} else { } else {
return Err(anyhow::anyhow!("no successful responses")); return Err(anyhow::anyhow!("no successful responses"));
@ -462,7 +648,7 @@ async fn main() {
let listen_port = 8445; let listen_port = 8445;
// TODO: be smart about about using archive nodes? // TODO: be smart about about using archive nodes?
let state = Web3ProxyState::new( let state = Web3ProxyState::try_new(
vec![ vec![
// local nodes // local nodes
vec![("ws://10.11.12.16:8545", 0), ("ws://10.11.12.16:8946", 0)], vec![("ws://10.11.12.16:8545", 0), ("ws://10.11.12.16:8946", 0)],
@ -479,7 +665,9 @@ async fn main() {
("https://api.edennetwork.io/v1/beta", 0), ("https://api.edennetwork.io/v1/beta", 0),
("https://api.edennetwork.io/v1/", 0), ("https://api.edennetwork.io/v1/", 0),
], ],
); )
.await
.unwrap();
let state: Arc<Web3ProxyState> = Arc::new(state); let state: Arc<Web3ProxyState> = Arc::new(state);