connection pooling

This commit is contained in:
Bryan Stitt 2022-07-07 03:22:09 +00:00
parent 3ddde56665
commit 8cc2fab48e
17 changed files with 251 additions and 154 deletions

69
Cargo.lock generated
View File

@ -212,6 +212,16 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-client-ip"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ba6eab8967a7b1121e3cd7491a205dff7a33bb05e65df6c199d687a32e4d3b"
dependencies = [
"axum",
"forwarded-header-value",
]
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.2.6" version = "0.2.6"
@ -281,6 +291,30 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a32fd6af2b5827bce66c29053ba0e7c42b9dcab01835835058558c10851a46b" checksum = "8a32fd6af2b5827bce66c29053ba0e7c42b9dcab01835835058558c10851a46b"
[[package]]
name = "bb8"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1627eccf3aa91405435ba240be23513eeca466b5dc33866422672264de061582"
dependencies = [
"async-trait",
"futures-channel",
"futures-util",
"parking_lot 0.12.1",
"tokio",
]
[[package]]
name = "bb8-redis"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b081e2864000416c5a4543fed63fb8dd979c4a0806b071fddab0e548c6cd4c74"
dependencies = [
"async-trait",
"bb8",
"redis",
]
[[package]] [[package]]
name = "bech32" name = "bech32"
version = "0.7.3" version = "0.7.3"
@ -1455,6 +1489,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror",
]
[[package]] [[package]]
name = "fs2" name = "fs2"
version = "0.4.3" version = "0.4.3"
@ -2281,6 +2325,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54"
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]] [[package]]
name = "notify" name = "notify"
version = "4.0.17" version = "4.0.17"
@ -2824,7 +2874,6 @@ dependencies = [
"itoa 0.4.8", "itoa 0.4.8",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"sha1",
"tokio", "tokio",
"tokio-util 0.6.10", "tokio-util 0.6.10",
"url", "url",
@ -2835,7 +2884,7 @@ name = "redis-cell-client"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"redis", "bb8-redis",
] ]
[[package]] [[package]]
@ -3199,21 +3248,6 @@ dependencies = [
"digest 0.10.3", "digest 0.10.3",
] ]
[[package]]
name = "sha1"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1da05c97445caa12d05e848c4a4fcbbea29e748ac28f7e80e9b010392063770"
dependencies = [
"sha1_smol",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.8.2" version = "0.8.2"
@ -4095,6 +4129,7 @@ dependencies = [
"arc-swap", "arc-swap",
"argh", "argh",
"axum", "axum",
"axum-client-ip",
"counter", "counter",
"dashmap", "dashmap",
"derive_more", "derive_more",

11
TODO.md
View File

@ -34,8 +34,13 @@
- [x] if web3 proxy gets an http error back, retry another node - [x] if web3 proxy gets an http error back, retry another node
- [x] endpoint for health checks. if no synced servers, give a 502 error - [x] endpoint for health checks. if no synced servers, give a 502 error
- [x] rpc errors propagate too far. one subscription failing ends the app. isolate the providers more (might already be fixed) - [x] rpc errors propagate too far. one subscription failing ends the app. isolate the providers more (might already be fixed)
- [ ] incoming rate limiting (by ip) - [x] incoming rate limiting (by ip)
- [ ] connection pool for redis
- [ ] automatically route to archive server when necessary - [ ] automatically route to archive server when necessary
- [ ] handle log subscriptions
- [ ] basic request method stats
- [ ] http servers should check block at the very start
- [ ] Got warning: "WARN subscribe_new_heads:send_block: web3_proxy::connection: unable to get block from https://rpc.ethermine.org: Deserialization Error: expected value at line 1 column 1. Response: error code: 1015". this is cloudflare rate limiting on fetching a block, but this is a private rpc. why is there a block subscription?
## V1 ## V1
@ -43,8 +48,8 @@
- create the app without applying any config to it - create the app without applying any config to it
- have a blocking future watching the config file and calling app.apply_config() on first load and on change - have a blocking future watching the config file and calling app.apply_config() on first load and on change
- work started on this in the "config_reloads" branch. because of how we pass channels around during spawn, this requires a larger refactor. - work started on this in the "config_reloads" branch. because of how we pass channels around during spawn, this requires a larger refactor.
- [ ] interval for http subscriptions should be based on block time. load from config is easy, but - [ ] interval for http subscriptions should be based on block time. load from config is easy, but better to query
- [ ] some things that are cached locally should probably be in shared redis caches - [ ] most things that are cached locally should probably be in shared redis caches
- [ ] stats when forks are resolved (and what chain they were on?) - [ ] stats when forks are resolved (and what chain they were on?)
- [ ] incoming rate limiting (by api key) - [ ] incoming rate limiting (by api key)
- [ ] failsafe. if no blocks or transactions in the last second, warn and reset the connection - [ ] failsafe. if no blocks or transactions in the last second, warn and reset the connection

View File

@ -1,5 +1,6 @@
[shared] [shared]
chain_id = 1 chain_id = 1
public_rate_limit_per_minute = 60_000
[balanced_rpcs] [balanced_rpcs]

View File

@ -1,6 +1,8 @@
[shared] [shared]
chain_id = 1 chain_id = 1
# in prod, do `rate_limit_redis = "redis://redis:6379/"` # in prod, do `rate_limit_redis = "redis://redis:6379/"`
rate_limit_redis = "redis://dev-redis:6379/"
public_rate_limit_per_minute = 60_000
[balanced_rpcs] [balanced_rpcs]

View File

@ -6,4 +6,4 @@ edition = "2018"
[dependencies] [dependencies]
anyhow = "1.0.58" anyhow = "1.0.58"
redis = { version = "0.21.5", features = ["aio", "tokio", "tokio-comp"] } bb8-redis = "0.11.0"

View File

@ -1,85 +0,0 @@
use std::time::Duration;
use redis::aio::MultiplexedConnection;
// TODO: take this as an argument to open?
const KEY_PREFIX: &str = "rate-limit";
pub struct RedisCellClient {
conn: MultiplexedConnection,
key: String,
max_burst: u32,
count_per_period: u32,
period: u32,
}
impl RedisCellClient {
// todo: seems like this could be derived
// TODO: take something generic for conn
// TODO: use r2d2 for connection pooling?
pub fn new(
conn: MultiplexedConnection,
key: String,
max_burst: u32,
count_per_period: u32,
period: u32,
) -> Self {
let key = format!("{}:{}", KEY_PREFIX, key);
Self {
conn,
key,
max_burst,
count_per_period,
period,
}
}
#[inline]
pub async fn throttle(&self) -> Result<(), Duration> {
self.throttle_quantity(1).await
}
#[inline]
pub async fn throttle_quantity(&self, quantity: u32) -> Result<(), Duration> {
/*
https://github.com/brandur/redis-cell#response
CL.THROTTLE <key> <max_burst> <count per period> <period> [<quantity>]
0. Whether the action was limited:
0 indicates the action is allowed.
1 indicates that the action was limited/blocked.
1. The total limit of the key (max_burst + 1). This is equivalent to the common X-RateLimit-Limit HTTP header.
2. The remaining limit of the key. Equivalent to X-RateLimit-Remaining.
3. The number of seconds until the user should retry, and always -1 if the action was allowed. Equivalent to Retry-After.
4. The number of seconds until the limit will reset to its maximum capacity. Equivalent to X-RateLimit-Reset.
*/
// TODO: don't unwrap. maybe return anyhow::Result<Result<(), Duration>>
// TODO: should we return more error info?
let x: Vec<isize> = redis::cmd("CL.THROTTLE")
.arg(&(
&self.key,
self.max_burst,
self.count_per_period,
self.period,
quantity,
))
.query_async(&mut self.conn.clone())
.await
.unwrap();
assert_eq!(x.len(), 5);
// TODO: trace log the result
// TODO: maybe we should do #4
let retry_after = *x.get(3).unwrap();
if retry_after == -1 {
Ok(())
} else {
Err(Duration::from_secs(retry_after as u64))
}
}
}

View File

@ -1,7 +1,102 @@
mod client; use bb8_redis::redis::cmd;
pub use client::RedisCellClient; pub use bb8_redis::{bb8, RedisConnectionManager};
pub use redis;
// TODO: don't hard code MultiplexedConnection use std::time::Duration;
pub use redis::aio::MultiplexedConnection;
pub use redis::Client; // TODO: take this as an argument to open?
const KEY_PREFIX: &str = "rate-limit";
pub type RedisClientPool = bb8::Pool<RedisConnectionManager>;
pub struct RedisCellClient {
pool: RedisClientPool,
key: String,
max_burst: u32,
count_per_period: u32,
period: u32,
}
impl RedisCellClient {
// todo: seems like this could be derived
// TODO: take something generic for conn
// TODO: use r2d2 for connection pooling?
pub fn new(
pool: bb8::Pool<RedisConnectionManager>,
default_key: String,
max_burst: u32,
count_per_period: u32,
period: u32,
) -> Self {
let default_key = format!("{}:{}", KEY_PREFIX, default_key);
Self {
pool,
key: default_key,
max_burst,
count_per_period,
period,
}
}
#[inline]
async fn _throttle(&self, key: &str, quantity: u32) -> Result<(), Duration> {
let mut conn = self.pool.get().await.unwrap();
/*
https://github.com/brandur/redis-cell#response
CL.THROTTLE <key> <max_burst> <count per period> <period> [<quantity>]
0. Whether the action was limited:
0 indicates the action is allowed.
1 indicates that the action was limited/blocked.
1. The total limit of the key (max_burst + 1). This is equivalent to the common X-RateLimit-Limit HTTP header.
2. The remaining limit of the key. Equivalent to X-RateLimit-Remaining.
3. The number of seconds until the user should retry, and always -1 if the action was allowed. Equivalent to Retry-After.
4. The number of seconds until the limit will reset to its maximum capacity. Equivalent to X-RateLimit-Reset.
*/
// TODO: don't unwrap. maybe return anyhow::Result<Result<(), Duration>>
// TODO: should we return more error info?
let x: Vec<isize> = cmd("CL.THROTTLE")
.arg(&(
key,
self.max_burst,
self.count_per_period,
self.period,
quantity,
))
.query_async(&mut *conn)
.await
.unwrap();
assert_eq!(x.len(), 5);
// TODO: trace log the result
let retry_after = *x.get(3).unwrap();
if retry_after == -1 {
Ok(())
} else {
Err(Duration::from_secs(retry_after as u64))
}
}
#[inline]
pub async fn throttle(&self) -> Result<(), Duration> {
self._throttle(&self.key, 1).await
}
#[inline]
pub async fn throttle_key(&self, key: &str) -> Result<(), Duration> {
let key = format!("{}:{}", KEY_PREFIX, key);
self._throttle(key.as_ref(), 1).await
}
#[inline]
pub async fn throttle_quantity(&self, quantity: u32) -> Result<(), Duration> {
self._throttle(&self.key, quantity).await
}
}

View File

@ -14,6 +14,7 @@ anyhow = { version = "1.0.58", features = ["backtrace"] }
arc-swap = "1.5.0" arc-swap = "1.5.0"
argh = "0.1.8" argh = "0.1.8"
axum = { version = "0.5.11", features = ["serde_json", "tokio-tungstenite", "ws"] } axum = { version = "0.5.11", features = ["serde_json", "tokio-tungstenite", "ws"] }
axum-client-ip = "0.2.0"
counter = "0.5.5" counter = "0.5.5"
dashmap = "5.3.4" dashmap = "5.3.4"
derive_more = "0.99.17" derive_more = "0.99.17"

View File

@ -10,6 +10,7 @@ use futures::stream::StreamExt;
use futures::Future; use futures::Future;
use linkedhashmap::LinkedHashMap; use linkedhashmap::LinkedHashMap;
use parking_lot::RwLock; use parking_lot::RwLock;
use redis_cell_client::{bb8, RedisCellClient, RedisClientPool, RedisConnectionManager};
use serde_json::json; use serde_json::json;
use std::fmt; use std::fmt;
use std::pin::Pin; use std::pin::Pin;
@ -20,7 +21,7 @@ use tokio::sync::{broadcast, watch};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::time::timeout; use tokio::time::timeout;
use tokio_stream::wrappers::{BroadcastStream, WatchStream}; use tokio_stream::wrappers::{BroadcastStream, WatchStream};
use tracing::{debug, info, info_span, instrument, trace, warn, Instrument}; use tracing::{info, info_span, instrument, trace, warn, Instrument};
use crate::config::Web3ConnectionConfig; use crate::config::Web3ConnectionConfig;
use crate::connections::Web3Connections; use crate::connections::Web3Connections;
@ -94,6 +95,7 @@ pub struct Web3ProxyApp {
head_block_receiver: watch::Receiver<Block<TxHash>>, head_block_receiver: watch::Receiver<Block<TxHash>>,
pending_tx_sender: broadcast::Sender<TxState>, pending_tx_sender: broadcast::Sender<TxState>,
pending_transactions: Arc<DashMap<TxHash, TxState>>, pending_transactions: Arc<DashMap<TxHash, TxState>>,
public_rate_limiter: Option<RedisCellClient>,
next_subscription_id: AtomicUsize, next_subscription_id: AtomicUsize,
} }
@ -109,11 +111,16 @@ impl Web3ProxyApp {
&self.pending_transactions &self.pending_transactions
} }
pub fn get_public_rate_limiter(&self) -> Option<&RedisCellClient> {
self.public_rate_limiter.as_ref()
}
pub async fn spawn( pub async fn spawn(
chain_id: usize, chain_id: usize,
redis_address: Option<String>, redis_address: Option<String>,
balanced_rpcs: Vec<Web3ConnectionConfig>, balanced_rpcs: Vec<Web3ConnectionConfig>,
private_rpcs: Vec<Web3ConnectionConfig>, private_rpcs: Vec<Web3ConnectionConfig>,
public_rate_limit_per_minute: u32,
) -> anyhow::Result<( ) -> anyhow::Result<(
Arc<Web3ProxyApp>, Arc<Web3ProxyApp>,
Pin<Box<dyn Future<Output = anyhow::Result<()>>>>, Pin<Box<dyn Future<Output = anyhow::Result<()>>>>,
@ -132,15 +139,14 @@ impl Web3ProxyApp {
.build()?, .build()?,
); );
let rate_limiter = match redis_address { let rate_limiter_pool = match redis_address {
Some(redis_address) => { Some(redis_address) => {
info!("Connecting to redis on {}", redis_address); info!("Connecting to redis on {}", redis_address);
let redis_client = redis_cell_client::Client::open(redis_address)?;
// TODO: r2d2 connection pool? let manager = RedisConnectionManager::new(redis_address)?;
let redis_conn = redis_client.get_multiplexed_tokio_connection().await?; let pool = bb8::Pool::builder().build(manager).await?;
Some(redis_conn) Some(pool)
} }
None => { None => {
info!("No redis address"); info!("No redis address");
@ -164,7 +170,7 @@ impl Web3ProxyApp {
chain_id, chain_id,
balanced_rpcs, balanced_rpcs,
http_client.as_ref(), http_client.as_ref(),
rate_limiter.as_ref(), rate_limiter_pool.as_ref(),
Some(head_block_sender), Some(head_block_sender),
Some(pending_tx_sender.clone()), Some(pending_tx_sender.clone()),
pending_transactions.clone(), pending_transactions.clone(),
@ -182,7 +188,7 @@ impl Web3ProxyApp {
chain_id, chain_id,
private_rpcs, private_rpcs,
http_client.as_ref(), http_client.as_ref(),
rate_limiter.as_ref(), rate_limiter_pool.as_ref(),
// subscribing to new heads here won't work well // subscribing to new heads here won't work well
None, None,
// TODO: subscribe to pending transactions on the private rpcs? // TODO: subscribe to pending transactions on the private rpcs?
@ -199,7 +205,20 @@ impl Web3ProxyApp {
// TODO: use this? it could listen for confirmed transactions and then clear pending_transactions, but the head_block_sender is doing that // TODO: use this? it could listen for confirmed transactions and then clear pending_transactions, but the head_block_sender is doing that
drop(pending_tx_receiver); drop(pending_tx_receiver);
let app = Web3ProxyApp { // TODO: how much should we allow?
let public_max_burst = public_rate_limit_per_minute / 3;
let public_rate_limiter = rate_limiter_pool.as_ref().map(|redis_client_pool| {
RedisCellClient::new(
redis_client_pool.clone(),
"public".to_string(),
public_max_burst,
public_rate_limit_per_minute,
60,
)
});
let app = Self {
balanced_rpcs, balanced_rpcs,
private_rpcs, private_rpcs,
incoming_requests: Default::default(), incoming_requests: Default::default(),
@ -207,6 +226,7 @@ impl Web3ProxyApp {
head_block_receiver, head_block_receiver,
pending_tx_sender, pending_tx_sender,
pending_transactions, pending_transactions,
public_rate_limiter,
next_subscription_id: 1.into(), next_subscription_id: 1.into(),
}; };
@ -431,7 +451,7 @@ impl Web3ProxyApp {
request: JsonRpcRequestEnum, request: JsonRpcRequestEnum,
) -> anyhow::Result<JsonRpcForwardedResponseEnum> { ) -> anyhow::Result<JsonRpcForwardedResponseEnum> {
// TODO: i don't always see this in the logs. why? // TODO: i don't always see this in the logs. why?
debug!("Received request: {:?}", request); trace!("Received request: {:?}", request);
// even though we have timeouts on the requests to our backend providers, // even though we have timeouts on the requests to our backend providers,
// we need a timeout for the incoming request so that delays from // we need a timeout for the incoming request so that delays from
@ -447,7 +467,7 @@ impl Web3ProxyApp {
}; };
// TODO: i don't always see this in the logs. why? // TODO: i don't always see this in the logs. why?
debug!("Forwarding response: {:?}", response); trace!("Forwarding response: {:?}", response);
Ok(response) Ok(response)
} }

View File

@ -40,6 +40,9 @@ pub struct RpcSharedConfig {
/// TODO: what type for chain_id? TODO: this isn't at the right level. this is inside a "Config" /// TODO: what type for chain_id? TODO: this isn't at the right level. this is inside a "Config"
pub chain_id: usize, pub chain_id: usize,
pub rate_limit_redis: Option<String>, pub rate_limit_redis: Option<String>,
// TODO: serde default for development?
// TODO: allow no limit?
pub public_rate_limit_per_minute: u32,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -71,6 +74,7 @@ impl RpcConfig {
self.shared.rate_limit_redis, self.shared.rate_limit_redis,
balanced_rpcs, balanced_rpcs,
private_rpcs, private_rpcs,
self.shared.public_rate_limit_per_minute,
) )
.await .await
} }
@ -81,14 +85,14 @@ impl Web3ConnectionConfig {
// #[instrument(name = "try_build_Web3ConnectionConfig", skip_all)] // #[instrument(name = "try_build_Web3ConnectionConfig", skip_all)]
pub async fn spawn( pub async fn spawn(
self, self,
rate_limiter: Option<&redis_cell_client::MultiplexedConnection>, redis_client_pool: Option<&redis_cell_client::RedisClientPool>,
chain_id: usize, chain_id: usize,
http_client: Option<&reqwest::Client>, http_client: Option<&reqwest::Client>,
http_interval_sender: Option<Arc<broadcast::Sender<()>>>, http_interval_sender: Option<Arc<broadcast::Sender<()>>>,
block_sender: Option<flume::Sender<(Block<TxHash>, Arc<Web3Connection>)>>, block_sender: Option<flume::Sender<(Block<TxHash>, Arc<Web3Connection>)>>,
tx_id_sender: Option<flume::Sender<(TxHash, Arc<Web3Connection>)>>, tx_id_sender: Option<flume::Sender<(TxHash, Arc<Web3Connection>)>>,
) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> { ) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> {
let hard_rate_limit = self.hard_limit.map(|x| (x, rate_limiter.unwrap())); let hard_rate_limit = self.hard_limit.map(|x| (x, redis_client_pool.unwrap()));
Web3Connection::spawn( Web3Connection::spawn(
chain_id, chain_id,

View File

@ -131,7 +131,7 @@ impl Web3Connection {
// optional because this is only used for http providers. websocket providers don't use it // optional because this is only used for http providers. websocket providers don't use it
http_client: Option<&reqwest::Client>, http_client: Option<&reqwest::Client>,
http_interval_sender: Option<Arc<broadcast::Sender<()>>>, http_interval_sender: Option<Arc<broadcast::Sender<()>>>,
hard_limit: Option<(u32, &redis_cell_client::MultiplexedConnection)>, hard_limit: Option<(u32, &redis_cell_client::RedisClientPool)>,
// TODO: think more about this type // TODO: think more about this type
soft_limit: u32, soft_limit: u32,
block_sender: Option<flume::Sender<(Block<TxHash>, Arc<Self>)>>, block_sender: Option<flume::Sender<(Block<TxHash>, Arc<Self>)>>,
@ -356,10 +356,6 @@ impl Web3Connection {
let mut last_hash = Default::default(); let mut last_hash = Default::default();
loop { loop {
// wait for the interval
// TODO: if error or rate limit, increase interval?
http_interval_receiver.recv().await.unwrap();
match self.try_request_handle().await { match self.try_request_handle().await {
Ok(active_request_handle) => { Ok(active_request_handle) => {
// TODO: i feel like this should be easier. there is a provider.getBlock, but i don't know how to give it "latest" // TODO: i feel like this should be easier. there is a provider.getBlock, but i don't know how to give it "latest"
@ -384,6 +380,10 @@ impl Web3Connection {
warn!("Failed getting latest block from {}: {:?}", self, e); warn!("Failed getting latest block from {}: {:?}", self, e);
} }
} }
// wait for the interval
// TODO: if error or rate limit, increase interval?
http_interval_receiver.recv().await.unwrap();
} }
} }
Web3Provider::Ws(provider) => { Web3Provider::Ws(provider) => {
@ -447,10 +447,6 @@ impl Web3Connection {
// TODO: create a filter // TODO: create a filter
loop { loop {
// wait for the interval
// TODO: if error or rate limit, increase interval?
interval.tick().await;
// TODO: actually do something here // TODO: actually do something here
/* /*
match self.try_request_handle().await { match self.try_request_handle().await {
@ -463,6 +459,10 @@ impl Web3Connection {
} }
} }
*/ */
// wait for the interval
// TODO: if error or rate limit, increase interval?
interval.tick().await;
} }
} }
Web3Provider::Ws(provider) => { Web3Provider::Ws(provider) => {

View File

@ -89,7 +89,7 @@ impl Web3Connections {
chain_id: usize, chain_id: usize,
server_configs: Vec<Web3ConnectionConfig>, server_configs: Vec<Web3ConnectionConfig>,
http_client: Option<&reqwest::Client>, http_client: Option<&reqwest::Client>,
rate_limiter: Option<&redis_cell_client::MultiplexedConnection>, redis_client_pool: Option<&redis_cell_client::RedisClientPool>,
head_block_sender: Option<watch::Sender<Block<TxHash>>>, head_block_sender: Option<watch::Sender<Block<TxHash>>>,
pending_tx_sender: Option<broadcast::Sender<TxState>>, pending_tx_sender: Option<broadcast::Sender<TxState>>,
pending_transactions: Arc<DashMap<TxHash, TxState>>, pending_transactions: Arc<DashMap<TxHash, TxState>>,
@ -141,7 +141,7 @@ impl Web3Connections {
for server_config in server_configs.into_iter() { for server_config in server_configs.into_iter() {
match server_config match server_config
.spawn( .spawn(
rate_limiter, redis_client_pool,
chain_id, chain_id,
http_client, http_client,
http_interval_sender.clone(), http_interval_sender.clone(),

View File

@ -7,15 +7,15 @@ use crate::jsonrpc::JsonRpcForwardedResponse;
pub async fn handler_404() -> impl IntoResponse { pub async fn handler_404() -> impl IntoResponse {
let err = anyhow::anyhow!("nothing to see here"); let err = anyhow::anyhow!("nothing to see here");
handle_anyhow_error(err, Some(StatusCode::NOT_FOUND)).await handle_anyhow_error(Some(StatusCode::NOT_FOUND), err).await
} }
/// handle errors by converting them into something that implements `IntoResponse` /// handle errors by converting them into something that implements `IntoResponse`
/// TODO: use this. i can't get https://docs.rs/axum/latest/axum/error_handling/index.html to work /// TODO: use this. i can't get https://docs.rs/axum/latest/axum/error_handling/index.html to work
/// TODO: i think we want a custom result type instead. put the anyhow result inside. then `impl IntoResponse for CustomResult` /// TODO: i think we want a custom result type instead. put the anyhow result inside. then `impl IntoResponse for CustomResult`
pub async fn handle_anyhow_error( pub async fn handle_anyhow_error(
err: anyhow::Error,
code: Option<StatusCode>, code: Option<StatusCode>,
err: anyhow::Error,
) -> impl IntoResponse { ) -> impl IntoResponse {
let id = RawValue::from_string("null".to_string()).unwrap(); let id = RawValue::from_string("null".to_string()).unwrap();

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use crate::app::Web3ProxyApp; use crate::app::Web3ProxyApp;
/// Health check page for load balancers to use /// Health check page for load balancers to use
pub async fn health(app: Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse { pub async fn health(Extension(app): Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse {
if app.get_balanced_rpcs().has_synced_rpcs() { if app.get_balanced_rpcs().has_synced_rpcs() {
(StatusCode::OK, "OK") (StatusCode::OK, "OK")
} else { } else {
@ -14,7 +14,7 @@ pub async fn health(app: Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse {
} }
/// Very basic status page /// Very basic status page
pub async fn status(app: Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse { pub async fn status(Extension(app): Extension<Arc<Web3ProxyApp>>) -> impl IntoResponse {
// TODO: what else should we include? uptime? prometheus? // TODO: what else should we include? uptime? prometheus?
let balanced_rpcs = app.get_balanced_rpcs(); let balanced_rpcs = app.get_balanced_rpcs();
let private_rpcs = app.get_private_rpcs(); let private_rpcs = app.get_private_rpcs();

View File

@ -1,15 +1,34 @@
use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use axum_client_ip::ClientIp;
use std::sync::Arc; use std::sync::Arc;
use super::errors::handle_anyhow_error; use super::errors::handle_anyhow_error;
use crate::{app::Web3ProxyApp, jsonrpc::JsonRpcRequestEnum}; use crate::{app::Web3ProxyApp, jsonrpc::JsonRpcRequestEnum};
pub async fn proxy_web3_rpc( pub async fn proxy_web3_rpc(
payload: Json<JsonRpcRequestEnum>, Json(payload): Json<JsonRpcRequestEnum>,
app: Extension<Arc<Web3ProxyApp>>, Extension(app): Extension<Arc<Web3ProxyApp>>,
ClientIp(ip): ClientIp,
) -> impl IntoResponse { ) -> impl IntoResponse {
match app.proxy_web3_rpc(payload.0).await { if let Some(rate_limiter) = app.get_public_rate_limiter() {
let rate_limiter_key = format!("{}", ip);
if rate_limiter.throttle_key(&rate_limiter_key).await.is_err() {
// TODO: set headers so they know when they can retry
// warn!(?ip, "public rate limit exceeded");
return handle_anyhow_error(
Some(StatusCode::TOO_MANY_REQUESTS),
anyhow::anyhow!("too many requests"),
)
.await
.into_response();
}
} else {
// TODO: if no redis, rate limit with a local cache?
}
match app.proxy_web3_rpc(payload).await {
Ok(response) => (StatusCode::OK, Json(&response)).into_response(), Ok(response) => (StatusCode::OK, Json(&response)).into_response(),
Err(err) => handle_anyhow_error(err, None).await.into_response(), Err(err) => handle_anyhow_error(None, err).await.into_response(),
} }
} }

View File

@ -3,6 +3,7 @@ mod errors;
mod http; mod http;
mod http_proxy; mod http_proxy;
mod ws_proxy; mod ws_proxy;
use axum::{ use axum::{
handler::Handler, handler::Handler,
routing::{get, post}, routing::{get, post},
@ -35,8 +36,9 @@ pub async fn run(port: u16, proxy_app: Arc<Web3ProxyApp>) -> anyhow::Result<()>
// `axum::Server` is a re-export of `hyper::Server` // `axum::Server` is a re-export of `hyper::Server`
let addr = SocketAddr::from(([0, 0, 0, 0], port)); let addr = SocketAddr::from(([0, 0, 0, 0], port));
debug!("listening on port {}", port); debug!("listening on port {}", port);
// TODO: into_make_service is enough if we always run behind a proxy. make into_make_service_with_connect_info optional?
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await .await
.map_err(Into::into) .map_err(Into::into)
} }

View File

@ -12,7 +12,7 @@ use hashbrown::HashMap;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use std::str::from_utf8_mut; use std::str::from_utf8_mut;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, error, info, warn}; use tracing::{error, info, trace, warn};
use crate::{ use crate::{
app::Web3ProxyApp, app::Web3ProxyApp,
@ -20,13 +20,13 @@ use crate::{
}; };
pub async fn websocket_handler( pub async fn websocket_handler(
app: Extension<Arc<Web3ProxyApp>>, Extension(app): Extension<Arc<Web3ProxyApp>>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(|socket| proxy_web3_socket(app, socket)) ws.on_upgrade(|socket| proxy_web3_socket(app, socket))
} }
async fn proxy_web3_socket(app: Extension<Arc<Web3ProxyApp>>, socket: WebSocket) { async fn proxy_web3_socket(app: Arc<Web3ProxyApp>, socket: WebSocket) {
// split the websocket so we can read and write concurrently // split the websocket so we can read and write concurrently
let (ws_tx, ws_rx) = socket.split(); let (ws_tx, ws_rx) = socket.split();
@ -109,7 +109,7 @@ async fn handle_socket_payload(
} }
async fn read_web3_socket( async fn read_web3_socket(
app: Extension<Arc<Web3ProxyApp>>, app: Arc<Web3ProxyApp>,
mut ws_rx: SplitStream<WebSocket>, mut ws_rx: SplitStream<WebSocket>,
response_tx: flume::Sender<Message>, response_tx: flume::Sender<Message>,
) { ) {
@ -119,12 +119,11 @@ async fn read_web3_socket(
// new message from our client. forward to a backend and then send it through response_tx // new message from our client. forward to a backend and then send it through response_tx
let response_msg = match msg { let response_msg = match msg {
Message::Text(payload) => { Message::Text(payload) => {
handle_socket_payload(app.0.clone(), &payload, &response_tx, &mut subscriptions) handle_socket_payload(app.clone(), &payload, &response_tx, &mut subscriptions).await
.await
} }
Message::Ping(x) => Message::Pong(x), Message::Ping(x) => Message::Pong(x),
Message::Pong(x) => { Message::Pong(x) => {
debug!("pong: {:?}", x); trace!("pong: {:?}", x);
continue; continue;
} }
Message::Close(_) => { Message::Close(_) => {
@ -134,8 +133,7 @@ async fn read_web3_socket(
Message::Binary(mut payload) => { Message::Binary(mut payload) => {
let payload = from_utf8_mut(&mut payload).unwrap(); let payload = from_utf8_mut(&mut payload).unwrap();
handle_socket_payload(app.0.clone(), payload, &response_tx, &mut subscriptions) handle_socket_payload(app.clone(), payload, &response_tx, &mut subscriptions).await
.await
} }
}; };