pass db conn through

This commit is contained in:
Bryan Stitt 2022-09-22 22:10:28 +00:00
parent 3854312674
commit 8d011e0cd1
6 changed files with 62 additions and 37 deletions

View File

@ -288,6 +288,7 @@ impl Web3ProxyApp {
Some(pending_tx_sender.clone()), Some(pending_tx_sender.clone()),
pending_transactions.clone(), pending_transactions.clone(),
open_request_handle_metrics.clone(), open_request_handle_metrics.clone(),
db_conn.clone(),
) )
.await .await
.context("balanced rpcs")?; .context("balanced rpcs")?;
@ -315,6 +316,7 @@ impl Web3ProxyApp {
None, None,
pending_transactions.clone(), pending_transactions.clone(),
open_request_handle_metrics.clone(), open_request_handle_metrics.clone(),
db_conn.clone(),
) )
.await .await
.context("private_rpcs")?; .context("private_rpcs")?;

View File

@ -6,6 +6,7 @@ use argh::FromArgs;
use derive_more::Constructor; use derive_more::Constructor;
use ethers::prelude::TxHash; use ethers::prelude::TxHash;
use hashbrown::HashMap; use hashbrown::HashMap;
use sea_orm::DatabaseConnection;
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
@ -124,6 +125,7 @@ impl Web3ConnectionConfig {
block_sender: Option<flume::Sender<BlockAndRpc>>, block_sender: Option<flume::Sender<BlockAndRpc>>,
tx_id_sender: Option<flume::Sender<TxHashAndRpc>>, tx_id_sender: Option<flume::Sender<TxHashAndRpc>>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> { ) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> {
let hard_limit = match (self.hard_limit, redis_pool) { let hard_limit = match (self.hard_limit, redis_pool) {
(None, None) => None, (None, None) => None,
@ -156,6 +158,7 @@ impl Web3ConnectionConfig {
true, true,
self.weight, self.weight,
open_request_handle_metrics, open_request_handle_metrics,
db_conn,
) )
.await .await
} }

View File

@ -5,7 +5,8 @@ use axum::headers::{Referer, UserAgent};
use deferred_rate_limiter::DeferredRateLimitResult; use deferred_rate_limiter::DeferredRateLimitResult;
use entities::user_keys; use entities::user_keys;
use sea_orm::{ use sea_orm::{
ColumnTrait, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter, QuerySelect, ColumnTrait, DatabaseConnection, DeriveColumn, EntityTrait, EnumIter, IdenStatic, QueryFilter,
QuerySelect,
}; };
use serde::Serialize; use serde::Serialize;
use std::{net::IpAddr, sync::Arc}; use std::{net::IpAddr, sync::Arc};
@ -53,11 +54,11 @@ impl AuthorizedKey {
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub enum AuthorizedRequest { pub enum AuthorizedRequest {
/// Request from the app itself /// Request from the app itself
Internal, Internal(#[serde(skip)] Option<DatabaseConnection>),
/// Request from an anonymous IP address /// Request from an anonymous IP address
Ip(IpAddr), Ip(#[serde(skip)] Option<DatabaseConnection>, IpAddr),
/// Request from an authenticated and authorized user /// Request from an authenticated and authorized user
User(AuthorizedKey), User(#[serde(skip)] Option<DatabaseConnection>, AuthorizedKey),
} }
pub async fn ip_is_authorized( pub async fn ip_is_authorized(
@ -74,7 +75,9 @@ pub async fn ip_is_authorized(
x => unimplemented!("rate_limit_by_ip shouldn't ever see these: {:?}", x), x => unimplemented!("rate_limit_by_ip shouldn't ever see these: {:?}", x),
}; };
Ok(AuthorizedRequest::Ip(ip)) let db = app.db_conn.clone();
Ok(AuthorizedRequest::Ip(db, ip))
} }
pub async fn key_is_authorized( pub async fn key_is_authorized(
@ -97,7 +100,9 @@ pub async fn key_is_authorized(
let authorized_user = AuthorizedKey::try_new(ip, user_data, referer, user_agent)?; let authorized_user = AuthorizedKey::try_new(ip, user_data, referer, user_agent)?;
Ok(AuthorizedRequest::User(authorized_user)) let db = app.db_conn.clone();
Ok(AuthorizedRequest::User(db, authorized_user))
} }
impl Web3ProxyApp { impl Web3ProxyApp {

View File

@ -12,6 +12,7 @@ use futures::StreamExt;
use parking_lot::RwLock; use parking_lot::RwLock;
use rand::Rng; use rand::Rng;
use redis_rate_limiter::{RedisPool, RedisRateLimitResult, RedisRateLimiter}; use redis_rate_limiter::{RedisPool, RedisRateLimitResult, RedisRateLimiter};
use sea_orm::DatabaseConnection;
use serde::ser::{SerializeStruct, Serializer}; use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize; use serde::Serialize;
use std::cmp::min; use std::cmp::min;
@ -52,6 +53,7 @@ pub struct Web3Connection {
// TODO: async lock? // TODO: async lock?
pub(super) head_block_id: RwLock<Option<BlockId>>, pub(super) head_block_id: RwLock<Option<BlockId>>,
pub(super) open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, pub(super) open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
pub(super) db_conn: Option<DatabaseConnection>,
} }
impl Web3Connection { impl Web3Connection {
@ -76,6 +78,7 @@ impl Web3Connection {
reconnect: bool, reconnect: bool,
weight: u32, weight: u32,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> { ) -> anyhow::Result<(Arc<Web3Connection>, AnyhowJoinHandle<()>)> {
let hard_limit = hard_limit.map(|(hard_rate_limit, redis_pool)| { let hard_limit = hard_limit.map(|(hard_rate_limit, redis_pool)| {
// TODO: is cache size 1 okay? i think we need // TODO: is cache size 1 okay? i think we need
@ -101,6 +104,7 @@ impl Web3Connection {
head_block_id: RwLock::new(Default::default()), head_block_id: RwLock::new(Default::default()),
weight, weight,
open_request_handle_metrics, open_request_handle_metrics,
db_conn,
}; };
let new_connection = Arc::new(new_connection); let new_connection = Arc::new(new_connection);

View File

@ -20,6 +20,7 @@ use futures::StreamExt;
use hashbrown::HashMap; use hashbrown::HashMap;
use moka::future::{Cache, ConcurrentCacheExt}; use moka::future::{Cache, ConcurrentCacheExt};
use petgraph::graphmap::DiGraphMap; use petgraph::graphmap::DiGraphMap;
use sea_orm::DatabaseConnection;
use serde::ser::{SerializeStruct, Serializer}; use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize; use serde::Serialize;
use serde_json::value::RawValue; use serde_json::value::RawValue;
@ -68,6 +69,7 @@ impl Web3Connections {
pending_tx_sender: Option<broadcast::Sender<TxStatus>>, pending_tx_sender: Option<broadcast::Sender<TxStatus>>,
pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>, pending_transactions: Cache<TxHash, TxStatus, hashbrown::hash_map::DefaultHashBuilder>,
open_request_handle_metrics: Arc<OpenRequestHandleMetrics>, open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
db_conn: Option<DatabaseConnection>,
) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> { ) -> anyhow::Result<(Arc<Self>, AnyhowJoinHandle<()>)> {
let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded(); let (pending_tx_id_sender, pending_tx_id_receiver) = flume::unbounded();
let (block_sender, block_receiver) = flume::unbounded::<BlockAndRpc>(); let (block_sender, block_receiver) = flume::unbounded::<BlockAndRpc>();
@ -125,6 +127,7 @@ impl Web3Connections {
let pending_tx_id_sender = Some(pending_tx_id_sender.clone()); let pending_tx_id_sender = Some(pending_tx_id_sender.clone());
let block_map = block_map.clone(); let block_map = block_map.clone();
let open_request_handle_metrics = open_request_handle_metrics.clone(); let open_request_handle_metrics = open_request_handle_metrics.clone();
let db_conn = db_conn.clone();
tokio::spawn(async move { tokio::spawn(async move {
server_config server_config
@ -138,6 +141,7 @@ impl Web3Connections {
block_sender, block_sender,
pending_tx_id_sender, pending_tx_id_sender,
open_request_handle_metrics, open_request_handle_metrics,
db_conn,
) )
.await .await
}) })

View File

@ -7,6 +7,7 @@ use metered::metered;
use metered::HitCount; use metered::HitCount;
use metered::ResponseTime; use metered::ResponseTime;
use metered::Throughput; use metered::Throughput;
use rand::Rng;
use std::fmt; use std::fmt;
use std::sync::atomic::{self, AtomicBool, Ordering}; use std::sync::atomic::{self, AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
@ -77,7 +78,10 @@ impl OpenRequestHandle {
let metrics = conn.open_request_handle_metrics.clone(); let metrics = conn.open_request_handle_metrics.clone();
let used = false.into(); let used = false.into();
let authorization = authorization.unwrap_or_else(|| Arc::new(AuthorizedRequest::Internal)); let authorization = authorization.unwrap_or_else(|| {
let db_conn = conn.db_conn.clone();
Arc::new(AuthorizedRequest::Internal(db_conn))
});
Self { Self {
authorization, authorization,
@ -156,43 +160,46 @@ impl OpenRequestHandle {
// TODO: only set SaveReverts if this is an eth_call or eth_estimateGas? we'll need eth_sendRawTransaction somewhere else // TODO: only set SaveReverts if this is an eth_call or eth_estimateGas? we'll need eth_sendRawTransaction somewhere else
// TODO: logging every one is going to flood the database // TODO: logging every one is going to flood the database
// TODO: have a percent chance to do this. or maybe a "logged reverts per second" // TODO: have a percent chance to do this. or maybe a "logged reverts per second"
if let ProviderError::JsonRpcClientError(err) = err { if save_chance == 1.0 || rand::thread_rng().gen_range(0.0..=1.0) <= save_chance
match provider { {
Web3Provider::Http(_) => { if let ProviderError::JsonRpcClientError(err) = err {
if let Some(HttpClientError::JsonRpcError(err)) = match provider {
err.downcast_ref::<HttpClientError>() Web3Provider::Http(_) => {
{ if let Some(HttpClientError::JsonRpcError(err)) =
if err.message.starts_with("execution reverted") { err.downcast_ref::<HttpClientError>()
debug!(%method, ?params, "TODO: save the request"); {
if err.message.starts_with("execution reverted") {
debug!(%method, ?params, "TODO: save the request");
let f = self let f = self
.authorization .authorization
.clone() .clone()
.save_revert(method.to_string(), params); .save_revert(method.to_string(), params);
tokio::spawn(async move { f.await }); tokio::spawn(async move { f.await });
// TODO: don't do this on the hot path. spawn it // TODO: don't do this on the hot path. spawn it
} else { } else {
debug!(?err, %method, rpc=%self.conn, "bad response!"); debug!(?err, %method, rpc=%self.conn, "bad response!");
}
} }
} }
} Web3Provider::Ws(_) => {
Web3Provider::Ws(_) => { if let Some(WsClientError::JsonRpcError(err)) =
if let Some(WsClientError::JsonRpcError(err)) = err.downcast_ref::<WsClientError>()
err.downcast_ref::<WsClientError>() {
{ if err.message.starts_with("execution reverted") {
if err.message.starts_with("execution reverted") { debug!(%method, ?params, "TODO: save the request");
debug!(%method, ?params, "TODO: save the request");
let f = self let f = self
.authorization .authorization
.clone() .clone()
.save_revert(method.to_string(), params); .save_revert(method.to_string(), params);
tokio::spawn(async move { f.await }); tokio::spawn(async move { f.await });
} else { } else {
debug!(?err, %method, rpc=%self.conn, "bad response!"); debug!(?err, %method, rpc=%self.conn, "bad response!");
}
} }
} }
} }