From 80a3c7412018e4ec3d8243cfa7c9f49194106413 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Wed, 10 Aug 2022 02:37:34 +0000 Subject: [PATCH] cache db data in a map --- Cargo.lock | 21 +- Cargo.toml | 1 + TODO.md | 1 + docker-compose.common.yml | 2 +- entities/Cargo.toml | 2 +- fifomap/Cargo.toml | 9 + fifomap/src/fifo_count_map.rs | 58 ++++++ fifomap/src/fifo_sized_map.rs | 92 +++++++++ fifomap/src/lib.rs | 5 + redis-cell-client/src/lib.rs | 10 +- web3_proxy/Cargo.toml | 6 +- web3_proxy/src/app.rs | 257 ++++++------------------ web3_proxy/src/block_helpers.rs | 163 ++++++++++++++++ web3_proxy/src/config.rs | 2 +- web3_proxy/src/connection.rs | 19 +- web3_proxy/src/connections.rs | 3 +- web3_proxy/src/frontend/errors.rs | 16 +- web3_proxy/src/frontend/http_proxy.rs | 19 +- web3_proxy/src/frontend/rate_limit.rs | 268 ++++++++++++++------------ web3_proxy/src/frontend/users.rs | 14 +- web3_proxy/src/frontend/ws_proxy.rs | 19 +- web3_proxy/src/lib.rs | 1 + 22 files changed, 610 insertions(+), 378 deletions(-) create mode 100644 fifomap/Cargo.toml create mode 100644 fifomap/src/fifo_count_map.rs create mode 100644 fifomap/src/fifo_sized_map.rs create mode 100644 fifomap/src/lib.rs create mode 100644 web3_proxy/src/block_helpers.rs diff --git a/Cargo.lock b/Cargo.lock index 72dbf0d6..912ff19d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -335,9 +335,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c943a505c17b494638a38a9af129067f760c9c06794b9f57d499266909be8e72" +checksum = "9de18bc5f2e9df8f52da03856bf40e29b747de5a84e43aefff90e3dc4a21529b" dependencies = [ "async-trait", "axum-core", @@ -1769,6 +1769,13 @@ dependencies = [ "subtle", ] +[[package]] +name = "fifomap" +version = "0.1.0" +dependencies = [ + "linkedhashmap", +] + [[package]] name = "filetime" version = "0.2.17" @@ -4023,9 +4030,9 @@ checksum = "930c0acf610d3fdb5e2ab6213019aaa04e227ebe9547b0649ba599b16d788bd7" [[package]] name = "serde" -version = "1.0.142" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e590c437916fb6b221e1d00df6e3294f3fccd70ca7e92541c475d6ed6ef5fee2" +checksum = "53e8e5d5b70924f74ff5c6d64d9a5acd91422117c60f48c4e07855238a254553" dependencies = [ "serde_derive", ] @@ -4052,9 +4059,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.142" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34b5b8d809babe02f538c2cfec6f2c1ed10804c0e5a6a041a049a4f5588ccc2e" +checksum = "d3d8e8de557aee63c26b85b947f5e59b690d0454c753f3adeb5cd7835ab88391" dependencies = [ "proc-macro2", "quote", @@ -5209,12 +5216,12 @@ dependencies = [ "entities", "ethers", "fdlimit", + "fifomap", "flume", "fstrings", "futures", "hashbrown", "indexmap", - "linkedhashmap", "migration", "notify", "num", diff --git a/Cargo.toml b/Cargo.toml index 63f29dcf..9cd9e09b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "entities", "migration", + "fifomap", "linkedhashmap", "redis-cell-client", "web3_proxy", diff --git a/TODO.md b/TODO.md index e9117cb2..7297a6ae 100644 --- a/TODO.md +++ b/TODO.md @@ -64,6 +64,7 @@ - [x] refactor result type on active handlers to use a cleaner success/error so we can use the try operator - [x] give users different rate limits looked up from the database - [x] Add a "weight" key to the servers. Sort on that after block. keep most requests local +- [ ] cache db query results for user data. db is a big bottleneck right now - [ ] allow blocking public requests - [ ] use siwe messages and signatures for sign up and login - [ ] basic request method stats diff --git a/docker-compose.common.yml b/docker-compose.common.yml index 0c5766f0..610f9982 100644 --- a/docker-compose.common.yml +++ b/docker-compose.common.yml @@ -3,7 +3,7 @@ services: # TODO: build in dev but use docker hub in prod? build: . restart: unless-stopped - command: --config /config.toml --workers 8 + command: --config /config.toml --workers 32 environment: #RUST_LOG: "info,web3_proxy=debug" RUST_LOG: info diff --git a/entities/Cargo.toml b/entities/Cargo.toml index fd24bb99..11ad63a5 100644 --- a/entities/Cargo.toml +++ b/entities/Cargo.toml @@ -11,5 +11,5 @@ path = "src/mod.rs" [dependencies] sea-orm = "0.9.1" -serde = "1.0.142" +serde = "1.0.143" uuid = "1.1.2" diff --git a/fifomap/Cargo.toml b/fifomap/Cargo.toml new file mode 100644 index 00000000..d351684f --- /dev/null +++ b/fifomap/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "fifomap" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +linkedhashmap = { path = "../linkedhashmap", features = ["inline-more"] } diff --git a/fifomap/src/fifo_count_map.rs b/fifomap/src/fifo_count_map.rs new file mode 100644 index 00000000..461d25f2 --- /dev/null +++ b/fifomap/src/fifo_count_map.rs @@ -0,0 +1,58 @@ +use linkedhashmap::LinkedHashMap; +use std::{ + borrow::Borrow, + collections::hash_map::RandomState, + hash::{BuildHasher, Hash}, +}; + +pub struct FifoCountMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + /// size limit for the map + max_count: usize, + /// FIFO + map: LinkedHashMap, +} + +impl FifoCountMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + pub fn new(max_count: usize) -> Self { + Self { + max_count, + map: Default::default(), + } + } +} + +impl FifoCountMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + /// if the size is larger than `self.max_size_bytes`, drop items (first in, first out) + /// no item is allowed to take more than `1/max_share` of the cache + pub fn insert(&mut self, key: K, value: V) { + // drop items until the cache has enough room for the new data + // TODO: this probably has wildly variable timings + if self.map.len() > self.max_count { + // TODO: this isn't an LRU. it's a "least recently created". does that have a fancy name? should we make it an lru? these caches only live for one block + self.map.pop_front(); + } + + self.map.insert(key, value); + } + + /// get an item from the cache, or None + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Hash + Eq, + { + self.map.get(key) + } +} diff --git a/fifomap/src/fifo_sized_map.rs b/fifomap/src/fifo_sized_map.rs new file mode 100644 index 00000000..584e8602 --- /dev/null +++ b/fifomap/src/fifo_sized_map.rs @@ -0,0 +1,92 @@ +use linkedhashmap::LinkedHashMap; +use std::{ + borrow::Borrow, + collections::hash_map::RandomState, + hash::{BuildHasher, Hash}, + mem::size_of_val, +}; + +// TODO: if values have wildly different sizes, this is good. but if they are all about the same, this could be simpler +pub struct FifoSizedMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + /// size limit in bytes for the map + max_size_bytes: usize, + /// size limit in bytes for a single item in the map + max_item_bytes: usize, + /// FIFO + map: LinkedHashMap, +} + +impl FifoSizedMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + pub fn new(max_size_bytes: usize, max_share: usize) -> Self { + let max_item_bytes = max_size_bytes / max_share; + + Self { + max_size_bytes, + max_item_bytes, + map: Default::default(), + } + } +} + +impl Default for FifoSizedMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + fn default() -> Self { + Self::new( + // 100 MB default cache + 100_000_000, + // items cannot take more than 1% of the cache + 100, + ) + } +} + +impl FifoSizedMap +where + K: Hash + Eq + Clone, + S: BuildHasher + Default, +{ + /// if the size is larger than `self.max_size_bytes`, drop items (first in, first out) + /// no item is allowed to take more than `1/max_share` of the cache + pub fn insert(&mut self, key: K, value: V) -> bool { + // TODO: this might be too naive. not sure how much overhead the object has + let new_size = size_of_val(&key) + size_of_val(&value); + + // no item is allowed to take more than 1% of the cache + // TODO: get this from config? + // TODO: trace logging + if new_size > self.max_item_bytes { + return false; + } + + // drop items until the cache has enough room for the new data + // TODO: this probably has wildly variable timings + while size_of_val(&self.map) + new_size > self.max_size_bytes { + // TODO: this isn't an LRU. it's a "least recently created". does that have a fancy name? should we make it an lru? these caches only live for one block + self.map.pop_front(); + } + + self.map.insert(key, value); + + true + } + + /// get an item from the cache, or None + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Hash + Eq, + { + self.map.get(key) + } +} diff --git a/fifomap/src/lib.rs b/fifomap/src/lib.rs new file mode 100644 index 00000000..8b2655bc --- /dev/null +++ b/fifomap/src/lib.rs @@ -0,0 +1,5 @@ +mod fifo_count_map; +mod fifo_sized_map; + +pub use fifo_count_map::FifoCountMap; +pub use fifo_sized_map::FifoSizedMap; diff --git a/redis-cell-client/src/lib.rs b/redis-cell-client/src/lib.rs index fd90a59c..059db113 100644 --- a/redis-cell-client/src/lib.rs +++ b/redis-cell-client/src/lib.rs @@ -8,10 +8,10 @@ pub use bb8_redis::{bb8, RedisConnectionManager}; use std::ops::Add; use std::time::{Duration, Instant}; -pub type RedisClientPool = bb8::Pool; +pub type RedisPool = bb8::Pool; -pub struct RedisCellClient { - pool: RedisClientPool, +pub struct RedisCell { + pool: RedisPool, key: String, default_max_burst: u32, default_count_per_period: u32, @@ -23,12 +23,12 @@ pub enum ThrottleResult { RetryAt(Instant), } -impl RedisCellClient { +impl RedisCell { // todo: seems like this could be derived // TODO: take something generic for conn // TODO: use r2d2 for connection pooling? pub fn new( - pool: RedisClientPool, + pool: RedisPool, app: &str, key: &str, default_max_burst: u32, diff --git a/web3_proxy/Cargo.toml b/web3_proxy/Cargo.toml index ef5c51af..66d7c1c5 100644 --- a/web3_proxy/Cargo.toml +++ b/web3_proxy/Cargo.toml @@ -19,7 +19,7 @@ migration = { path = "../migration" } anyhow = { version = "1.0.60", features = ["backtrace"] } arc-swap = "1.5.1" argh = "0.1.8" -axum = { version = "0.5.13", features = ["serde_json", "tokio-tungstenite", "ws"] } +axum = { version = "0.5.15", features = ["serde_json", "tokio-tungstenite", "ws"] } axum-client-ip = "0.2.0" counter = "0.5.6" dashmap = "5.3.4" @@ -32,7 +32,7 @@ futures = { version = "0.3.21", features = ["thread-pool"] } fstrings = "0.2.3" hashbrown = { version = "0.12.3", features = ["serde"] } indexmap = "1.9.1" -linkedhashmap = { path = "../linkedhashmap", features = ["inline-more"] } +fifomap = { path = "../fifomap" } notify = "4.0.17" num = "0.4.0" parking_lot = { version = "0.12.1", features = ["arc_lock"] } @@ -45,7 +45,7 @@ reqwest = { version = "0.11.11", default-features = false, features = ["json", " rustc-hash = "1.1.0" siwe = "0.4.1" sea-orm = { version = "0.9.1", features = ["macros"] } -serde = { version = "1.0.142", features = [] } +serde = { version = "1.0.143", features = [] } serde_json = { version = "1.0.83", default-features = false, features = ["alloc", "raw_value"] } tokio = { version = "1.20.1", features = ["full", "tracing"] } # TODO: make sure this uuid version matches what is in sea orm. PR on sea orm to put builder into prelude diff --git a/web3_proxy/src/app.rs b/web3_proxy/src/app.rs index 41c72329..b1979bfc 100644 --- a/web3_proxy/src/app.rs +++ b/web3_proxy/src/app.rs @@ -1,35 +1,40 @@ +// TODO: this file is way too big now. move things into other modules + use anyhow::Context; use axum::extract::ws::Message; use dashmap::mapref::entry::Entry as DashMapEntry; use dashmap::DashMap; +use derive_more::From; use ethers::core::utils::keccak256; -use ethers::prelude::{Address, Block, BlockNumber, Bytes, Transaction, TxHash, H256, U64}; +use ethers::prelude::{Address, Block, Bytes, Transaction, TxHash, H256, U64}; +use fifomap::{FifoCountMap, FifoSizedMap}; use futures::future::Abortable; use futures::future::{join_all, AbortHandle}; use futures::stream::FuturesUnordered; use futures::stream::StreamExt; use futures::Future; -use linkedhashmap::LinkedHashMap; use migration::{Migrator, MigratorTrait}; use parking_lot::RwLock; use redis_cell_client::bb8::ErrorSink; -use redis_cell_client::{bb8, RedisCellClient, RedisConnectionManager}; +use redis_cell_client::{bb8, RedisCell, RedisConnectionManager, RedisPool}; use sea_orm::DatabaseConnection; use serde_json::json; use std::fmt; -use std::mem::size_of_val; use std::pin::Pin; use std::str::FromStr; use std::sync::atomic::{self, AtomicUsize}; use std::sync::Arc; use std::time::Duration; +use tokio::sync::RwLock as AsyncRwLock; use tokio::sync::{broadcast, watch}; use tokio::task::JoinHandle; -use tokio::time::timeout; +use tokio::time::{timeout, Instant}; use tokio_stream::wrappers::{BroadcastStream, WatchStream}; use tracing::{info, info_span, instrument, trace, warn, Instrument}; +use uuid::Uuid; use crate::bb8_helpers; +use crate::block_helpers::block_needed; use crate::config::AppConfig; use crate::connections::Web3Connections; use crate::jsonrpc::JsonRpcForwardedResponse; @@ -48,14 +53,14 @@ static APP_USER_AGENT: &str = concat!( // block hash, method, params type CacheKey = (H256, String, Option); -// TODO: make something more advanced that keeps track of cache size in bytes -type ResponseLrcCache = RwLock>; +type ResponseLrcCache = RwLock>; type ActiveRequestsMap = DashMap>; pub type AnyhowJoinHandle = JoinHandle>; /// flatten a JoinError into an anyhow error +/// Useful when joining multiple futures. pub async fn flatten_handle(handle: AnyhowJoinHandle) -> anyhow::Result { match handle.await { Ok(Ok(result)) => Ok(result), @@ -79,173 +84,17 @@ pub async fn flatten_handles( Ok(()) } -fn block_num_to_u64(block_num: BlockNumber, latest_block: U64) -> (bool, U64) { - match block_num { - BlockNumber::Earliest => (false, U64::zero()), - BlockNumber::Latest => { - // change "latest" to a number - (true, latest_block) - } - BlockNumber::Number(x) => (false, x), - // TODO: think more about how to handle Pending - BlockNumber::Pending => (false, latest_block), - } -} - -fn clean_block_number( - params: &mut serde_json::Value, - block_param_id: usize, - latest_block: U64, -) -> anyhow::Result { - match params.as_array_mut() { - None => Err(anyhow::anyhow!("params not an array")), - Some(params) => match params.get_mut(block_param_id) { - None => { - if params.len() != block_param_id - 1 { - return Err(anyhow::anyhow!("unexpected params length")); - } - - // add the latest block number to the end of the params - params.push(serde_json::to_value(latest_block)?); - - Ok(latest_block) - } - Some(x) => { - // convert the json value to a BlockNumber - let block_num: BlockNumber = serde_json::from_value(x.clone())?; - - let (modified, block_num) = block_num_to_u64(block_num, latest_block); - - // if we changed "latest" to a number, update the params to match - if modified { - *x = serde_json::to_value(block_num)?; - } - - Ok(block_num) - } - }, - } -} - -// TODO: change this to return also return the hash needed -fn block_needed( - method: &str, - params: Option<&mut serde_json::Value>, - head_block: U64, -) -> Option { - let params = params?; - - // TODO: double check these. i think some of the getBlock stuff will never need archive - let block_param_id = match method { - "eth_call" => 1, - "eth_estimateGas" => 1, - "eth_getBalance" => 1, - "eth_getBlockByHash" => { - // TODO: double check that any node can serve this - return None; - } - "eth_getBlockByNumber" => { - // TODO: double check that any node can serve this - return None; - } - "eth_getBlockTransactionCountByHash" => { - // TODO: double check that any node can serve this - return None; - } - "eth_getBlockTransactionCountByNumber" => 0, - "eth_getCode" => 1, - "eth_getLogs" => { - let obj = params[0].as_object_mut().unwrap(); - - if let Some(x) = obj.get_mut("fromBlock") { - let block_num: BlockNumber = serde_json::from_value(x.clone()).ok()?; - - let (modified, block_num) = block_num_to_u64(block_num, head_block); - - if modified { - *x = serde_json::to_value(block_num).unwrap(); - } - - return Some(block_num); - } - - if let Some(x) = obj.get_mut("toBlock") { - let block_num: BlockNumber = serde_json::from_value(x.clone()).ok()?; - - let (modified, block_num) = block_num_to_u64(block_num, head_block); - - if modified { - *x = serde_json::to_value(block_num).unwrap(); - } - - return Some(block_num); - } - - if let Some(x) = obj.get("blockHash") { - // TODO: check a linkedhashmap of recent hashes - // TODO: error if fromBlock or toBlock were set - todo!("handle blockHash {}", x); - } - - return None; - } - "eth_getStorageAt" => 2, - "eth_getTransactionByHash" => { - // TODO: not sure how best to look these up - // try full nodes first. retry will use archive - return None; - } - "eth_getTransactionByBlockHashAndIndex" => { - // TODO: check a linkedhashmap of recent hashes - // try full nodes first. retry will use archive - return None; - } - "eth_getTransactionByBlockNumberAndIndex" => 0, - "eth_getTransactionCount" => 1, - "eth_getTransactionReceipt" => { - // TODO: not sure how best to look these up - // try full nodes first. retry will use archive - return None; - } - "eth_getUncleByBlockHashAndIndex" => { - // TODO: check a linkedhashmap of recent hashes - // try full nodes first. retry will use archive - return None; - } - "eth_getUncleByBlockNumberAndIndex" => 0, - "eth_getUncleCountByBlockHash" => { - // TODO: check a linkedhashmap of recent hashes - // try full nodes first. retry will use archive - return None; - } - "eth_getUncleCountByBlockNumber" => 0, - _ => { - // some other command that doesn't take block numbers as an argument - return None; - } - }; - - match clean_block_number(params, block_param_id, head_block) { - Ok(block) => Some(block), - Err(err) => { - // TODO: seems unlikely that we will get here - // if this is incorrect, it should retry on an archive server - warn!(?err, "could not get block from params"); - None - } - } -} - +/// Connect to the database and run migrations pub async fn get_migrated_db( db_url: String, min_connections: u32, ) -> anyhow::Result { let mut db_opt = sea_orm::ConnectOptions::new(db_url); - // TODO: load all these options from the config file + // TODO: load all these options from the config file. i think mysql default max is 100 // TODO: sqlx logging only in debug. way too verbose for production db_opt - .max_connections(100) + .max_connections(99) .min_connections(min_connections) .connect_timeout(Duration::from_secs(8)) .idle_timeout(Duration::from_secs(8)) @@ -269,6 +118,13 @@ pub enum TxState { Orphaned(Transaction), } +#[derive(Clone, Copy, From)] +pub struct UserCacheValue { + pub expires_at: Instant, + pub user_id: i64, + pub user_count_per_period: u32, +} + /// The application // TODO: this debug impl is way too verbose. make something smaller // TODO: if Web3ProxyApp is always in an Arc, i think we can avoid having at least some of these internal things in arcs @@ -278,18 +134,17 @@ pub struct Web3ProxyApp { balanced_rpcs: Arc, /// Send private requests (like eth_sendRawTransaction) to all these servers private_rpcs: Arc, - /// Track active requests so that we don't 66 - /// + /// Track active requests so that we don't send the same query to multiple backends active_requests: ActiveRequestsMap, - /// bytes available to response_cache (it will be slightly larger than this) - response_cache_max_bytes: AtomicUsize, response_cache: ResponseLrcCache, // don't drop this or the sender will stop working // TODO: broadcast channel instead? head_block_receiver: watch::Receiver>>, pending_tx_sender: broadcast::Sender, pending_transactions: Arc>, - rate_limiter: Option, + user_cache: AsyncRwLock>, + redis_pool: Option, + rate_limiter: Option, db_conn: Option, } @@ -301,18 +156,26 @@ impl fmt::Debug for Web3ProxyApp { } impl Web3ProxyApp { - pub fn db_conn(&self) -> &sea_orm::DatabaseConnection { - self.db_conn.as_ref().unwrap() + pub fn db_conn(&self) -> Option<&sea_orm::DatabaseConnection> { + self.db_conn.as_ref() } pub fn pending_transactions(&self) -> &DashMap { &self.pending_transactions } - pub fn rate_limiter(&self) -> Option<&RedisCellClient> { + pub fn rate_limiter(&self) -> Option<&RedisCell> { self.rate_limiter.as_ref() } + pub fn redis_pool(&self) -> Option<&RedisPool> { + self.redis_pool.as_ref() + } + + pub fn user_cache(&self) -> &AsyncRwLock> { + &self.user_cache + } + // TODO: should we just take the rpc config as the only arg instead? pub async fn spawn( app_config: AppConfig, @@ -355,7 +218,7 @@ impl Web3ProxyApp { .build()?, ); - let redis_client_pool = match app_config.shared.redis_url { + let redis_pool = match app_config.shared.redis_url { Some(redis_url) => { info!("Connecting to redis on {}", redis_url); @@ -399,7 +262,7 @@ impl Web3ProxyApp { app_config.shared.chain_id, balanced_rpcs, http_client.clone(), - redis_client_pool.clone(), + redis_pool.clone(), Some(head_block_sender), Some(pending_tx_sender.clone()), pending_transactions.clone(), @@ -418,7 +281,7 @@ impl Web3ProxyApp { app_config.shared.chain_id, private_rpcs, http_client.clone(), - redis_client_pool.clone(), + redis_pool.clone(), // subscribing to new heads here won't work well None, // TODO: subscribe to pending transactions on the private rpcs? @@ -436,9 +299,9 @@ impl Web3ProxyApp { // TODO: how much should we allow? let public_max_burst = app_config.shared.public_rate_limit_per_minute / 3; - let frontend_rate_limiter = redis_client_pool.as_ref().map(|redis_client_pool| { - RedisCellClient::new( - redis_client_pool.clone(), + let frontend_rate_limiter = redis_pool.as_ref().map(|redis_pool| { + RedisCell::new( + redis_pool.clone(), "web3_proxy", "frontend", public_max_burst, @@ -451,13 +314,20 @@ impl Web3ProxyApp { balanced_rpcs, private_rpcs, active_requests: Default::default(), - response_cache_max_bytes: AtomicUsize::new(app_config.shared.response_cache_max_bytes), - response_cache: Default::default(), + // TODO: make the share configurable + response_cache: RwLock::new(FifoSizedMap::new( + app_config.shared.response_cache_max_bytes, + 100, + )), head_block_receiver, pending_tx_sender, pending_transactions, rate_limiter: frontend_rate_limiter, db_conn, + redis_pool, + // TODO: make the size configurable + // TODO: why does this need to be async but the other one doesn't? + user_cache: AsyncRwLock::new(FifoCountMap::new(1_000)), }; let app = Arc::new(app); @@ -907,6 +777,7 @@ impl Web3ProxyApp { // returns Keccak-256 (not the standardized SHA3-256) of the given data. match &request.params { Some(serde_json::Value::Array(params)) => { + // TODO: make a struct and use serde conversion to clean this up if params.len() != 1 || !params[0].is_string() { return Err(anyhow::anyhow!("invalid request")); } @@ -1009,26 +880,10 @@ impl Web3ProxyApp { { let mut response_cache = response_cache.write(); - let response_cache_max_bytes = self - .response_cache_max_bytes - .load(atomic::Ordering::Acquire); - - // TODO: this might be too naive. not sure how much overhead the object has - let new_size = size_of_val(&cache_key) + size_of_val(&response); - - // no item is allowed to take more than 1% of the cache - // TODO: get this from config? - if new_size < response_cache_max_bytes / 100 { - // TODO: this probably has wildly variable timings - while size_of_val(&response_cache) + new_size >= response_cache_max_bytes { - // TODO: this isn't an LRU. it's a "least recently created". does that have a fancy name? should we make it an lru? these caches only live for one block - response_cache.pop_front(); - } - - response_cache.insert(cache_key.clone(), response.clone()); + if response_cache.insert(cache_key.clone(), response.clone()) { } else { - // TODO: emit a stat instead? - warn!(?new_size, "value too large for caching"); + // TODO: emit a stat instead? what else should be in the log + trace!(?cache_key, "value too large for caching"); } } diff --git a/web3_proxy/src/block_helpers.rs b/web3_proxy/src/block_helpers.rs new file mode 100644 index 00000000..7920efc7 --- /dev/null +++ b/web3_proxy/src/block_helpers.rs @@ -0,0 +1,163 @@ +use ethers::prelude::{BlockNumber, U64}; +use tracing::warn; + +pub fn block_num_to_u64(block_num: BlockNumber, latest_block: U64) -> (bool, U64) { + match block_num { + BlockNumber::Earliest => (false, U64::zero()), + BlockNumber::Latest => { + // change "latest" to a number + (true, latest_block) + } + BlockNumber::Number(x) => (false, x), + BlockNumber::Pending => { + // TODO: think more about how to handle Pending + // modified is false because we probably want the backend to see "pending" + (false, latest_block) + } + } +} + +/// modify params to always have a block number and not "latest" +pub fn clean_block_number( + params: &mut serde_json::Value, + block_param_id: usize, + latest_block: U64, +) -> anyhow::Result { + match params.as_array_mut() { + None => Err(anyhow::anyhow!("params not an array")), + Some(params) => match params.get_mut(block_param_id) { + None => { + if params.len() != block_param_id - 1 { + return Err(anyhow::anyhow!("unexpected params length")); + } + + // add the latest block number to the end of the params + params.push(serde_json::to_value(latest_block)?); + + Ok(latest_block) + } + Some(x) => { + // convert the json value to a BlockNumber + let block_num: BlockNumber = serde_json::from_value(x.clone())?; + + let (modified, block_num) = block_num_to_u64(block_num, latest_block); + + // if we changed "latest" to a number, update the params to match + if modified { + *x = serde_json::to_value(block_num)?; + } + + Ok(block_num) + } + }, + } +} + +// TODO: change this to also return the hash needed +pub fn block_needed( + method: &str, + params: Option<&mut serde_json::Value>, + head_block: U64, +) -> Option { + let params = params?; + + // TODO: double check these. i think some of the getBlock stuff will never need archive + let block_param_id = match method { + "eth_call" => 1, + "eth_estimateGas" => 1, + "eth_getBalance" => 1, + "eth_getBlockByHash" => { + // TODO: double check that any node can serve this + return None; + } + "eth_getBlockByNumber" => { + // TODO: double check that any node can serve this + return None; + } + "eth_getBlockTransactionCountByHash" => { + // TODO: double check that any node can serve this + return None; + } + "eth_getBlockTransactionCountByNumber" => 0, + "eth_getCode" => 1, + "eth_getLogs" => { + let obj = params[0].as_object_mut().unwrap(); + + if let Some(x) = obj.get_mut("fromBlock") { + let block_num: BlockNumber = serde_json::from_value(x.clone()).ok()?; + + let (modified, block_num) = block_num_to_u64(block_num, head_block); + + if modified { + *x = serde_json::to_value(block_num).unwrap(); + } + + return Some(block_num); + } + + if let Some(x) = obj.get_mut("toBlock") { + let block_num: BlockNumber = serde_json::from_value(x.clone()).ok()?; + + let (modified, block_num) = block_num_to_u64(block_num, head_block); + + if modified { + *x = serde_json::to_value(block_num).unwrap(); + } + + return Some(block_num); + } + + if let Some(x) = obj.get("blockHash") { + // TODO: check a linkedhashmap of recent hashes + // TODO: error if fromBlock or toBlock were set + todo!("handle blockHash {}", x); + } + + return None; + } + "eth_getStorageAt" => 2, + "eth_getTransactionByHash" => { + // TODO: not sure how best to look these up + // try full nodes first. retry will use archive + return None; + } + "eth_getTransactionByBlockHashAndIndex" => { + // TODO: check a linkedhashmap of recent hashes + // try full nodes first. retry will use archive + return None; + } + "eth_getTransactionByBlockNumberAndIndex" => 0, + "eth_getTransactionCount" => 1, + "eth_getTransactionReceipt" => { + // TODO: not sure how best to look these up + // try full nodes first. retry will use archive + return None; + } + "eth_getUncleByBlockHashAndIndex" => { + // TODO: check a linkedhashmap of recent hashes + // try full nodes first. retry will use archive + return None; + } + "eth_getUncleByBlockNumberAndIndex" => 0, + "eth_getUncleCountByBlockHash" => { + // TODO: check a linkedhashmap of recent hashes + // try full nodes first. retry will use archive + return None; + } + "eth_getUncleCountByBlockNumber" => 0, + _ => { + // some other command that doesn't take block numbers as an argument + return None; + } + }; + + match clean_block_number(params, block_param_id, head_block) { + Ok(block) => Some(block), + Err(err) => { + // TODO: seems unlikely that we will get here + // if this is incorrect, it should retry on an archive server + warn!(?err, "could not get block from params"); + None + } + } +} diff --git a/web3_proxy/src/config.rs b/web3_proxy/src/config.rs index f92b0f2a..96962d7e 100644 --- a/web3_proxy/src/config.rs +++ b/web3_proxy/src/config.rs @@ -70,7 +70,7 @@ impl Web3ConnectionConfig { // #[instrument(name = "try_build_Web3ConnectionConfig", skip_all)] pub async fn spawn( self, - redis_client_pool: Option, + redis_client_pool: Option, chain_id: u64, http_client: Option, http_interval_sender: Option>>, diff --git a/web3_proxy/src/connection.rs b/web3_proxy/src/connection.rs index be227309..d343f738 100644 --- a/web3_proxy/src/connection.rs +++ b/web3_proxy/src/connection.rs @@ -4,7 +4,8 @@ use derive_more::From; use ethers::prelude::{Block, Bytes, Middleware, ProviderError, TxHash, H256, U64}; use futures::future::try_join_all; use futures::StreamExt; -use redis_cell_client::{RedisCellClient, ThrottleResult}; +use parking_lot::RwLock; +use redis_cell_client::{RedisCell, ThrottleResult}; use serde::ser::{SerializeStruct, Serializer}; use serde::Serialize; use std::fmt; @@ -12,7 +13,7 @@ use std::hash::{Hash, Hasher}; use std::sync::atomic::{self, AtomicU32, AtomicU64}; use std::{cmp::Ordering, sync::Arc}; use tokio::sync::broadcast; -use tokio::sync::RwLock; +use tokio::sync::RwLock as AsyncRwLock; use tokio::time::{interval, sleep, sleep_until, Duration, Instant, MissedTickBehavior}; use tracing::{error, info, info_span, instrument, trace, warn, Instrument}; @@ -77,14 +78,14 @@ pub struct Web3Connection { /// keep track of currently open requests. We sort on this active_requests: AtomicU32, /// provider is in a RwLock so that we can replace it if re-connecting - provider: RwLock>>, + provider: AsyncRwLock>>, /// rate limits are stored in a central redis so that multiple proxies can share their rate limits - hard_limit: Option, + hard_limit: Option, /// used for load balancing to the least loaded server soft_limit: u32, block_data_limit: AtomicU64, weight: u32, - head_block: parking_lot::RwLock<(H256, U64)>, + head_block: RwLock<(H256, U64)>, } impl Serialize for Web3Connection { @@ -146,7 +147,7 @@ impl Web3Connection { // optional because this is only used for http providers. websocket providers don't use it http_client: Option, http_interval_sender: Option>>, - hard_limit: Option<(u32, redis_cell_client::RedisClientPool)>, + hard_limit: Option<(u32, redis_cell_client::RedisPool)>, // TODO: think more about this type soft_limit: u32, block_sender: Option>, @@ -157,7 +158,7 @@ impl Web3Connection { let hard_limit = hard_limit.map(|(hard_rate_limit, redis_conection)| { // TODO: allow configurable period and max_burst let period = 1; - RedisCellClient::new( + RedisCell::new( redis_conection, "web3_proxy", &format!("{}:{}", chain_id, url_str), @@ -172,11 +173,11 @@ impl Web3Connection { let new_connection = Self { url: url_str.clone(), active_requests: 0.into(), - provider: RwLock::new(Some(Arc::new(provider))), + provider: AsyncRwLock::new(Some(Arc::new(provider))), hard_limit, soft_limit, block_data_limit: Default::default(), - head_block: parking_lot::RwLock::new((H256::zero(), 0isize.into())), + head_block: RwLock::new((H256::zero(), 0isize.into())), weight, }; diff --git a/web3_proxy/src/connections.rs b/web3_proxy/src/connections.rs index 22b938bd..f4904995 100644 --- a/web3_proxy/src/connections.rs +++ b/web3_proxy/src/connections.rs @@ -37,7 +37,6 @@ struct SyncedConnections { head_block_num: u64, head_block_hash: H256, // TODO: this should be able to serialize, but it isn't - // TODO: use linkedhashmap? #[serde(skip_serializing)] conns: IndexSet>, } @@ -147,7 +146,7 @@ impl Web3Connections { chain_id: u64, server_configs: Vec, http_client: Option, - redis_client_pool: Option, + redis_client_pool: Option, head_block_sender: Option>>>, pending_tx_sender: Option>, pending_transactions: Arc>, diff --git a/web3_proxy/src/frontend/errors.rs b/web3_proxy/src/frontend/errors.rs index b9ef5e4f..89441f20 100644 --- a/web3_proxy/src/frontend/errors.rs +++ b/web3_proxy/src/frontend/errors.rs @@ -1,22 +1,26 @@ -use axum::{http::StatusCode, response::IntoResponse, Json}; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; use serde_json::value::RawValue; use crate::jsonrpc::JsonRpcForwardedResponse; -pub async fn handler_404() -> impl IntoResponse { +pub async fn handler_404() -> Response { let err = anyhow::anyhow!("nothing to see here"); - handle_anyhow_error(Some(StatusCode::NOT_FOUND), None, err).await + handle_anyhow_error(Some(StatusCode::NOT_FOUND), None, err) } /// handle errors by converting them into something that implements `IntoResponse` /// TODO: use this. i can't get to work /// TODO: i think we want a custom result type instead. put the anyhow result inside. then `impl IntoResponse for CustomResult` -pub async fn handle_anyhow_error( +pub fn handle_anyhow_error( http_code: Option, id: Option>, err: anyhow::Error, -) -> impl IntoResponse { +) -> Response { // TODO: we might have an id. like if this is for rate limiting, we can use it let id = id.unwrap_or_else(|| RawValue::from_string("null".to_string()).unwrap()); @@ -27,5 +31,5 @@ pub async fn handle_anyhow_error( let code = http_code.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - (code, Json(err)) + (code, Json(err)).into_response() } diff --git a/web3_proxy/src/frontend/http_proxy.rs b/web3_proxy/src/frontend/http_proxy.rs index 325e5cdc..000c7037 100644 --- a/web3_proxy/src/frontend/http_proxy.rs +++ b/web3_proxy/src/frontend/http_proxy.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use uuid::Uuid; use super::errors::handle_anyhow_error; -use super::rate_limit::{rate_limit_by_ip, rate_limit_by_key}; +use super::rate_limit::handle_rate_limit_error_response; use crate::{app::Web3ProxyApp, jsonrpc::JsonRpcRequestEnum}; pub async fn public_proxy_web3_rpc( @@ -13,13 +13,15 @@ pub async fn public_proxy_web3_rpc( Extension(app): Extension>, ClientIp(ip): ClientIp, ) -> impl IntoResponse { - if let Err(x) = rate_limit_by_ip(&app, &ip).await { - return x.into_response(); + if let Some(err_response) = + handle_rate_limit_error_response(app.rate_limit_by_ip(&ip).await).await + { + return err_response.into_response(); } match app.proxy_web3_rpc(payload).await { Ok(response) => (StatusCode::OK, Json(&response)).into_response(), - Err(err) => handle_anyhow_error(None, None, err).await.into_response(), + Err(err) => handle_anyhow_error(None, None, err).into_response(), } } @@ -28,12 +30,15 @@ pub async fn user_proxy_web3_rpc( Extension(app): Extension>, Path(user_key): Path, ) -> impl IntoResponse { - if let Err(x) = rate_limit_by_key(&app, user_key).await { - return x.into_response(); + // TODO: add a helper on this that turns RateLimitResult into error if its not allowed + if let Some(err_response) = + handle_rate_limit_error_response(app.rate_limit_by_key(user_key).await).await + { + return err_response.into_response(); } match app.proxy_web3_rpc(payload).await { Ok(response) => (StatusCode::OK, Json(&response)).into_response(), - Err(err) => handle_anyhow_error(None, None, err).await.into_response(), + Err(err) => handle_anyhow_error(None, None, err), } } diff --git a/web3_proxy/src/frontend/rate_limit.rs b/web3_proxy/src/frontend/rate_limit.rs index 41e3f8c0..937513f4 100644 --- a/web3_proxy/src/frontend/rate_limit.rs +++ b/web3_proxy/src/frontend/rate_limit.rs @@ -1,146 +1,168 @@ -use axum::response::IntoResponse; +use axum::response::Response; use entities::user_keys; use redis_cell_client::ThrottleResult; use reqwest::StatusCode; use sea_orm::{ ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect, }; -use std::net::IpAddr; +use std::{net::IpAddr, time::Duration}; +use tokio::time::Instant; use tracing::{debug, warn}; use uuid::Uuid; -use crate::app::Web3ProxyApp; +use crate::app::{UserCacheValue, Web3ProxyApp}; use super::errors::handle_anyhow_error; -pub async fn rate_limit_by_ip(app: &Web3ProxyApp, ip: &IpAddr) -> Result<(), impl IntoResponse> { - let rate_limiter_key = format!("ip-{}", ip); - - // TODO: dry this up with rate_limit_by_key - if let Some(rate_limiter) = app.rate_limiter() { - match rate_limiter - .throttle_key(&rate_limiter_key, None, None, None) - .await - { - Ok(ThrottleResult::Allowed) => {} - Ok(ThrottleResult::RetryAt(_retry_at)) => { - // TODO: set headers so they know when they can retry - debug!(?rate_limiter_key, "rate limit exceeded"); // this is too verbose, but a stat might be good - // TODO: use their id if possible - return Err(handle_anyhow_error( - Some(StatusCode::TOO_MANY_REQUESTS), - None, - anyhow::anyhow!(format!("too many requests from this ip: {}", ip)), - ) - .await - .into_response()); - } - Err(err) => { - // internal error, not rate limit being hit - // TODO: i really want axum to do this for us in a single place. - return Err(handle_anyhow_error( - Some(StatusCode::INTERNAL_SERVER_ERROR), - None, - err, - ) - .await - .into_response()); - } - } - } else { - // TODO: if no redis, rate limit with a local cache? - warn!("no rate limiter!"); - } - - Ok(()) +pub enum RateLimitResult { + Allowed, + RateLimitExceeded, + UnknownKey, } -/// if Ok(()), rate limits are acceptable -/// if Err(response), rate limits exceeded -pub async fn rate_limit_by_key( - app: &Web3ProxyApp, - user_key: Uuid, -) -> Result<(), impl IntoResponse> { - let db = app.db_conn(); +impl Web3ProxyApp { + pub async fn rate_limit_by_ip(&self, ip: &IpAddr) -> anyhow::Result { + let rate_limiter_key = format!("ip-{}", ip); - /// query just a few columns instead of the entire table - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - UserId, - RequestsPerMinute, - } - - // query the db to make sure this key is active - // TODO: probably want a cache on this - match user_keys::Entity::find() - .select_only() - .column_as(user_keys::Column::UserId, QueryAs::UserId) - .column_as( - user_keys::Column::RequestsPerMinute, - QueryAs::RequestsPerMinute, - ) - .filter(user_keys::Column::ApiKey.eq(user_key)) - .filter(user_keys::Column::Active.eq(true)) - .into_values::<_, QueryAs>() - .one(db) - .await - { - Ok::, _>(Some((_user_id, user_count_per_period))) => { - // user key is valid - if let Some(rate_limiter) = app.rate_limiter() { - // TODO: how does max burst actually work? what should it be? - let user_max_burst = user_count_per_period / 3; - let user_period = 60; - - if rate_limiter - .throttle_key( - &user_key.to_string(), - Some(user_max_burst), - Some(user_count_per_period), - Some(user_period), - ) - .await - .is_err() - { + // TODO: dry this up with rate_limit_by_key + if let Some(rate_limiter) = self.rate_limiter() { + match rate_limiter + .throttle_key(&rate_limiter_key, None, None, None) + .await + { + Ok(ThrottleResult::Allowed) => {} + Ok(ThrottleResult::RetryAt(_retry_at)) => { // TODO: set headers so they know when they can retry - // warn!(?ip, "public rate limit exceeded"); // this is too verbose, but a stat might be good - // TODO: use their id if possible - return Err(handle_anyhow_error( - Some(StatusCode::TOO_MANY_REQUESTS), - None, - // TODO: include the user id (NOT THE API KEY!) here - anyhow::anyhow!("too many requests from this key"), - ) - .await - .into_response()); + debug!(?rate_limiter_key, "rate limit exceeded"); // this is too verbose, but a stat might be good + // TODO: use their id if possible + return Ok(RateLimitResult::RateLimitExceeded); + } + Err(err) => { + // internal error, not rate limit being hit + // TODO: i really want axum to do this for us in a single place. + return Err(err); } - } else { - // TODO: if no redis, rate limit with a local cache? } + } else { + // TODO: if no redis, rate limit with a local cache? + warn!("no rate limiter!"); } - Ok(None) => { - // invalid user key - // TODO: rate limit by ip here, too? maybe tarpit? - return Err(handle_anyhow_error( - Some(StatusCode::FORBIDDEN), - None, - anyhow::anyhow!("unknown api key"), - ) - .await - .into_response()); - } - Err(err) => { - let err: anyhow::Error = err.into(); - return Err(handle_anyhow_error( - Some(StatusCode::INTERNAL_SERVER_ERROR), - None, - err.context("failed checking database for user key"), - ) - .await - .into_response()); - } + Ok(RateLimitResult::Allowed) } - Ok(()) + pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result { + let user_cache = self.user_cache(); + + // check the local cache + let user_data = if let Some(cached_user) = user_cache.read().await.get(&user_key) { + // TODO: also include the time this value was last checked! otherwise we cache forever! + if cached_user.expires_at < Instant::now() { + // old record found + None + } else { + // this key was active in the database recently + Some(*cached_user) + } + } else { + // cache miss + None + }; + + // if cache was empty, check the database + let user_data = if user_data.is_none() { + if let Some(db) = self.db_conn() { + /// helper enum for query just a few columns instead of the entire table + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + UserId, + RequestsPerMinute, + } + let user_data = match user_keys::Entity::find() + .select_only() + .column_as(user_keys::Column::UserId, QueryAs::UserId) + .column_as( + user_keys::Column::RequestsPerMinute, + QueryAs::RequestsPerMinute, + ) + .filter(user_keys::Column::ApiKey.eq(user_key)) + .filter(user_keys::Column::Active.eq(true)) + .into_values::<_, QueryAs>() + .one(db) + .await? + { + Some((user_id, requests_per_minute)) => { + UserCacheValue::from(( + // TODO: how long should this cache last? get this from config + Instant::now() + Duration::from_secs(60), + user_id, + requests_per_minute, + )) + } + None => { + return Err(anyhow::anyhow!("unknown api key")); + } + }; + + // save for the next run + user_cache.write().await.insert(user_key, user_data); + + user_data + } else { + // TODO: rate limit with only local caches? + unimplemented!("no cache hit and no database connection") + } + } else { + // unwrap the cache's result + user_data.unwrap() + }; + + // user key is valid. now check rate limits + if let Some(rate_limiter) = self.rate_limiter() { + // TODO: how does max burst actually work? what should it be? + let user_max_burst = user_data.user_count_per_period / 3; + let user_period = 60; + + if rate_limiter + .throttle_key( + &user_key.to_string(), + Some(user_max_burst), + Some(user_data.user_count_per_period), + Some(user_period), + ) + .await + .is_err() + { + // TODO: set headers so they know when they can retry + // warn!(?ip, "public rate limit exceeded"); // this is too verbose, but a stat might be good + // TODO: use their id if possible + // TODO: StatusCode::TOO_MANY_REQUESTS + return Err(anyhow::anyhow!("too many requests from this key")); + } + } else { + // TODO: if no redis, rate limit with a local cache? + unimplemented!("no redis. cannot rate limit") + } + + Ok(RateLimitResult::Allowed) + } +} + +pub async fn handle_rate_limit_error_response( + r: anyhow::Result, +) -> Option { + match r { + Ok(RateLimitResult::Allowed) => None, + Ok(RateLimitResult::RateLimitExceeded) => Some(handle_anyhow_error( + Some(StatusCode::TOO_MANY_REQUESTS), + None, + anyhow::anyhow!("rate limit exceeded"), + )), + Ok(RateLimitResult::UnknownKey) => Some(handle_anyhow_error( + Some(StatusCode::FORBIDDEN), + None, + anyhow::anyhow!("unknown key"), + )), + Err(err) => Some(handle_anyhow_error(None, None, err)), + } } diff --git a/web3_proxy/src/frontend/users.rs b/web3_proxy/src/frontend/users.rs index c957c21d..4ab2027a 100644 --- a/web3_proxy/src/frontend/users.rs +++ b/web3_proxy/src/frontend/users.rs @@ -15,7 +15,9 @@ use sea_orm::ActiveModelTrait; use serde::Deserialize; use std::sync::Arc; -use crate::{app::Web3ProxyApp, frontend::rate_limit::rate_limit_by_ip}; +use crate::app::Web3ProxyApp; + +use super::rate_limit::handle_rate_limit_error_response; pub async fn create_user( // this argument tells axum to parse the request body @@ -24,11 +26,13 @@ pub async fn create_user( Extension(app): Extension>, ClientIp(ip): ClientIp, ) -> impl IntoResponse { - if let Err(x) = rate_limit_by_ip(&app, &ip).await { - return x; + if let Some(err_response) = + handle_rate_limit_error_response(app.rate_limit_by_ip(&ip).await).await + { + return err_response.into_response(); } - // TODO: check invite_code against the app's config + // TODO: check invite_code against the app's config or database if payload.invite_code != "llam4n0des!" { todo!("proper error message") } @@ -49,7 +53,7 @@ pub async fn create_user( ..Default::default() }; - let db = app.db_conn(); + let db = app.db_conn().unwrap(); // TODO: proper error message let user = user.insert(db).await.unwrap(); diff --git a/web3_proxy/src/frontend/ws_proxy.rs b/web3_proxy/src/frontend/ws_proxy.rs index 32638acf..e9ed7ebb 100644 --- a/web3_proxy/src/frontend/ws_proxy.rs +++ b/web3_proxy/src/frontend/ws_proxy.rs @@ -22,17 +22,19 @@ use crate::{ jsonrpc::{JsonRpcForwardedResponse, JsonRpcForwardedResponseEnum, JsonRpcRequest}, }; -use super::rate_limit::{rate_limit_by_ip, rate_limit_by_key}; +use super::rate_limit::handle_rate_limit_error_response; pub async fn public_websocket_handler( Extension(app): Extension>, ClientIp(ip): ClientIp, - ws: Option, + ws_upgrade: Option, ) -> Response { - match ws { + match ws_upgrade { Some(ws) => { - if let Err(x) = rate_limit_by_ip(&app, &ip).await { - return x.into_response(); + if let Some(err_response) = + handle_rate_limit_error_response(app.rate_limit_by_ip(&ip).await).await + { + return err_response.into_response(); } ws.on_upgrade(|socket| proxy_web3_socket(app, socket)) @@ -41,6 +43,7 @@ pub async fn public_websocket_handler( None => { // this is not a websocket. give a friendly page // TODO: make a friendly page + // TODO: rate limit this? "hello, world".into_response() } } @@ -51,8 +54,10 @@ pub async fn user_websocket_handler( ws: WebSocketUpgrade, Path(user_key): Path, ) -> Response { - if let Err(x) = rate_limit_by_key(&app, user_key).await { - return x.into_response(); + if let Some(err_response) = + handle_rate_limit_error_response(app.rate_limit_by_key(user_key).await).await + { + return err_response; } ws.on_upgrade(|socket| proxy_web3_socket(app, socket)) diff --git a/web3_proxy/src/lib.rs b/web3_proxy/src/lib.rs index 1e2b5eaf..37812ad3 100644 --- a/web3_proxy/src/lib.rs +++ b/web3_proxy/src/lib.rs @@ -1,5 +1,6 @@ pub mod app; pub mod bb8_helpers; +pub mod block_helpers; pub mod config; pub mod connection; pub mod connections;