diff --git a/src/main.rs b/src/main.rs index 7748e440..aa9b903f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,12 +17,12 @@ const PARALLEL_REQUESTS: usize = 4; type RpcRateLimiter = RateLimiter>; +type ConnectionsMap = DashMap; + /// Load balance to the least-connection rpc struct BalancedRpcs { rpcs: RwLock>, - connections: DashMap, - // TODO: what type? store with connections? - // ratelimits: RateLimiter, dyn governor::clock::Clock>, + connections: ConnectionsMap, ratelimits: DashMap, } @@ -95,7 +95,6 @@ impl BalancedRpcs { }; // increment our connection counter - // TODO: need to change this to be an atomic counter! let mut connections = self.connections.get_mut(selected_rpc).unwrap(); *connections += 1; @@ -175,7 +174,10 @@ impl Web3ProxyState { // there are private rpcs configured and the request is eth_sendSignedTransaction. send to all private rpcs let upstream_servers = self.private_rpcs.get_upstream_servers().await; - if let Ok(result) = self.try_send_requests(upstream_servers, &json_body).await { + if let Ok(result) = self + .try_send_requests(upstream_servers, None, &json_body) + .await + { return Ok(result); } } else { @@ -184,7 +186,11 @@ impl Web3ProxyState { if let Ok(upstream_server) = balanced_rpcs.get_upstream_server().await { // TODO: capture any errors. at least log them if let Ok(result) = self - .try_send_requests(vec![upstream_server], &json_body) + .try_send_requests( + vec![upstream_server], + Some(&balanced_rpcs.connections), + &json_body, + ) .await { return Ok(result); @@ -201,6 +207,7 @@ impl Web3ProxyState { async fn try_send_requests( &self, upstream_servers: Vec, + connections: Option<&ConnectionsMap>, json_body: &serde_json::Value, ) -> anyhow::Result { // send the query to all the servers @@ -209,8 +216,16 @@ impl Web3ProxyState { let client = self.client.clone(); let json_body = json_body.clone(); tokio::spawn(async move { - let resp = client.post(url).json(&json_body).send().await?; - resp.text().await + let resp = client + .post(&url) + .json(&json_body) + .send() + .await + .map_err(|e| (url.clone(), e))?; + resp.text() + .await + .map(|t| (url.clone(), t)) + .map_err(|e| (url, e)) }) }) .buffer_unordered(PARALLEL_REQUESTS); @@ -220,12 +235,21 @@ impl Web3ProxyState { while let Some(b) = bodies.next().await { // TODO: reduce connection counter + match b { - Ok(Ok(b)) => { + Ok(Ok((url, b))) => { + if let Some(connections) = connections { + *connections.get_mut(&url).unwrap() -= 1; + } + // TODO: if "no block with that header", skip this response (maybe retry) oks.push(b); } - Ok(Err(e)) => { + Ok(Err((url, e))) => { + if let Some(connections) = connections { + *connections.get_mut(&url).unwrap() -= 1; + } + // TODO: better errors eprintln!("Got a reqwest::Error: {}", e); errs.push(anyhow::anyhow!("Got a reqwest::Error"));