use cache's expiration and race-free get_with

when this was a dashmap, we needed our own expiration and parallel requests would do the same query.

with moka, we can use their expiration code and get_with
This commit is contained in:
Bryan Stitt 2022-09-20 01:33:39 +00:00
parent 90fed885bc
commit 6ae24b1ff9
9 changed files with 101 additions and 146 deletions

19
Cargo.lock generated

@ -50,18 +50,6 @@ dependencies = [
"version_check",
]
[[package]]
name = "ahash"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57e6e951cfbb2db8de1828d49073a113a29fd7117b1596caa781a258c7e38d72"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.18"
@ -1272,8 +1260,8 @@ dependencies = [
name = "deferred-rate-limiter"
version = "0.2.0"
dependencies = [
"ahash 0.8.0",
"anyhow",
"hashbrown",
"moka",
"redis-rate-limiter",
"tokio",
@ -2237,7 +2225,7 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash 0.7.6",
"ahash",
"serde",
]
@ -4609,7 +4597,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b69bf218860335ddda60d6ce85ee39f6cf6e5630e300e19757d1de15886a093"
dependencies = [
"ahash 0.7.6",
"ahash",
"atoi",
"bitflags",
"byteorder",
@ -5544,7 +5532,6 @@ dependencies = [
name = "web3_proxy"
version = "0.1.0"
dependencies = [
"ahash 0.8.0",
"anyhow",
"arc-swap",
"argh",

@ -7,8 +7,8 @@ edition = "2021"
[dependencies]
redis-rate-limiter = { path = "../redis-rate-limiter" }
ahash = "0.8.0"
anyhow = "1.0.65"
hashbrown = "*"
moka = { version = "0.9.4", default-features = false, features = ["future"] }
tokio = "1.21.1"
tracing = "0.1.36"

@ -16,7 +16,7 @@ pub struct DeferredRateLimiter<K>
where
K: Send + Sync,
{
local_cache: Cache<K, Arc<AtomicU64>, ahash::RandomState>,
local_cache: Cache<K, Arc<AtomicU64>, hashbrown::hash_map::DefaultHashBuilder>,
prefix: String,
rrl: RedisRateLimiter,
}
@ -39,7 +39,7 @@ where
.time_to_live(Duration::from_secs(ttl))
.max_capacity(cache_size)
.name(prefix)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
Self {
local_cache,

@ -19,7 +19,6 @@ entities = { path = "../entities" }
migration = { path = "../migration" }
redis-rate-limiter = { path = "../redis-rate-limiter" }
ahash = "0.8.0"
anyhow = { version = "1.0.65", features = ["backtrace"] }
arc-swap = "1.5.1"
argh = "0.1.8"

@ -39,7 +39,7 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, watch};
use tokio::task::JoinHandle;
use tokio::time::{timeout, Instant};
use tokio::time::{timeout};
use tokio_stream::wrappers::{BroadcastStream, WatchStream};
use tracing::{error, info, info_span, instrument, trace, warn, Instrument};
use uuid::Uuid;
@ -55,13 +55,13 @@ static APP_USER_AGENT: &str = concat!(
/// block hash, method, params
// TODO: better name
type ResponseCacheKey = (H256, String, Option<String>);
type ResponseCache = Cache<ResponseCacheKey, JsonRpcForwardedResponse, ahash::RandomState>;
type ResponseCache =
Cache<ResponseCacheKey, JsonRpcForwardedResponse, hashbrown::hash_map::DefaultHashBuilder>;
pub type AnyhowJoinHandle<T> = JoinHandle<anyhow::Result<T>>;
#[derive(Clone, Copy, From)]
pub struct UserCacheValue {
pub expires_at: Instant,
pub user_id: u64,
/// if None, allow unlimited queries
pub user_count_per_period: Option<u64>,
@ -86,11 +86,11 @@ pub struct Web3ProxyApp {
app_metrics: Arc<Web3ProxyAppMetrics>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
/// store pending transactions that we've seen so that we don't send duplicates to subscribers
pub pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>,
pub pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
pub frontend_ip_rate_limiter: Option<DeferredRateLimiter<IpAddr>>,
pub frontend_key_rate_limiter: Option<DeferredRateLimiter<Uuid>>,
pub redis_pool: Option<RedisPool>,
pub user_cache: Cache<Uuid, UserCacheValue, ahash::RandomState>,
pub user_cache: Cache<Uuid, UserCacheValue, hashbrown::hash_map::DefaultHashBuilder>,
}
/// flatten a JoinError into an anyhow error
@ -256,11 +256,13 @@ impl Web3ProxyApp {
// TODO: once a transaction is "Confirmed" we remove it from the map. this should prevent major memory leaks.
// TODO: we should still have some sort of expiration or maximum size limit for the map
drop(pending_tx_receiver);
// TODO: capacity from configs
// all these are the same size, so no need for a weigher
// TODO: ttl on this?
let pending_transactions = Cache::builder()
.max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// keep 1GB of blocks in the cache
// TODO: limits from config
@ -268,7 +270,7 @@ impl Web3ProxyApp {
let block_map = Cache::builder()
.max_capacity(1024 * 1024 * 1024)
.weigher(|_k, v| size_of_val(v) as u32)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let (balanced_rpcs, balanced_handle) = Web3Connections::spawn(
top_config.app.chain_id,
@ -345,14 +347,14 @@ impl Web3ProxyApp {
let response_cache = Cache::builder()
.max_capacity(1024 * 1024 * 1024)
.weigher(|k, v| (size_of_val(k) + size_of_val(v)) as u32)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// all the users are the same size, so no need for a weigher
// TODO: max_capacity from config
let user_cache = Cache::builder()
.max_capacity(10_000)
.time_to_live(Duration::from_secs(60))
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let app = Self {
config: top_config.app,

@ -6,7 +6,7 @@ use entities::user_keys;
use sea_orm::{
ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect,
};
use std::{net::IpAddr, time::Duration};
use std::{net::IpAddr, sync::Arc};
use tokio::time::Instant;
use tracing::{error, trace};
use uuid::Uuid;
@ -83,134 +83,100 @@ impl Web3ProxyApp {
}
}
pub(crate) async fn cache_user_data(&self, user_key: Uuid) -> anyhow::Result<UserCacheValue> {
pub(crate) async fn user_data(&self, user_key: Uuid) -> anyhow::Result<UserCacheValue> {
let db = self.db_conn.as_ref().context("no database")?;
/// helper enum for query just a few columns instead of the entire table
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
UserId,
RequestsPerMinute,
}
// TODO: join the user table to this to return the User? we don't always need it
let user_data = match user_keys::Entity::find()
.select_only()
.column_as(user_keys::Column::UserId, QueryAs::UserId)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await?
{
Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
UserCacheValue::from((
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
user_id,
user_count_per_period,
))
}
None => {
// TODO: think about this more
UserCacheValue::from((
// TODO: how long should this cache last? get this from config
Instant::now() + Duration::from_secs(60),
0,
Some(0),
))
}
};
let user_data: Result<_, Arc<anyhow::Error>> = self
.user_cache
.try_get_with(user_key, async move {
/// helper enum for querying just a few columns instead of the entire table
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
UserId,
RequestsPerMinute,
}
// save for the next run
self.user_cache.insert(user_key, user_data).await;
// TODO: join the user table to this to return the User? we don't always need it
match user_keys::Entity::find()
.select_only()
.column_as(user_keys::Column::UserId, QueryAs::UserId)
.column_as(
user_keys::Column::RequestsPerMinute,
QueryAs::RequestsPerMinute,
)
.filter(user_keys::Column::ApiKey.eq(user_key))
.filter(user_keys::Column::Active.eq(true))
.into_values::<_, QueryAs>()
.one(db)
.await?
{
Some((user_id, requests_per_minute)) => {
// TODO: add a column here for max, or is u64::MAX fine?
let user_count_per_period = if requests_per_minute == u64::MAX {
None
} else {
Some(requests_per_minute)
};
Ok(UserCacheValue::from((user_id, user_count_per_period)))
}
None => Ok(UserCacheValue::from((0, Some(0)))),
}
})
.await;
Ok(user_data)
// TODO: i'm not actually sure about this expect
user_data.map_err(|err| Arc::try_unwrap(err).expect("this should be the only reference"))
}
pub async fn rate_limit_by_key(&self, user_key: Uuid) -> anyhow::Result<RateLimitResult> {
// check the local cache fo user data to save a database query
let user_data = if let Some(cached_user) = self.user_cache.get(&user_key) {
// TODO: also include the time this value was last checked! otherwise we cache forever!
if cached_user.expires_at < Instant::now() {
// old record found
None
} else {
// this key was active in the database recently
Some(cached_user)
}
} else {
// cache miss
None
};
// if cache was empty, check the database
let user_data = match user_data {
None => self
.cache_user_data(user_key)
.await
.context("fetching user data for rate limits")?,
Some(user_data) => user_data,
};
let user_data = self.user_data(user_key).await?;
if user_data.user_id == 0 {
return Ok(RateLimitResult::UnknownKey);
}
// TODO: turn back on rate limiting once our alpha test is complete
// TODO: if user_data.unlimited_queries
// return Ok(RateLimitResult::AllowedUser(user_data.user_id));
let user_count_per_period = match user_data.user_count_per_period {
None => return Ok(RateLimitResult::AllowedUser(user_data.user_id)),
Some(x) => x,
};
// user key is valid. now check rate limits
if let Some(rate_limiter) = &self.frontend_key_rate_limiter {
if user_data.user_count_per_period.is_none() {
// None means unlimited rate limit
Ok(RateLimitResult::AllowedUser(user_data.user_id))
} else {
match rate_limiter
.throttle(&user_key, user_data.user_count_per_period, 1)
.await
{
Ok(DeferredRateLimitResult::Allowed) => {
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
// TODO: set headers so they know when they can retry
// TODO: debug or trace?
// this is too verbose, but a stat might be good
// TODO: keys are secrets! use the id instead
trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
Ok(RateLimitResult::RateLimitedUser(
user_data.user_id,
Some(retry_at),
))
}
Ok(DeferredRateLimitResult::RetryNever) => {
// TODO: i don't think we'll get here. maybe if we ban an IP forever? seems unlikely
// TODO: keys are secret. don't log them!
trace!(?user_key, "rate limit is 0");
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
}
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
match rate_limiter
.throttle(&user_key, Some(user_count_per_period), 1)
.await
{
Ok(DeferredRateLimitResult::Allowed) => {
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
Ok(DeferredRateLimitResult::RetryAt(retry_at)) => {
// TODO: set headers so they know when they can retry
// TODO: debug or trace?
// this is too verbose, but a stat might be good
// TODO: keys are secrets! use the id instead
trace!(?user_key, "rate limit exceeded until {:?}", retry_at);
Ok(RateLimitResult::RateLimitedUser(
user_data.user_id,
Some(retry_at),
))
}
Ok(DeferredRateLimitResult::RetryNever) => {
// TODO: keys are secret. don't log them!
trace!(?user_key, "rate limit is 0");
Ok(RateLimitResult::RateLimitedUser(user_data.user_id, None))
}
Err(err) => {
// internal error, not rate limit being hit
// TODO: i really want axum to do this for us in a single place.
error!(?err, "rate limiter is unhappy. allowing ip");
Ok(RateLimitResult::AllowedUser(user_data.user_id))
}
}
} else {
// TODO: if no redis, rate limit with a local cache?
todo!("no redis. cannot rate limit")
// TODO: if no redis, rate limit with just a local cache?
// if we don't have redis, we probably don't have a db, so this probably will never happen
Err(anyhow::anyhow!("no redis. cannot rate limit"))
}
}
}

@ -248,7 +248,7 @@ pub async fn post_login(
// save the user data in redis with a short expiry
// TODO: we already have uk, so this could be more efficient. it works for now
app.cache_user_data(uk.api_key).await?;
app.user_data(uk.api_key).await?;
Ok(response)
}

@ -19,7 +19,7 @@ use tracing::{debug, trace, warn};
// TODO: type for Hydrated Blocks with their full transactions?
pub type ArcBlock = Arc<Block<TxHash>>;
pub type BlockHashesCache = Cache<H256, ArcBlock, ahash::RandomState>;
pub type BlockHashesCache = Cache<H256, ArcBlock, hashbrown::hash_map::DefaultHashBuilder>;
/// A block's hash and number.
#[derive(Clone, Debug, Default, From, Serialize)]

@ -36,12 +36,13 @@ pub struct Web3Connections {
pub(super) conns: HashMap<String, Arc<Web3Connection>>,
/// any requests will be forwarded to one (or more) of these connections
pub(super) synced_connections: ArcSwap<SyncedConnections>,
pub(super) pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>,
pub(super) pending_transactions:
Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
/// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis?
/// all blocks, including orphans
pub(super) block_hashes: BlockHashesCache,
/// blocks on the heaviest chain
pub(super) block_numbers: Cache<U64, H256, ahash::RandomState>,
pub(super) block_numbers: Cache<U64, H256, hashbrown::hash_map::DefaultHashBuilder>,
/// TODO: this map is going to grow forever unless we do some sort of pruning. maybe store pruned in redis?
/// TODO: what should we use for edges?
pub(super) blockchain_graphmap: AsyncRwLock<DiGraphMap<H256, u32>>,
@ -62,7 +63,7 @@ impl Web3Connections {
min_sum_soft_limit: u32,
min_synced_rpcs: usize,
pending_tx_sender: Option<broadcast::Sender<TxStatus>>,
pending_transactions: Cache<TxHash, TxStatus, ahash::RandomState>,
pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> {
let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded();
@ -180,12 +181,12 @@ impl Web3Connections {
let block_hashes = Cache::builder()
.time_to_idle(Duration::from_secs(600))
.max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
// all block numbers are the same size, so no need for weigher
let block_numbers = Cache::builder()
.time_to_idle(Duration::from_secs(600))
.max_capacity(10_000)
.build_with_hasher(ahash::RandomState::new());
.build_with_hasher(hashbrown::hash_map::DefaultHashBuilder::new());
let connections = Arc::new(Self {
conns: connections,