use cache's expiration and race-free get_with

when this was a dashmap, we needed our own expiration and parallel requests would do the same query.

with moka, we can use their expiration code and get_with
This commit is contained in:
Bryan Stitt 2022-09-20 01:33:39 +00:00
parent 90fed885bc
commit 6ae24b1ff9
9 changed files with 101 additions and 146 deletions

19
Cargo.lock generated
View File

@ -50,18 +50,6 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "ahash"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57e6e951cfbb2db8de1828d49073a113a29fd7117b1596caa781a258c7e38d72"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"version_check",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.18" version = "0.7.18"
@ -1272,8 +1260,8 @@ dependencies = [
name = "deferred-rate-limiter" name = "deferred-rate-limiter"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"ahash 0.8.0",
"anyhow", "anyhow",
"hashbrown",
"moka", "moka",
"redis-rate-limiter", "redis-rate-limiter",
"tokio", "tokio",
@ -2237,7 +2225,7 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [ dependencies = [
"ahash 0.7.6", "ahash",
"serde", "serde",
] ]
@ -4609,7 +4597,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b69bf218860335ddda60d6ce85ee39f6cf6e5630e300e19757d1de15886a093" checksum = "6b69bf218860335ddda60d6ce85ee39f6cf6e5630e300e19757d1de15886a093"
dependencies = [ dependencies = [
"ahash 0.7.6", "ahash",
"atoi", "atoi",
"bitflags", "bitflags",
"byteorder", "byteorder",
@ -5544,7 +5532,6 @@ dependencies = [
name = "web3_proxy" name = "web3_proxy"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ahash 0.8.0",
"anyhow", "anyhow",
"arc-swap", "arc-swap",
"argh", "argh",

View File

@ -7,8 +7,8 @@ edition = "2021"
[dependencies] [dependencies]
redis-rate-limiter = { path = "../redis-rate-limiter" } redis-rate-limiter = { path = "../redis-rate-limiter" }
ahash = "0.8.0"
anyhow = "1.0.65" anyhow = "1.0.65"
hashbrown = "*"
moka = { version = "0.9.4", default-features = false, features = ["future"] } moka = { version = "0.9.4", default-features = false, features = ["future"] }
tokio = "1.21.1" tokio = "1.21.1"
tracing = "0.1.36" tracing = "0.1.36"

View File

@ -16,7 +16,7 @@ pub struct DeferredRateLimiter<K>
where where
K: Send + Sync, K: Send + Sync,
{ {
local_cache: Cache<K, Arc<AtomicU64>, ahash::RandomState>, local_cache: Cache<K, Arc<AtomicU64>, hashbrown::hash_map::DefaultHashBuilder>,
prefix: String, prefix: String,
rrl: RedisRateLimiter, rrl: RedisRateLimiter,
} }
@ -39,7 +39,7 @@ where
.time_to_live(Duration::from_secs(ttl)) .time_to_live(Duration::from_secs(ttl))
.max_capacity(cache_size) .max_capacity(cache_size)
.name(prefix) .name(prefix)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
Self { Self {
local_cache, local_cache,

View File

@ -19,7 +19,6 @@ entities = { path = "../entities" }
migration = { path = "../migration" } migration = { path = "../migration" }
redis-rate-limiter = { path = "../redis-rate-limiter" } redis-rate-limiter = { path = "../redis-rate-limiter" }
ahash = "0.8.0"
anyhow = { version = "1.0.65", features = ["backtrace"] } anyhow = { version = "1.0.65", features = ["backtrace"] }
arc-swap = "1.5.1" arc-swap = "1.5.1"
argh = "0.1.8" argh = "0.1.8"

View File

@ -39,7 +39,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{broadcast, watch}; use tokio::sync::{broadcast, watch};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::time::{timeout, Instant}; use tokio::time::{timeout};
use tokio_stream::wrappers::{BroadcastStream, WatchStream}; use tokio_stream::wrappers::{BroadcastStream, WatchStream};
use tracing::{error, info, info_span, instrument, trace, warn, Instrument}; use tracing::{error, info, info_span, instrument, trace, warn, Instrument};
use uuid::Uuid; use uuid::Uuid;
@ -55,13 +55,13 @@ static APP_USER_AGENT: &str = concat!(
/// block hash, method, params /// block hash, method, params
// TODO: better name // TODO: better name
type ResponseCacheKey = (H256, String, Option<String>); type ResponseCacheKey = (H256, String, Option<String>);
type ResponseCache = Cache<ResponseCacheKey, JsonRpcForwardedResponse, ahash::RandomState>; type ResponseCache =
Cache<ResponseCacheKey, JsonRpcForwardedResponse, hashbrown::hash_map::DefaultHashBuilder>;
pub type AnyhowJoinHandle<T> = JoinHandle<anyhow::Result<T>>; pub type AnyhowJoinHandle<T> = JoinHandle<anyhow::Result<T>>;
#[derive(Clone, Copy, From)] #[derive(Clone, Copy, From)]
pub struct UserCacheValue { pub struct UserCacheValue {
pub expires_at: Instant,
pub user_id: u64, pub user_id: u64,
/// if None, allow unlimited queries /// if None, allow unlimited queries
pub user_count_per_period: Option<u64>, pub user_count_per_period: Option<u64>,
@ -86,11 +86,11 @@ pub struct Web3ProxyApp {
app_metrics: Arc<Web3ProxyAppMetrics>, app_metrics: Arc<Web3ProxyAppMetrics>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
/// store pending transactions that we've seen so that we don't send duplicates to subscribers /// store pending transactions that we've seen so that we don't send duplicates to subscribers
pub pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>, pub pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
pub frontend_ip_rate_limiter: Option<DeferredRateLimiter<IpAddr>>, pub frontend_ip_rate_limiter: Option<DeferredRateLimiter<IpAddr>>,
pub frontend_key_rate_limiter: Option<DeferredRateLimiter<Uuid>>, pub frontend_key_rate_limiter: Option<DeferredRateLimiter<Uuid>>,
pub redis_pool: Option<RedisPool>, pub redis_pool: Option<RedisPool>,
pub user_cache: Cache<Uuid, UserCacheValue, ahash::RandomState>, pub user_cache: Cache<Uuid, UserCacheValue, hashbrown::hash_map::DefaultHashBuilder>,
} }
/// flatten a JoinError into an anyhow error /// flatten a JoinError into an anyhow error
@ -256,11 +256,13 @@ impl Web3ProxyApp {
// TODO: once a transaction is "Confirmed" we remove it from the map. this should prevent major memory leaks. // TODO: once a transaction is "Confirmed" we remove it from the map. this should prevent major memory leaks.
// TODO: we should still have some sort of expiration or maximum size limit for the map // TODO: we should still have some sort of expiration or maximum size limit for the map
drop(pending_tx_receiver); drop(pending_tx_receiver);
// TODO: capacity from configs // TODO: capacity from configs
// all these are the same size, so no need for a weigher // all these are the same size, so no need for a weigher
// TODO: ttl on this?
let pending_transactions = Cache::builder() let pending_transactions = Cache::builder()
.max_capacity(10_000) .max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// keep 1GB of blocks in the cache // keep 1GB of blocks in the cache
// TODO: limits from config // TODO: limits from config
@ -268,7 +270,7 @@ impl Web3ProxyApp {
let block_map = Cache::builder() let block_map = Cache::builder()
.max_capacity(1024 * 1024 * 1024) .max_capacity(1024 * 1024 * 1024)
.weigher(|_k, v| size_of_val(v) as u32) .weigher(|_k, v| size_of_val(v) as u32)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let (balanced_rpcs, balanced_handle) = Web3Connections::spawn( let (balanced_rpcs, balanced_handle) = Web3Connections::spawn(
top_config.app.chain_id, top_config.app.chain_id,
@ -345,14 +347,14 @@ impl Web3ProxyApp {
let response_cache = Cache::builder() let response_cache = Cache::builder()
.max_capacity(1024 * 1024 * 1024) .max_capacity(1024 * 1024 * 1024)
.weigher(|k, v| (size_of_val(k) + size_of_val(v)) as u32) .weigher(|k, v| (size_of_val(k) + size_of_val(v)) as u32)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// all the users are the same size, so no need for a weigher // all the users are the same size, so no need for a weigher
// TODO: max_capacity from config // TODO: max_capacity from config
let user_cache = Cache::builder() let user_cache = Cache::builder()
.max_capacity(10_000) .max_capacity(10_000)
.time_to_live(Duration::from_secs(60)) .time_to_live(Duration::from_secs(60))
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let app = Self { let app = Self {
config: top_config.app, config: top_config.app,

View File

@ -6,7 +6,7 @@ use entities::user_keys;
use sea_orm::{ use sea_orm::{
ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect, ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect,
}; };
use std::{net::IpAddr, time::Duration}; use std::{net::IpAddr, sync::Arc};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{error, trace}; use tracing::{error, trace};
use uuid::Uuid; use uuid::Uuid;
@ -83,134 +83,100 @@ impl Web3ProxyApp {
} }
} }
pub(crate) async fn cache_user_data(&self, user_key: Uuid) -> anyhow::Result<UserCacheValue> { pub(crate) async fn user_data(&self, user_key: Uuid) -> anyhow::Result<UserCacheValue> {
let db = self.db_conn.as_ref().context("no database")?; let db = self.db_conn.as_ref().context("no database")?;
/// helper enum for query just a few columns instead of the entire table let user_data: Result<_, Arc<anyhow::Error>> = self
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] .user_cache
enum QueryAs { .try_get_with(user_key, async move {
UserId, /// helper enum for querying just a few columns instead of the entire table
RequestsPerMinute, #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
} enum QueryAs {
// TODO: join the user table to this to return the User? we don't always need it UserId,
let user_data = match user_keys::Entity::find() RequestsPerMinute,
.select_only() }
.column_as(user_keys::Column::UserId, QueryAs::UserId)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await?
{
Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
UserCacheValue::from((
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
user_id,
user_count_per_period,
))
}
None => {
// TODO: think about this more
UserCacheValue::from((
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
0,
Some(0),
))
}
};
// save for the next run // TODO: join the user table to this to return the User? we don't always need it
self.user_cache.insert(user_key, user_data).await; match user_keys::Entity::find()
.select_only()
.column_as(user_keys::Column::UserId, QueryAs::UserId)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await?
{
Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
Ok(UserCacheValue::from((user_id, user_count_per_period)))
}
None => Ok(UserCacheValue::from((0, Some(0)))),
}
})
.await;
Ok(user_data) // TODO: i'm not actually sure about this expect
user_data.map_err(|err| Arc::try_unwrap(err).expect("this should be the only reference"))
} }
pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> { pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> {
// check the local cache fo user data to save a database query // check the local cache fo user data to save a database query
let user_data = if let Some(cached_user) = self.user_cache.get(&user_key) { let user_data = self.user_data(user_key).await?;
// TODO: also include the time this value was last checked! otherwise we cache forever!
if cached_user.expires_at < Instant::now() {
// old record found
None
} else {
// this key was active in the database recently
Some(cached_user)
}
} else {
// cache miss
None
};
// if cache was empty, check the database
let user_data = match user_data {
None => self
.cache_user_data(user_key)
.await
.context("fetching user data for rate limits")?,
Some(user_data) => user_data,
};
if user_data.user_id == 0 { if user_data.user_id == 0 {
return Ok(RateLimitResult::UnknownKey); return Ok(RateLimitResult::UnknownKey);
} }
// TODO: turn back on rate limiting once our alpha test is complete let user_count_per_period = match user_data.user_count_per_period {
// TODO: if user_data.unlimited_queries None => return Ok(RateLimitResult::AllowedUser(user_data.user_id)),
// return Ok(RateLimitResult::AllowedUser(user_data.user_id)); Some(x) => x,
};
// user key is valid. now check rate limits // user key is valid. now check rate limits
if let Some(rate_limiter) = &self.frontend_key_rate_limiter { if let Some(rate_limiter) = &self.frontend_key_rate_limiter {
if user_data.user_count_per_period.is_none() { match rate_limiter
// None means unlimited rate limit .throttle(&user_key, Some(user_count_per_period), 1)
Ok(RateLimitResult::AllowedUser(user_data.user_id)) .await
} else { {
match rate_limiter Ok(DeferredRateLimitResult::Allowed) => {
.throttle(&user_key, user_data.user_count_per_period, 1) Ok(RateLimitResult::AllowedUser(user_data.user_id))
.await }
{ Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
Ok(DeferredRateLimitResult::Allowed) => { // TODO: set headers so they know when they can retry
Ok(RateLimitResult::AllowedUser(user_data.user_id)) // TODO: debug or trace?
} // this is too verbose, but a stat might be good
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => { // TODO: keys are secrets! use the id instead
// TODO: set headers so they know when they can retry trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
// TODO: debug or trace? Ok(RateLimitResult::RateLimitedUser(
// this is too verbose, but a stat might be good user_data.user_id,
// TODO: keys are secrets! use the id instead Some(retry_at),
trace!(?user_key, "rate limit exceeded until {:?}", retry_at); ))
Ok(RateLimitResult::RateLimitedUser( }
user_data.user_id, Ok(DeferredRateLimitResult::RetryNever) => {
Some(retry_at), // TODO: keys are secret. don't log them!
)) trace!(?user_key, "rate limit is 0");
} Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
Ok(DeferredRateLimitResult::RetryNever) => { }
// TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely Err(err) => {
// TODO: keys are secret. don't log them! // internal error, not rate limit being hit
trace!(?user_key, "rate limit is 0"); // TODO: i really want axum to do this for us in a single place.
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None)) error!(?err, "rate limiter is unhappy. allowing ip");
} Ok(RateLimitResult::AllowedUser(user_data.user_id))
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
} }
} }
} else { } else {
// TODO: if no redis, rate limit with a local cache? // TODO: if no redis, rate limit with just a local cache?
todo!("no redis. cannot rate limit") // if we don't have redis, we probably don't have a db, so this probably will never happen
Err(anyhow::anyhow!("no redis. cannot rate limit"))
} }
} }
} }

