From 8cc2fab48e6300969d12511aa093c37f198ece0f Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Thu, 7 Jul 2022 03:22:09 +0000 Subject: [PATCH] connection pooling --- Cargo.lock | 69 +++++++++++++---- TODO.md | 11 ++- config/example.bac | 1 + config/example.toml | 2 + redis-cell-client/Cargo.toml | 2 +- redis-cell-client/src/client.rs | 85 -------------------- redis-cell-client/src/lib.rs | 107 ++++++++++++++++++++++++-- web3-proxy/Cargo.toml | 1 + web3-proxy/src/app.rs | 42 +++++++--- web3-proxy/src/config.rs | 8 +- web3-proxy/src/connection.rs | 18 ++--- web3-proxy/src/connections.rs | 4 +- web3-proxy/src/frontend/errors.rs | 4 +- web3-proxy/src/frontend/http.rs | 4 +- web3-proxy/src/frontend/http_proxy.rs | 27 ++++++- web3-proxy/src/frontend/mod.rs | 4 +- web3-proxy/src/frontend/ws_proxy.rs | 16 ++-- 17 files changed, 251 insertions(+), 154 deletions(-) delete mode 100644 redis-cell-client/src/client.rs diff --git a/Cargo.lock b/Cargo.lock index bbb69905..620eab24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,6 +212,16 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-client-ip" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ba6eab8967a7b1121e3cd7491a205dff7a33bb05e65df6c199d687a32e4d3b" +dependencies = [ + "axum", + "forwarded-header-value", +] + [[package]] name = "axum-core" version = "0.2.6" @@ -281,6 +291,30 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a32fd6af2b5827bce66c29053ba0e7c42b9dcab01835835058558c10851a46b" +[[package]] +name = "bb8" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1627eccf3aa91405435ba240be23513eeca466b5dc33866422672264de061582" +dependencies = [ + "async-trait", + "futures-channel", + "futures-util", + "parking_lot 0.12.1", + "tokio", +] + +[[package]] +name = "bb8-redis" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b081e2864000416c5a4543fed63fb8dd979c4a0806b071fddab0e548c6cd4c74" +dependencies = [ + "async-trait", + "bb8", + "redis", +] + [[package]] name = "bech32" version = "0.7.3" @@ -1455,6 +1489,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror", +] + [[package]] name = "fs2" version = "0.4.3" @@ -2281,6 +2325,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + [[package]] name = "notify" version = "4.0.17" @@ -2824,7 +2874,6 @@ dependencies = [ "itoa 0.4.8", "percent-encoding", "pin-project-lite", - "sha1", "tokio", "tokio-util 0.6.10", "url", @@ -2835,7 +2884,7 @@ name = "redis-cell-client" version = "0.2.0" dependencies = [ "anyhow", - "redis", + "bb8-redis", ] [[package]] @@ -3199,21 +3248,6 @@ dependencies = [ "digest 0.10.3", ] -[[package]] -name = "sha1" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1da05c97445caa12d05e848c4a4fcbbea29e748ac28f7e80e9b010392063770" -dependencies = [ - "sha1_smol", -] - -[[package]] -name = "sha1_smol" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" - [[package]] name = "sha2" version = "0.8.2" @@ -4095,6 +4129,7 @@ dependencies = [ "arc-swap", "argh", "axum", + "axum-client-ip", "counter", "dashmap", "derive_more", diff --git a/TODO.md b/TODO.md index bae3cca4..afec16aa 100644 --- a/TODO.md +++ b/TODO.md @@ -34,8 +34,13 @@ - [x] if web3 proxy gets an http error back, retry another node - [x] endpoint for health checks. if no synced servers, give a 502 error - [x] rpc errors propagate too far. one subscription failing ends the app. isolate the providers more (might already be fixed) -- [ ] incoming rate limiting (by ip) +- [x] incoming rate limiting (by ip) +- [ ] connection pool for redis - [ ] automatically route to archive server when necessary +- [ ] handle log subscriptions +- [ ] basic request method stats +- [ ] http servers should check block at the very start +- [ ] Got warning: "WARN subscribe_new_heads:send_block: web3_proxy::connection: unable to get block from https://rpc.ethermine.org: Deserialization Error: expected value at line 1 column 1. Response: error code: 1015". this is cloudflare rate limiting on fetching a block, but this is a private rpc. why is there a block subscription? ## V1 @@ -43,8 +48,8 @@ - create the app without applying any config to it - have a blocking future watching the config file and calling app.apply_config() on first load and on change - work started on this in the "config_reloads" branch. because of how we pass channels around during spawn, this requires a larger refactor. -- [ ] interval for http subscriptions should be based on block time. load from config is easy, but -- [ ] some things that are cached locally should probably be in shared redis caches +- [ ] interval for http subscriptions should be based on block time. load from config is easy, but better to query +- [ ] most things that are cached locally should probably be in shared redis caches - [ ] stats when forks are resolved (and what chain they were on?) - [ ] incoming rate limiting (by api key) - [ ] failsafe. if no blocks or transactions in the last second, warn and reset the connection diff --git a/config/example.bac b/config/example.bac index 0a0ad1a9..4dbbb0e2 100644 --- a/config/example.bac +++ b/config/example.bac @@ -1,5 +1,6 @@ [shared] chain_id = 1 +public_rate_limit_per_minute = 60_000 [balanced_rpcs] diff --git a/config/example.toml b/config/example.toml index d963691a..7e4b81d4 100644 --- a/config/example.toml +++ b/config/example.toml @@ -1,6 +1,8 @@ [shared] chain_id = 1 # in prod, do `rate_limit_redis = "redis://redis:6379/"` +rate_limit_redis = "redis://dev-redis:6379/" +public_rate_limit_per_minute = 60_000 [balanced_rpcs] diff --git a/redis-cell-client/Cargo.toml b/redis-cell-client/Cargo.toml index 9ce816c3..756f8121 100644 --- a/redis-cell-client/Cargo.toml +++ b/redis-cell-client/Cargo.toml @@ -6,4 +6,4 @@ edition = "2018" [dependencies] anyhow = "1.0.58" -redis = { version = "0.21.5", features = ["aio", "tokio", "tokio-comp"] } +bb8-redis = "0.11.0" diff --git a/redis-cell-client/src/client.rs b/redis-cell-client/src/client.rs deleted file mode 100644 index 397bc95b..00000000 --- a/redis-cell-client/src/client.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::time::Duration; - -use redis::aio::MultiplexedConnection; - -// TODO: take this as an argument to open? -const KEY_PREFIX: &str = "rate-limit"; - -pub struct RedisCellClient { - conn: MultiplexedConnection, - key: String, - max_burst: u32, - count_per_period: u32, - period: u32, -} - -impl RedisCellClient { - // todo: seems like this could be derived - // TODO: take something generic for conn - // TODO: use r2d2 for connection pooling? - pub fn new( - conn: MultiplexedConnection, - key: String, - max_burst: u32, - count_per_period: u32, - period: u32, - ) -> Self { - let key = format!("{}:{}", KEY_PREFIX, key); - - Self { - conn, - key, - max_burst, - count_per_period, - period, - } - } - - #[inline] - pub async fn throttle(&self) -> Result<(), Duration> { - self.throttle_quantity(1).await - } - - #[inline] - pub async fn throttle_quantity(&self, quantity: u32) -> Result<(), Duration> { - /* - https://github.com/brandur/redis-cell#response - - CL.THROTTLE [] - - 0. Whether the action was limited: - 0 indicates the action is allowed. - 1 indicates that the action was limited/blocked. - 1. The total limit of the key (max_burst + 1). This is equivalent to the common X-RateLimit-Limit HTTP header. - 2. The remaining limit of the key. Equivalent to X-RateLimit-Remaining. - 3. The number of seconds until the user should retry, and always -1 if the action was allowed. Equivalent to Retry-After. - 4. The number of seconds until the limit will reset to its maximum capacity. Equivalent to X-RateLimit-Reset. - */ - // TODO: don't unwrap. maybe return anyhow::Result> - // TODO: should we return more error info? - let x: Vec = redis::cmd("CL.THROTTLE") - .arg(&( - &self.key, - self.max_burst, - self.count_per_period, - self.period, - quantity, - )) - .query_async(&mut self.conn.clone()) - .await - .unwrap(); - - assert_eq!(x.len(), 5); - - // TODO: trace log the result - - // TODO: maybe we should do #4 - let retry_after = *x.get(3).unwrap(); - - if retry_after == -1 { - Ok(()) - } else { - Err(Duration::from_secs(retry_after as u64)) - } - } -} diff --git a/redis-cell-client/src/lib.rs b/redis-cell-client/src/lib.rs index d28c9f93..47a9da43 100644 --- a/redis-cell-client/src/lib.rs +++ b/redis-cell-client/src/lib.rs @@ -1,7 +1,102 @@ -mod client; +use bb8_redis::redis::cmd; -pub use client::RedisCellClient; -pub use redis; -// TODO: don't hard code MultiplexedConnection -pub use redis::aio::MultiplexedConnection; -pub use redis::Client; +pub use bb8_redis::{bb8, RedisConnectionManager}; + +use std::time::Duration; + +// TODO: take this as an argument to open? +const KEY_PREFIX: &str = "rate-limit"; + +pub type RedisClientPool = bb8::Pool; + +pub struct RedisCellClient { + pool: RedisClientPool, + key: String, + max_burst: u32, + count_per_period: u32, + period: u32, +} + +impl RedisCellClient { + // todo: seems like this could be derived + // TODO: take something generic for conn + // TODO: use r2d2 for connection pooling? + pub fn new( + pool: bb8::Pool, + default_key: String, + max_burst: u32, + count_per_period: u32, + period: u32, + ) -> Self { + let default_key = format!("{}:{}", KEY_PREFIX, default_key); + + Self { + pool, + key: default_key, + max_burst, + count_per_period, + period, + } + } + + #[inline] + async fn _throttle(&self, key: &str, quantity: u32) -> Result<(), Duration> { + let mut conn = self.pool.get().await.unwrap(); + + /* + https://github.com/brandur/redis-cell#response + + CL.THROTTLE [] + + 0. Whether the action was limited: + 0 indicates the action is allowed. + 1 indicates that the action was limited/blocked. + 1. The total limit of the key (max_burst + 1). This is equivalent to the common X-RateLimit-Limit HTTP header. + 2. The remaining limit of the key. Equivalent to X-RateLimit-Remaining. + 3. The number of seconds until the user should retry, and always -1 if the action was allowed. Equivalent to Retry-After. + 4. The number of seconds until the limit will reset to its maximum capacity. Equivalent to X-RateLimit-Reset. + */ + // TODO: don't unwrap. maybe return anyhow::Result> + // TODO: should we return more error info? + let x: Vec = cmd("CL.THROTTLE") + .arg(&( + key, + self.max_burst, + self.count_per_period, + self.period, + quantity, + )) + .query_async(&mut *conn) + .await + .unwrap(); + + assert_eq!(x.len(), 5); + + // TODO: trace log the result + + let retry_after = *x.get(3).unwrap(); + + if retry_after == -1 { + Ok(()) + } else { + Err(Duration::from_secs(retry_after as u64)) + } + } + + #[inline] + pub async fn throttle(&self) -> Result<(), Duration> { + self._throttle(&self.key, 1).await + } + + #[inline] + pub async fn throttle_key(&self, key: &str) -> Result<(), Duration> { + let key = format!("{}:{}", KEY_PREFIX, key); + + self._throttle(key.as_ref(), 1).await + } + + #[inline] + pub async fn throttle_quantity(&self, quantity: u32) -> Result<(), Duration> { + self._throttle(&self.key, quantity).await + } +} diff --git a/web3-proxy/Cargo.toml b/web3-proxy/Cargo.toml index 5419d210..9b5adda7 100644 --- a/web3-proxy/Cargo.toml +++ b/web3-proxy/Cargo.toml @@ -14,6 +14,7 @@ anyhow = { version = "1.0.58", features = ["backtrace"] } arc-swap = "1.5.0" argh = "0.1.8" axum = { version = "0.5.11", features = ["serde_json", "tokio-tungstenite", "ws"] } +axum-client-ip = "0.2.0" counter = "0.5.5" dashmap = "5.3.4" derive_more = "0.99.17" diff --git a/web3-proxy/src/app.rs b/web3-proxy/src/app.rs index a1d80c2f..19556718 100644 --- a/web3-proxy/src/app.rs +++ b/web3-proxy/src/app.rs @@ -10,6 +10,7 @@ use futures::stream::StreamExt; use futures::Future; use linkedhashmap::LinkedHashMap; use parking_lot::RwLock; +use redis_cell_client::{bb8, RedisCellClient, RedisClientPool, RedisConnectionManager}; use serde_json::json; use std::fmt; use std::pin::Pin; @@ -20,7 +21,7 @@ use tokio::sync::{broadcast, watch}; use tokio::task::JoinHandle; use tokio::time::timeout; use tokio_stream::wrappers::{BroadcastStream, WatchStream}; -use tracing::{debug, info, info_span, instrument, trace, warn, Instrument}; +use tracing::{info, info_span, instrument, trace, warn, Instrument}; use crate::config::Web3ConnectionConfig; use crate::connections::Web3Connections; @@ -94,6 +95,7 @@ pub struct Web3ProxyApp { head_block_receiver: watch::Receiver>, pending_tx_sender: broadcast::Sender, pending_transactions: Arc>, + public_rate_limiter: Option, next_subscription_id: AtomicUsize, } @@ -109,11 +111,16 @@ impl Web3ProxyApp { &self.pending_transactions } + pub fn get_public_rate_limiter(&self) -> Option<&RedisCellClient> { + self.public_rate_limiter.as_ref() + } + pub async fn spawn( chain_id: usize, redis_address: Option, balanced_rpcs: Vec, private_rpcs: Vec, + public_rate_limit_per_minute: u32, ) -> anyhow::Result<( Arc, Pin>>>, @@ -132,15 +139,14 @@ impl Web3ProxyApp { .build()?, ); - let rate_limiter = match redis_address { + let rate_limiter_pool = match redis_address { Some(redis_address) => { info!("Connecting to redis on {}", redis_address); - let redis_client = redis_cell_client::Client::open(redis_address)?; - // TODO: r2d2 connection pool? - let redis_conn = redis_client.get_multiplexed_tokio_connection().await?; + let manager = RedisConnectionManager::new(redis_address)?; + let pool = bb8::Pool::builder().build(manager).await?; - Some(redis_conn) + Some(pool) } None => { info!("No redis address"); @@ -164,7 +170,7 @@ impl Web3ProxyApp { chain_id, balanced_rpcs, http_client.as_ref(), - rate_limiter.as_ref(), + rate_limiter_pool.as_ref(), Some(head_block_sender), Some(pending_tx_sender.clone()), pending_transactions.clone(), @@ -182,7 +188,7 @@ impl Web3ProxyApp { chain_id, private_rpcs, http_client.as_ref(), - rate_limiter.as_ref(), + rate_limiter_pool.as_ref(), // subscribing to new heads here won't work well None, // TODO: subscribe to pending transactions on the private rpcs? @@ -199,7 +205,20 @@ impl Web3ProxyApp { // TODO: use this? it could listen for confirmed transactions and then clear pending_transactions, but the head_block_sender is doing that drop(pending_tx_receiver); - let app = Web3ProxyApp { + // TODO: how much should we allow? + let public_max_burst = public_rate_limit_per_minute / 3; + + let public_rate_limiter = rate_limiter_pool.as_ref().map(|redis_client_pool| { + RedisCellClient::new( + redis_client_pool.clone(), + "public".to_string(), + public_max_burst, + public_rate_limit_per_minute, + 60, + ) + }); + + let app = Self { balanced_rpcs, private_rpcs, incoming_requests: Default::default(), @@ -207,6 +226,7 @@ impl Web3ProxyApp { head_block_receiver, pending_tx_sender, pending_transactions, + public_rate_limiter, next_subscription_id: 1.into(), }; @@ -431,7 +451,7 @@ impl Web3ProxyApp { request: JsonRpcRequestEnum, ) -> anyhow::Result { // TODO: i don't always see this in the logs. why? - debug!("Received request: {:?}", request); + trace!("Received request: {:?}", request); // even though we have timeouts on the requests to our backend providers, // we need a timeout for the incoming request so that delays from @@ -447,7 +467,7 @@ impl Web3ProxyApp { }; // TODO: i don't always see this in the logs. why? - debug!("Forwarding response: {:?}", response); + trace!("Forwarding response: {:?}", response); Ok(response) } diff --git a/web3-proxy/src/config.rs b/web3-proxy/src/config.rs index 913f2a98..de794954 100644 --- a/web3-proxy/src/config.rs +++ b/web3-proxy/src/config.rs @@ -40,6 +40,9 @@ pub struct RpcSharedConfig { /// TODO: what type for chain_id? TODO: this isn't at the right level. this is inside a "Config" pub chain_id: usize, pub rate_limit_redis: Option, + // TODO: serde default for development? + // TODO: allow no limit? + pub public_rate_limit_per_minute: u32, } #[derive(Debug, Deserialize)] @@ -71,6 +74,7 @@ impl RpcConfig { self.shared.rate_limit_redis, balanced_rpcs, private_rpcs, + self.shared.public_rate_limit_per_minute, ) .await } @@ -81,14 +85,14 @@ impl Web3ConnectionConfig { // #[instrument(name = "try_build_Web3ConnectionConfig", skip_all)] pub async fn spawn( self, - rate_limiter: Option<&redis_cell_client::MultiplexedConnection>, + redis_client_pool: Option<&redis_cell_client::RedisClientPool>, chain_id: usize, http_client: Option<&reqwest::Client>, http_interval_sender: Option>>, block_sender: Option, Arc)>>, tx_id_sender: Option)>>, ) -> anyhow::Result<(Arc, AnyhowJoinHandle<()>)> { - let hard_rate_limit = self.hard_limit.map(|x| (x, rate_limiter.unwrap())); + let hard_rate_limit = self.hard_limit.map(|x| (x, redis_client_pool.unwrap())); Web3Connection::spawn( chain_id, diff --git a/web3-proxy/src/connection.rs b/web3-proxy/src/connection.rs index 272ff9a0..b143a1bb 100644 --- a/web3-proxy/src/connection.rs +++ b/web3-proxy/src/connection.rs @@ -131,7 +131,7 @@ impl Web3Connection { // optional because this is only used for http providers. websocket providers don't use it http_client: Option<&reqwest::Client>, http_interval_sender: Option>>, - hard_limit: Option<(u32, &redis_cell_client::MultiplexedConnection)>, + hard_limit: Option<(u32, &redis_cell_client::RedisClientPool)>, // TODO: think more about this type soft_limit: u32, block_sender: Option, Arc)>>, @@ -356,10 +356,6 @@ impl Web3Connection { let mut last_hash = Default::default(); loop { - // wait for the interval - // TODO: if error or rate limit, increase interval? - http_interval_receiver.recv().await.unwrap(); - match self.try_request_handle().await { Ok(active_request_handle) => { // TODO: i feel like this should be easier. there is a provider.getBlock, but i don't know how to give it "latest" @@ -384,6 +380,10 @@ impl Web3Connection { warn!("Failed getting latest block from {}: {:?}", self, e); } } + + // wait for the interval + // TODO: if error or rate limit, increase interval? + http_interval_receiver.recv().await.unwrap(); } } Web3Provider::Ws(provider) => { @@ -447,10 +447,6 @@ impl Web3Connection { // TODO: create a filter loop { - // wait for the interval - // TODO: if error or rate limit, increase interval? - interval.tick().await; - // TODO: actually do something here /* match self.try_request_handle().await { @@ -463,6 +459,10 @@ impl Web3Connection { } } */ + + // wait for the interval + // TODO: if error or rate limit, increase interval? + interval.tick().await; } } Web3Provider::Ws(provider) => { diff --git a/web3-proxy/src/connections.rs b/web3-proxy/src/connections.rs index bae3271f..ffd311aa 100644 --- a/web3-proxy/src/connections.rs +++ b/web3-proxy/src/connections.rs @@ -89,7 +89,7 @@ impl Web3Connections { chain_id: usize, server_configs: Vec, http_client: Option<&reqwest::Client>, - rate_limiter: Option<&redis_cell_client::MultiplexedConnection>, + redis_client_pool: Option<&redis_cell_client::RedisClientPool>, head_block_sender: Option>>, pending_tx_sender: Option>, pending_transactions: Arc>, @@ -141,7 +141,7 @@ impl Web3Connections { for server_config in server_configs.into_iter() { match server_config .spawn( - rate_limiter, + redis_client_pool, chain_id, http_client, http_interval_sender.clone(), diff --git a/web3-proxy/src/frontend/errors.rs b/web3-proxy/src/frontend/errors.rs index 3e08f4c5..24e16f30 100644 --- a/web3-proxy/src/frontend/errors.rs +++ b/web3-proxy/src/frontend/errors.rs @@ -7,15 +7,15 @@ use crate::jsonrpc::JsonRpcForwardedResponse; pub async fn handler_404() -> impl IntoResponse { let err = anyhow::anyhow!("nothing to see here"); - handle_anyhow_error(err, Some(StatusCode::NOT_FOUND)).await + handle_anyhow_error(Some(StatusCode::NOT_FOUND), err).await } /// handle errors by converting them into something that implements `IntoResponse` /// TODO: use this. i can't get https://docs.rs/axum/latest/axum/error_handling/index.html to work /// TODO: i think we want a custom result type instead. put the anyhow result inside. then `impl IntoResponse for CustomResult` pub async fn handle_anyhow_error( - err: anyhow::Error, code: Option, + err: anyhow::Error, ) -> impl IntoResponse { let id = RawValue::from_string("null".to_string()).unwrap(); diff --git a/web3-proxy/src/frontend/http.rs b/web3-proxy/src/frontend/http.rs index 2ae951ee..eed6437d 100644 --- a/web3-proxy/src/frontend/http.rs +++ b/web3-proxy/src/frontend/http.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use crate::app::Web3ProxyApp; /// Health check page for load balancers to use -pub async fn health(app: Extension>) -> impl IntoResponse { +pub async fn health(Extension(app): Extension>) -> impl IntoResponse { if app.get_balanced_rpcs().has_synced_rpcs() { (StatusCode::OK, "OK") } else { @@ -14,7 +14,7 @@ pub async fn health(app: Extension>) -> impl IntoResponse { } /// Very basic status page -pub async fn status(app: Extension>) -> impl IntoResponse { +pub async fn status(Extension(app): Extension>) -> impl IntoResponse { // TODO: what else should we include? uptime? prometheus? let balanced_rpcs = app.get_balanced_rpcs(); let private_rpcs = app.get_private_rpcs(); diff --git a/web3-proxy/src/frontend/http_proxy.rs b/web3-proxy/src/frontend/http_proxy.rs index 6656b0c2..0543431f 100644 --- a/web3-proxy/src/frontend/http_proxy.rs +++ b/web3-proxy/src/frontend/http_proxy.rs @@ -1,15 +1,34 @@ use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; +use axum_client_ip::ClientIp; use std::sync::Arc; use super::errors::handle_anyhow_error; use crate::{app::Web3ProxyApp, jsonrpc::JsonRpcRequestEnum}; pub async fn proxy_web3_rpc( - payload: Json, - app: Extension>, + Json(payload): Json, + Extension(app): Extension>, + ClientIp(ip): ClientIp, ) -> impl IntoResponse { - match app.proxy_web3_rpc(payload.0).await { + if let Some(rate_limiter) = app.get_public_rate_limiter() { + let rate_limiter_key = format!("{}", ip); + + if rate_limiter.throttle_key(&rate_limiter_key).await.is_err() { + // TODO: set headers so they know when they can retry + // warn!(?ip, "public rate limit exceeded"); + return handle_anyhow_error( + Some(StatusCode::TOO_MANY_REQUESTS), + anyhow::anyhow!("too many requests"), + ) + .await + .into_response(); + } + } else { + // TODO: if no redis, rate limit with a local cache? + } + + match app.proxy_web3_rpc(payload).await { Ok(response) => (StatusCode::OK, Json(&response)).into_response(), - Err(err) => handle_anyhow_error(err, None).await.into_response(), + Err(err) => handle_anyhow_error(None, err).await.into_response(), } } diff --git a/web3-proxy/src/frontend/mod.rs b/web3-proxy/src/frontend/mod.rs index cc62fd4f..a9e1feb9 100644 --- a/web3-proxy/src/frontend/mod.rs +++ b/web3-proxy/src/frontend/mod.rs @@ -3,6 +3,7 @@ mod errors; mod http; mod http_proxy; mod ws_proxy; + use axum::{ handler::Handler, routing::{get, post}, @@ -35,8 +36,9 @@ pub async fn run(port: u16, proxy_app: Arc) -> anyhow::Result<()> // `axum::Server` is a re-export of `hyper::Server` let addr = SocketAddr::from(([0, 0, 0, 0], port)); debug!("listening on port {}", port); + // TODO: into_make_service is enough if we always run behind a proxy. make into_make_service_with_connect_info optional? axum::Server::bind(&addr) - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await .map_err(Into::into) } diff --git a/web3-proxy/src/frontend/ws_proxy.rs b/web3-proxy/src/frontend/ws_proxy.rs index 1486e1ab..01346b41 100644 --- a/web3-proxy/src/frontend/ws_proxy.rs +++ b/web3-proxy/src/frontend/ws_proxy.rs @@ -12,7 +12,7 @@ use hashbrown::HashMap; use serde_json::value::RawValue; use std::str::from_utf8_mut; use std::sync::Arc; -use tracing::{debug, error, info, warn}; +use tracing::{error, info, trace, warn}; use crate::{ app::Web3ProxyApp, @@ -20,13 +20,13 @@ use crate::{ }; pub async fn websocket_handler( - app: Extension>, + Extension(app): Extension>, ws: WebSocketUpgrade, ) -> impl IntoResponse { ws.on_upgrade(|socket| proxy_web3_socket(app, socket)) } -async fn proxy_web3_socket(app: Extension>, socket: WebSocket) { +async fn proxy_web3_socket(app: Arc, socket: WebSocket) { // split the websocket so we can read and write concurrently let (ws_tx, ws_rx) = socket.split(); @@ -109,7 +109,7 @@ async fn handle_socket_payload( } async fn read_web3_socket( - app: Extension>, + app: Arc, mut ws_rx: SplitStream, response_tx: flume::Sender, ) { @@ -119,12 +119,11 @@ async fn read_web3_socket( // new message from our client. forward to a backend and then send it through response_tx let response_msg = match msg { Message::Text(payload) => { - handle_socket_payload(app.0.clone(), &payload, &response_tx, &mut subscriptions) - .await + handle_socket_payload(app.clone(), &payload, &response_tx, &mut subscriptions).await } Message::Ping(x) => Message::Pong(x), Message::Pong(x) => { - debug!("pong: {:?}", x); + trace!("pong: {:?}", x); continue; } Message::Close(_) => { @@ -134,8 +133,7 @@ async fn read_web3_socket( Message::Binary(mut payload) => { let payload = from_utf8_mut(&mut payload).unwrap(); - handle_socket_payload(app.0.clone(), payload, &response_tx, &mut subscriptions) - .await + handle_socket_payload(app.clone(), payload, &response_tx, &mut subscriptions).await } };