pass db conn through

This commit is contained in:
Bryan Stitt 2022-09-22 22:10:28 +00:00
parent 3854312674
commit 8d011e0cd1
6 changed files with 62 additions and 37 deletions

@ -288,6 +288,7 @@ impl Web3ProxyApp {
Some(pending_tx_sender.clone()),
pending_transactions.clone(),
open_request_handle_metrics.clone(),
db_conn.clone(),
)
.await
.context("balanced rpcs")?;
@ -315,6 +316,7 @@ impl Web3ProxyApp {
None,
pending_transactions.clone(),
open_request_handle_metrics.clone(),
db_conn.clone(),
)
.await
.context("private_rpcs")?;

@ -6,6 +6,7 @@ use argh::FromArgs;
use derive_more::Constructor;
use ethers::prelude::TxHash;
use hashbrown::HashMap;
use sea_orm::DatabaseConnection;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::broadcast;
@ -124,6 +125,7 @@ impl Web3ConnectionConfig {
block_sender: Option<flume::Sender<BlockAndRpc>>,
tx_id_sender: Option<flume::Sender<TxHashAndRpc>>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> {
let hard_limit = match (self.hard_limit, redis_pool) {
(None, None) => None,
@ -156,6 +158,7 @@ impl Web3ConnectionConfig {
true,
self.weight,
open_request_handle_metrics,
db_conn,
)
.await
}

@ -5,7 +5,8 @@ use axum::headers::{Referer, UserAgent};
use deferred_rate_limiter::DeferredRateLimitResult;
use entities::user_keys;
use sea_orm::{
ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect,
ColumnTrait, DatabaseConnection, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter,
QuerySelect,
};
use serde::Serialize;
use std::{net::IpAddr, sync::Arc};
@ -53,11 +54,11 @@ impl AuthorizedKey {
#[derive(Debug, Serialize)]
pub enum AuthorizedRequest {
/// Request from the app itself
Internal,
Internal(#[serde(skip)] Option<DatabaseConnection>),
/// Request from an anonymous IP address
Ip(IpAddr),
Ip(#[serde(skip)] Option<DatabaseConnection>, IpAddr),
/// Request from an authenticated and authorized user
User(AuthorizedKey),
User(#[serde(skip)] Option<DatabaseConnection>, AuthorizedKey),
}
pub async fn ip_is_authorized(
@ -74,7 +75,9 @@ pub async fn ip_is_authorized(
x => unimplemented!("rate_limit_by_ip shouldn't ever see these: {:?}", x),
};
Ok(AuthorizedRequest::Ip(ip))
let db = app.db_conn.clone();
Ok(AuthorizedRequest::Ip(db, ip))
}
pub async fn key_is_authorized(
@ -97,7 +100,9 @@ pub async fn key_is_authorized(
let authorized_user = AuthorizedKey::try_new(ip, user_data, referer, user_agent)?;
Ok(AuthorizedRequest::User(authorized_user))
let db = app.db_conn.clone();
Ok(AuthorizedRequest::User(db, authorized_user))
}
impl Web3ProxyApp {

@ -12,6 +12,7 @@ use futures::StreamExt;
use parking_lot::RwLock;
use rand::Rng;
use redis_rate_limiter::{RedisPool, RedisRateLimitResult, RedisRateLimiter};
use sea_orm::DatabaseConnection;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use std::cmp::min;
@ -52,6 +53,7 @@ pub struct Web3Connection {
// TODO: async lock?
pub(super) head_block_id: RwLock<Option<BlockId>>,
pub(super) open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
pub(super) db_conn: Option<DatabaseConnection>,
}
impl Web3Connection {
@ -76,6 +78,7 @@ impl Web3Connection {
reconnect: bool,
weight: u32,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> {
let hard_limit = hard_limit.map(|(hard_rate_limit, redis_pool)| {
// TODO: is cache size 1 okay? i think we need
@ -101,6 +104,7 @@ impl Web3Connection {
head_block_id: RwLock::new(Default::default()),
weight,
open_request_handle_metrics,
db_conn,
};
let new_connection = Arc::new(new_connection);

@ -20,6 +20,7 @@ use futures::StreamExt;
use hashbrown::HashMap;
use moka::future::{Cache, ConcurrentCacheExt};
use petgraph::graphmap::DiGraphMap;
use sea_orm::DatabaseConnection;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use serde_json::value::RawValue;
@ -68,6 +69,7 @@ impl Web3Connections {
pending_tx_sender: Option<broadcast::Sender<TxStatus>>,
pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> {
let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded();
let (block_sender, block_receiver) = flume::unbounded::<BlockAndRpc>();
@ -125,6 +127,7 @@ impl Web3Connections {
let pending_tx_id_sender = Some(pending_tx_id_sender.clone());
let block_map = block_map.clone();
let open_request_handle_metrics = open_request_handle_metrics.clone();
let db_conn = db_conn.clone();
tokio::spawn(async move {
server_config
@ -138,6 +141,7 @@ impl Web3Connections {
block_sender,
pending_tx_id_sender,
open_request_handle_metrics,
db_conn,
)
.await
})

@ -7,6 +7,7 @@ use metered::metered;
use metered::HitCount;
use metered::ResponseTime;
use metered::Throughput;
use rand::Rng;
use std::fmt;
use std::sync::atomic::{self, AtomicBool, Ordering};
use std::sync::Arc;
@ -77,7 +78,10 @@ impl OpenRequestHandle {
let metrics = conn.open_request_handle_metrics.clone();
let used = false.into();
let authorization = authorization.unwrap_or_else(|| Arc::new(AuthorizedRequest::Internal));
let authorization = authorization.unwrap_or_else(|| {
let db_conn = conn.db_conn.clone();
Arc::new(AuthorizedRequest::Internal(db_conn))
});
Self {
authorization,
@ -156,43 +160,46 @@ impl OpenRequestHandle {
// TODO: only set SaveReverts if this is an eth_call or eth_estimateGas? we'll need eth_sendRawTransaction somewhere else
// TODO: logging every one is going to flood the database
// TODO: have a percent chance to do this. or maybe a "logged reverts per second"
if let ProviderError::JsonRpcClientError(err) = err {
match provider {
Web3Provider::Http(_) => {
if let Some(HttpClientError::JsonRpcError(err)) =
err.downcast_ref::<HttpClientError>()
{
if err.message.starts_with("execution reverted") {
debug!(%method, ?params, "TODO: save the request");
if save_chance == 1.0 || rand::thread_rng().gen_range(0.0..=1.0) <= save_chance
{
if let ProviderError::JsonRpcClientError(err) = err {
match provider {
Web3Provider::Http(_) => {
if let Some(HttpClientError::JsonRpcError(err)) =
err.downcast_ref::<HttpClientError>()
{
if err.message.starts_with("execution reverted") {
debug!(%method, ?params, "TODO: save the request");
let f = self
.authorization
.clone()
.save_revert(method.to_string(), params);
let f = self
.authorization
.clone()
.save_revert(method.to_string(), params);
tokio::spawn(async move { f.await });
tokio::spawn(async move { f.await });
// TODO: don't do this on the hot path. spawn it
} else {
debug!(?err, %method, rpc=%self.conn, "bad response!");
// TODO: don't do this on the hot path. spawn it
} else {
debug!(?err, %method, rpc=%self.conn, "bad response!");
}
}
}
}
Web3Provider::Ws(_) => {
if let Some(WsClientError::JsonRpcError(err)) =
err.downcast_ref::<WsClientError>()
{
if err.message.starts_with("execution reverted") {
debug!(%method, ?params, "TODO: save the request");
Web3Provider::Ws(_) => {
if let Some(WsClientError::JsonRpcError(err)) =
err.downcast_ref::<WsClientError>()
{
if err.message.starts_with("execution reverted") {
debug!(%method, ?params, "TODO: save the request");
let f = self
.authorization
.clone()
.save_revert(method.to_string(), params);
let f = self
.authorization
.clone()
.save_revert(method.to_string(), params);
tokio::spawn(async move { f.await });
} else {
debug!(?err, %method, rpc=%self.conn, "bad response!");
tokio::spawn(async move { f.await });
} else {
debug!(?err, %method, rpc=%self.conn, "bad response!");
}
}
}
}