View File

@ -248,7 +248,7 @@ pub async fn post_login(
// save the user data in redis with a short expiry // save the user data in redis with a short expiry
// TODO: we already have uk, so this could be more efficient. it works for now // TODO: we already have uk, so this could be more efficient. it works for now
app.cache_user_data(uk.api_key).await?; app.user_data(uk.api_key).await?;
Ok(response) Ok(response)
} }

View File

@ -19,7 +19,7 @@ use tracing::{debug, trace, warn};
// TODO: type for Hydrated Blocks with their full transactions? // TODO: type for Hydrated Blocks with their full transactions?
pub type ArcBlock = Arc<Block<TxHash>>; pub type ArcBlock = Arc<Block<TxHash>>;
pub type BlockHashesCache = Cache<H256, ArcBlock, ahash::RandomState>; pub type BlockHashesCache = Cache<H256, ArcBlock, hashbrown::hash_map::DefaultHashBuilder>;
/// A block's hash and number. /// A block's hash and number.
#[derive(Clone, Debug, Default, From, Serialize)] #[derive(Clone, Debug, Default, From, Serialize)]

View File

@ -36,12 +36,13 @@ pub struct Web3Connections {
pub(super) conns: HashMap<String, Arc<Web3Connection>>, pub(super) conns: HashMap<String, Arc<Web3Connection>>,
/// any requests will be forwarded to one (or more) of these connections /// any requests will be forwarded to one (or more) of these connections
pub(super) synced_connections: ArcSwap<SyncedConnections>, pub(super) synced_connections: ArcSwap<SyncedConnections>,
pub(super) pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>, pub(super) pending_transactions:
Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
/// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis? /// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis?
/// all blocks, including orphans /// all blocks, including orphans
pub(super) block_hashes: BlockHashesCache, pub(super) block_hashes: BlockHashesCache,
/// blocks on the heaviest chain /// blocks on the heaviest chain
pub(super) block_numbers: Cache<U64, H256, ahash::RandomState>, pub(super) block_numbers: Cache<U64, H256, hashbrown::hash_map::DefaultHashBuilder>,
/// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis? /// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis?
/// TODO: what should we use for edges? /// TODO: what should we use for edges?
pub(super) blockchain_graphmap: AsyncRwLock<DiGraphMap<H256, u32>>, pub(super) blockchain_graphmap: AsyncRwLock<DiGraphMap<H256, u32>>,
@ -62,7 +63,7 @@ impl Web3Connections {
min_sum_soft_limit: u32, min_sum_soft_limit: u32,
min_synced_rpcs: usize, min_synced_rpcs: usize,
pending_tx_sender: Option<broadcast::Sender<TxStatus>>, pending_tx_sender: Option<broadcast::Sender<TxStatus>>,
pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>, pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> { ) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> {
let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded(); let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded();
@ -180,12 +181,12 @@ impl Web3Connections {
let block_hashes = Cache::builder() let block_hashes = Cache::builder()
.time_to_idle(Duration::from_secs(600)) .time_to_idle(Duration::from_secs(600))
.max_capacity(10_000) .max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// all block numbers are the same size, so no need for weigher // all block numbers are the same size, so no need for weigher
let block_numbers = Cache::builder() let block_numbers = Cache::builder()
.time_to_idle(Duration::from_secs(600)) .time_to_idle(Duration::from_secs(600))
.max_capacity(10_000) .max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new()); .build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let connections = Arc::new(Self { let connections = Arc::new(Self {
conns: connections, conns: connections,