From a3d080361873cea3d244d360c2931ab61f9ff3b8 Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Thu, 20 Oct 2022 06:17:20 +0000 Subject: [PATCH] DRYer user queries --- web3_proxy/src/app.rs | 10 +- web3_proxy/src/bin/web3_proxy_cli/main.rs | 4 +- web3_proxy/src/frontend/authorization.rs | 10 +- web3_proxy/src/frontend/mod.rs | 1 + web3_proxy/src/frontend/users.rs | 258 ++---------------- web3_proxy/src/lib.rs | 2 +- .../src/{user_stats.rs => user_queries.rs} | 225 +++++++++++++-- 7 files changed, 232 insertions(+), 278 deletions(-) rename web3_proxy/src/{user_stats.rs => user_queries.rs} (66%) diff --git a/web3_proxy/src/app.rs b/web3_proxy/src/app.rs index 6aefc808..627c7170 100644 --- a/web3_proxy/src/app.rs +++ b/web3_proxy/src/app.rs @@ -159,12 +159,12 @@ pub async fn get_migrated_db( .sqlx_logging(false); // .sqlx_logging_level(log::LevelFilter::Info); - let db = sea_orm::Database::connect(db_opt).await?; + let db_conn = sea_orm::Database::connect(db_opt).await?; // TODO: if error, roll back? - Migrator::up(&db, None).await?; + Migrator::up(&db_conn, None).await?; - Ok(db) + Ok(db_conn) } #[metered(registry = Web3ProxyAppMetrics, registry_expr = self.app_metrics, visibility = pub)] @@ -202,9 +202,9 @@ impl Web3ProxyApp { .db_max_connections .unwrap_or(db_min_connections * 2); - let db = get_migrated_db(db_url, db_min_connections, db_max_connections).await?; + let db_conn = get_migrated_db(db_url, db_min_connections, db_max_connections).await?; - Some(db) + Some(db_conn) } else { info!("no database"); None diff --git a/web3_proxy/src/bin/web3_proxy_cli/main.rs b/web3_proxy/src/bin/web3_proxy_cli/main.rs index 7182b6f9..4cdb3593 100644 --- a/web3_proxy/src/bin/web3_proxy_cli/main.rs +++ b/web3_proxy/src/bin/web3_proxy_cli/main.rs @@ -50,9 +50,9 @@ async fn main() -> anyhow::Result<()> { match cli_config.sub_command { SubCommand::CreateUser(x) => { - let db = get_migrated_db(cli_config.db_url, 1, 1).await?; + let db_conn = get_migrated_db(cli_config.db_url, 1, 1).await?; - x.main(&db).await + x.main(&db_conn).await } SubCommand::CheckConfig(x) => x.main().await, } diff --git a/web3_proxy/src/frontend/authorization.rs b/web3_proxy/src/frontend/authorization.rs index 5a969bbf..2a7c0e9c 100644 --- a/web3_proxy/src/frontend/authorization.rs +++ b/web3_proxy/src/frontend/authorization.rs @@ -296,7 +296,7 @@ pub async fn bearer_is_authorized( .context("fetching user by id")? .context("unknown user id")?; - todo!("api_key is wrong. we should check user ids instead") + todo!("rewrite this. key_is_authorized is wrong. we should check user ids instead") // key_is_authorized( // app, // user_key_data.api_key.into(), @@ -348,9 +348,9 @@ pub async fn key_is_authorized( let authorized_user = AuthorizedKey::try_new(ip, origin, referer, user_agent, user_data)?; - let db = app.db_conn.clone(); + let db_conn = app.db_conn.clone(); - Ok((AuthorizedRequest::User(db, authorized_user), semaphore)) + Ok((AuthorizedRequest::User(db_conn, authorized_user), semaphore)) } impl Web3ProxyApp { @@ -479,7 +479,7 @@ impl Web3ProxyApp { .try_get_with(user_key.into(), async move { trace!(?user_key, "user_cache miss"); - let db = self.db_conn().context("Getting database connection")?; + let db_conn = self.db_conn().context("Getting database connection")?; let user_uuid: Uuid = user_key.into(); @@ -487,7 +487,7 @@ impl Web3ProxyApp { match user_keys::Entity::find() .filter(user_keys::Column::ApiKey.eq(user_uuid)) .filter(user_keys::Column::Active.eq(true)) - .one(&db) + .one(&db_conn) .await? { Some(user_key_model) => { diff --git a/web3_proxy/src/frontend/mod.rs b/web3_proxy/src/frontend/mod.rs index 3a1ead80..06da07c9 100644 --- a/web3_proxy/src/frontend/mod.rs +++ b/web3_proxy/src/frontend/mod.rs @@ -73,6 +73,7 @@ pub async fn serve(port: u16, proxy_app: Arc) -> anyhow::Result<() .route("/user/balance/:txid", post(users::user_balance_post)) .route("/user/profile", get(users::user_profile_get)) .route("/user/profile", post(users::user_profile_post)) + .route("/user/revert_logs", get(users::user_revert_logs_get)) .route( "/user/stats/aggregate", get(users::user_stats_aggregate_get), diff --git a/web3_proxy/src/frontend/users.rs b/web3_proxy/src/frontend/users.rs index 7ca560a1..58ac42e2 100644 --- a/web3_proxy/src/frontend/users.rs +++ b/web3_proxy/src/frontend/users.rs @@ -3,7 +3,7 @@ use super::authorization::{login_is_authorized, UserKey}; use super::errors::FrontendResult; use crate::app::Web3ProxyApp; -use crate::user_stats::{get_aggregate_rpc_stats, get_detailed_stats}; +use crate::user_queries::{get_aggregate_rpc_stats_from_params, get_detailed_stats}; use anyhow::Context; use axum::{ extract::{Path, Query}, @@ -19,13 +19,14 @@ use ethers::{prelude::Address, types::Bytes}; use hashbrown::HashMap; use http::StatusCode; use redis_rate_limiter::redis::AsyncCommands; +use redis_rate_limiter::RedisConnection; use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, TransactionTrait}; use serde::{Deserialize, Serialize}; use siwe::{Message, VerificationOpts}; use std::ops::Add; use std::sync::Arc; use time::{Duration, OffsetDateTime}; -use tracing::{info, warn}; +use tracing::warn; use ulid::Ulid; /// `GET /user/login/:user_address` or `GET /user/login/:user_address/:message_eip` -- Start the "Sign In with Ethereum" (siwe) login flow. @@ -202,18 +203,18 @@ pub async fn user_login_post( let bearer_token = Ulid::new(); - let db = app.db_conn().context("Getting database connection")?; + let db_conn = app.db_conn().context("Getting database connection")?; // TODO: limit columns or load whole user? let u = user::Entity::find() .filter(user::Column::Address.eq(our_msg.address.as_ref())) - .one(&db) + .one(&db_conn) .await .unwrap(); let (u, _uks, response) = match u { None => { - let txn = db.begin().await?; + let txn = db_conn.begin().await?; // the only thing we need from them is an address // everything else is optional @@ -256,7 +257,7 @@ pub async fn user_login_post( // the user is already registered let uks = user_keys::Entity::find() .filter(user_keys::Column::UserId.eq(u.id)) - .all(&db) + .all(&db_conn) .await .context("failed loading user's key")?; @@ -343,9 +344,9 @@ pub async fn user_profile_post( } } - let db = app.db_conn().context("Getting database connection")?; + let db_conn = app.db_conn().context("Getting database connection")?; - user.save(&db).await?; + user.save(&db_conn).await?; todo!("finish post_user"); } @@ -416,6 +417,15 @@ pub async fn user_profile_get( todo!("user_profile_get"); } +/// `GET /user/revert_logs` -- Use a bearer token to get the user's revert logs. +#[debug_handler] +pub async fn user_revert_logs_get( + TypedHeader(Authorization(bearer_token)): TypedHeader>, + Extension(app): Extension>, +) -> FrontendResult { + todo!("user_revert_logs_get"); +} + /// `GET /user/stats/detailed` -- Use a bearer token to get the user's key stats such as bandwidth used and methods requested. /// /// If no bearer is provided, detailed stats for all users will be shown. @@ -433,120 +443,10 @@ pub async fn user_stats_detailed_get( Extension(app): Extension>, Query(params): Query>, ) -> FrontendResult { - // TODO: how is db_conn supposed to be used? - let db = app.db_conn.clone().context("connecting to db")?; + let db_conn = app.db_conn().context("connecting to db")?; + let redis_conn = app.redis_conn().await.context("connecting to redis")?; - // get the attached address from redis for the given auth_token. - let mut redis_conn = app.redis_conn().await.context("connecting to redis")?; - - // TODO: DRY - let user_id = match (bearer, params.get("user_id")) { - (Some(bearer), Some(params)) => { - // check for the bearer cache key - // TODO: move this to a helper function - let bearer_cache_key = format!("bearer:{}", bearer.token()); - - // get the user id that is attached to this bearer token - redis_conn - .get::<_, u64>(bearer_cache_key) - .await - // TODO: this should be a 403 - .context("fetching user_key_id from redis with bearer_cache_key")? - } - (_, None) => { - // they have a bearer token. we don't care about it on public pages - // 0 means all - 0 - } - (None, Some(x)) => { - // they do not have a bearer token, but requested a specific id. block - // TODO: proper error code - // TODO: maybe instead of this sharp edged warn, we have a config value? - warn!("this should maybe be an access denied"); - x.parse().context("Parsing user_id param")? - } - }; - - // only allow user_key to be set if user_id is also set - // this will keep people from reading someone else's keys - let user_key = if user_id > 0 { - params - .get("user_key") - .map_or_else::, _, _>( - || Ok(0), - |c| { - let c = c.parse()?; - - Ok(c) - }, - )? - } else { - 0 - }; - - // TODO: DRY - let chain_id = params - .get("chain_id") - .map_or_else::, _, _>( - || Ok(app.config.chain_id), - |c| { - let c = c.parse()?; - - Ok(c) - }, - )?; - - // TODO: DRY - let query_start = params - .get("timestamp") - .map_or_else::, _, _>( - || { - // no timestamp in params. set default - let x = chrono::Utc::now() - chrono::Duration::days(30); - - Ok(x.naive_utc()) - }, - |x: &String| { - // parse the given timestamp - let x = x.parse::().context("parsing timestamp query param")?; - - // TODO: error code 401 - let x = NaiveDateTime::from_timestamp_opt(x, 0) - .context("parsing timestamp query param")?; - - Ok(x) - }, - )?; - - let page = params - .get("page") - .map_or_else::, _, _>( - || { - // no page in params. set default - Ok(0) - }, - |x: &String| { - // parse the given timestamp - // TODO: error code 401 - let x = x.parse::().context("parsing page query param")?; - - Ok(x) - }, - )?; - - // TODO: page size from config - let page_size = 200; - - let x = get_detailed_stats( - chain_id, - &db, - page, - page_size, - query_start, - user_key, - user_id, - ) - .await?; + let x = get_detailed_stats(&app, bearer, db_conn, redis_conn, params).await?; Ok(Json(x).into_response()) } @@ -558,121 +458,7 @@ pub async fn user_stats_aggregate_get( Extension(app): Extension>, Query(params): Query>, ) -> FrontendResult { - // TODO: how is db_conn supposed to be used? - let db_conn = app.db_conn.clone().context("connecting to db")?; - - // get the attached address from redis for the given auth_token. - let mut redis_conn = app.redis_conn().await.context("connecting to redis")?; - - let user_id = match (bearer, params.get("user_id")) { - (Some(bearer), Some(params)) => { - // check for the bearer cache key - // TODO: move this to a helper function - let bearer_cache_key = format!("bearer:{}", bearer.token()); - - // get the user id that is attached to this bearer token - redis_conn - .get::<_, u64>(bearer_cache_key) - .await - // TODO: this should be a 403 - .context("fetching user_key_id from redis with bearer_cache_key")? - } - (_, None) => { - // they have a bearer token. we don't care about it on public pages - // 0 means all - 0 - } - (None, Some(x)) => { - // they do not have a bearer token, but requested a specific id. block - // TODO: proper error code - // TODO: maybe instead of this sharp edged warn, we have a config value? - warn!("this should maybe be an access denied"); - x.parse().context("Parsing user_id param")? - } - }; - - let chain_id = params - .get("chain_id") - .map_or_else::, _, _>( - || Ok(app.config.chain_id), - |c| { - let c = c.parse()?; - - Ok(c) - }, - )?; - - let query_start = params - .get("timestamp") - .map_or_else::, _, _>( - || { - // no timestamp in params. set default - let x = chrono::Utc::now() - chrono::Duration::days(30); - - Ok(x.naive_utc()) - }, - |x: &String| { - // parse the given timestamp - let x = x.parse::().context("parsing timestamp query param")?; - - // TODO: error code 401 - let x = NaiveDateTime::from_timestamp_opt(x, 0) - .context("parsing timestamp query param")?; - - Ok(x) - }, - )?; - - let query_window_seconds = params - .get("query_window_seconds") - .map_or_else::>, _, _>( - || { - // no page in params. set default - Ok(None) - }, - |x: &String| { - // parse the given timestamp - // TODO: error code 401 - let x = x.parse::().context("parsing page query param")?; - - if x == 0 { - Ok(None) - } else { - Ok(Some(x)) - } - }, - )?; - - let page = params - .get("page") - .map_or_else::, _, _>( - || { - // no page in params. set None - Ok(0) - }, - |x: &String| { - // parse the given timestamp - // TODO: error code 401 - let x = x.parse().context("parsing page query param")?; - - Ok(x) - }, - )?; - - // TODO: page size from config - let page_size = 200; - - // TODO: optionally no chain id? - let x = get_aggregate_rpc_stats( - chain_id, - &db_conn, - page, - page_size, - query_start, - query_window_seconds, - user_id, - ) - .await?; + let x = get_aggregate_rpc_stats_from_params(&app, bearer, params).await?; Ok(Json(x).into_response()) } diff --git a/web3_proxy/src/lib.rs b/web3_proxy/src/lib.rs index c9ca7910..28bfd96c 100644 --- a/web3_proxy/src/lib.rs +++ b/web3_proxy/src/lib.rs @@ -7,4 +7,4 @@ pub mod jsonrpc; pub mod metered; pub mod metrics_frontend; pub mod rpcs; -pub mod user_stats; +pub mod user_queries; diff --git a/web3_proxy/src/user_stats.rs b/web3_proxy/src/user_queries.rs similarity index 66% rename from web3_proxy/src/user_stats.rs rename to web3_proxy/src/user_queries.rs index 49585a70..ac3ccf34 100644 --- a/web3_proxy/src/user_stats.rs +++ b/web3_proxy/src/user_queries.rs @@ -1,24 +1,169 @@ use anyhow::Context; +use axum::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, +}; +use chrono::NaiveDateTime; use entities::{rpc_accounting, user_keys}; use hashbrown::HashMap; use migration::Expr; use num::Zero; +use redis_rate_limiter::{redis::AsyncCommands, RedisConnection}; use sea_orm::{ ColumnTrait, Condition, DatabaseConnection, EntityTrait, JoinType, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, RelationTrait, }; use tracing::trace; +use crate::app::Web3ProxyApp; + +/// get the attached address from redis for the given auth_token. +/// 0 means all users +async fn get_user_from_params( + mut redis_conn: RedisConnection, + // this is a long type. should we strip it down? + bearer: Option>>, + params: &HashMap, +) -> anyhow::Result { + match (bearer, params.get("user_id")) { + (Some(bearer), Some(user_id)) => { + // check for the bearer cache key + // TODO: move this to a helper function + let bearer_cache_key = format!("bearer:{}", bearer.token()); + + // get the user id that is attached to this bearer token + redis_conn + .get::<_, u64>(bearer_cache_key) + .await + // TODO: this should be a 403 + .context("fetching user_key_id from redis with bearer_cache_key") + } + (_, None) => { + // they have a bearer token. we don't care about it on public pages + // 0 means all + Ok(0) + } + (None, Some(x)) => { + // they do not have a bearer token, but requested a specific id. block + // TODO: proper error code + // TODO: maybe instead of this sharp edged warn, we have a config value? + // TODO: check config for if we should deny or allow this + x.parse().context("Parsing user_id param") + } + } +} + +/// only allow user_key to be set if user_id is also set. +/// this will keep people from reading someone else's keys. +/// 0 means none. +fn get_user_key_from_params(user_id: u64, params: &HashMap) -> anyhow::Result { + if user_id > 0 { + params.get("user_key").map_or_else( + || Ok(0), + |c| { + let c = c.parse()?; + + Ok(c) + }, + ) + } else { + Ok(0) + } +} + +fn get_chain_id_from_params( + app: &Web3ProxyApp, + params: &HashMap, +) -> anyhow::Result { + params.get("chain_id").map_or_else( + || Ok(app.config.chain_id), + |c| { + let c = c.parse()?; + + Ok(c) + }, + ) +} + +fn get_query_start_from_params( + params: &HashMap, +) -> anyhow::Result { + params.get("query_start").map_or_else( + || { + // no timestamp in params. set default + let x = chrono::Utc::now() - chrono::Duration::days(30); + + Ok(x.naive_utc()) + }, + |x: &String| { + // parse the given timestamp + let x = x.parse::().context("parsing timestamp query param")?; + + // TODO: error code 401 + let x = + NaiveDateTime::from_timestamp_opt(x, 0).context("parsing timestamp query param")?; + + Ok(x) + }, + ) +} + +fn get_page_from_params(params: &HashMap) -> anyhow::Result { + params + .get("page") + .map_or_else::, _, _>( + || { + // no page in params. set default + Ok(0) + }, + |x: &String| { + // parse the given timestamp + // TODO: error code 401 + let x = x.parse().context("parsing page query from params")?; + + Ok(x) + }, + ) +} + +fn get_query_window_seconds_from_params(params: &HashMap) -> anyhow::Result { + params.get("query_window_seconds").map_or_else( + || { + // no page in params. set default + Ok(0) + }, + |x: &String| { + // parse the given timestamp + // TODO: error code 401 + let x = x + .parse() + .context("parsing query window seconds from params")?; + + Ok(x) + }, + ) +} + /// stats aggregated across a large time period -pub async fn get_aggregate_rpc_stats( - chain_id: u64, - db_conn: &DatabaseConnection, - page: usize, - page_size: usize, - query_start: chrono::NaiveDateTime, - query_window_seconds: Option, - user_id: u64, +pub async fn get_aggregate_rpc_stats_from_params( + app: &Web3ProxyApp, + bearer: Option>>, + params: HashMap, ) -> anyhow::Result> { + let db_conn = app.db_conn().context("connecting to db")?; + let redis_conn = app.redis_conn().await.context("connecting to redis")?; + + let user_id = get_user_from_params(redis_conn, bearer, ¶ms).await?; + let chain_id = get_chain_id_from_params(app, ¶ms)?; + let query_start = get_query_start_from_params(¶ms)?; + let page = get_page_from_params(¶ms)?; + let query_window_seconds = get_query_window_seconds_from_params(¶ms)?; + + // TODO: warn if unknown fields in params + + // TODO: page size from config + let page_size = 200; + trace!(?chain_id, %query_start, ?user_id, "get_aggregate_stats"); // TODO: minimum query_start of 90 days? @@ -28,7 +173,10 @@ pub async fn get_aggregate_rpc_stats( response.insert("page", serde_json::to_value(page)?); response.insert("page_size", serde_json::to_value(page_size)?); response.insert("chain_id", serde_json::to_value(chain_id)?); - response.insert("query_start", serde_json::to_value(query_start)?); + response.insert( + "query_start", + serde_json::to_value(query_start.timestamp())?, + ); // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? @@ -62,10 +210,15 @@ pub async fn get_aggregate_rpc_stats( ) .order_by_asc(rpc_accounting::Column::PeriodDatetime.min()); - let q = if let Some(query_window_seconds) = query_window_seconds { - debug_assert_ne!(query_window_seconds, 0); - + let q = if query_window_seconds != 0 { + /* + let query_start_timestamp: u64 = query_start + .timestamp() + .try_into() + .context("query_start to timestamp")?; + */ // TODO: is there a better way to do this? how can we get "period_datetime" into this with types? + // TODO: how can we get the first window to start at query_start_timestamp let expr = Expr::cust_with_values( "FLOOR(UNIX_TIMESTAMP(rpc_accounting.period_datetime) / ?) * ?", [query_window_seconds, query_window_seconds], @@ -130,11 +283,11 @@ pub async fn get_aggregate_rpc_stats( let aggregate = q .into_json() - .paginate(db_conn, page_size) + .paginate(&db_conn, page_size) .fetch_page(page) .await?; - response.insert("aggregrate", serde_json::Value::Array(aggregate)); + response.insert("aggregate", serde_json::Value::Array(aggregate)); Ok(response) } @@ -147,16 +300,20 @@ pub async fn get_user_stats(chain_id: u64) -> u64 { /// /// TODO: take a "timebucket" duration in minutes that will make a more advanced pub async fn get_detailed_stats( - chain_id: u64, - db_conn: &DatabaseConnection, - page: usize, - page_size: usize, - query_start: chrono::NaiveDateTime, - user_key_id: u64, - user_id: u64, + app: &Web3ProxyApp, + bearer: Option>>, + db_conn: DatabaseConnection, + redis_conn: RedisConnection, + params: HashMap, ) -> anyhow::Result> { - // aggregate stats, but grouped by method and error - trace!(?chain_id, %query_start, ?user_id, "get_aggregate_stats"); + let user_id = get_user_from_params(redis_conn, bearer, ¶ms).await?; + let user_key = get_user_key_from_params(user_id, ¶ms)?; + let chain_id = get_chain_id_from_params(app, ¶ms)?; + let query_start = get_query_start_from_params(¶ms)?; + let page = get_page_from_params(¶ms)?; + + // TODO: page size from config + let page_size = 200; // TODO: minimum query_start of 90 days? @@ -165,7 +322,10 @@ pub async fn get_detailed_stats( response.insert("page", serde_json::to_value(page)?); response.insert("page_size", serde_json::to_value(page_size)?); response.insert("chain_id", serde_json::to_value(chain_id)?); - response.insert("query_start", serde_json::to_value(query_start)?); + response.insert( + "query_start", + serde_json::to_value(query_start.timestamp())?, + ); // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? @@ -254,7 +414,7 @@ pub async fn get_detailed_stats( // TODO: transform this into a nested hashmap instead of a giant table? let r = q .into_json() - .paginate(db_conn, page_size) + .paginate(&db_conn, page_size) .fetch_page(page) .await?; @@ -277,19 +437,22 @@ pub async fn get_revert_logs( page_size: usize, query_start: chrono::NaiveDateTime, user_id: u64, - key_id: u64, + user_key_id: u64, ) -> anyhow::Result> { // aggregate stats, but grouped by method and error trace!(?chain_id, %query_start, ?user_id, "get_aggregate_stats"); - // TODO: minimum query_start of 90 days? + // TODO: minimum query_start of 90 days ago? let mut response = HashMap::new(); response.insert("page", serde_json::to_value(page)?); response.insert("page_size", serde_json::to_value(page_size)?); response.insert("chain_id", serde_json::to_value(chain_id)?); - response.insert("query_start", serde_json::to_value(query_start)?); + response.insert( + "query_start", + serde_json::to_value(query_start.timestamp())?, + ); // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? @@ -346,7 +509,7 @@ pub async fn get_revert_logs( (condition, q) }; - let (condition, q) = if user_id.is_zero() { + let (condition, q) = if user_id == 0 { // 0 means everyone. don't filter on user (condition, q) } else { @@ -365,6 +528,10 @@ pub async fn get_revert_logs( let condition = condition.add(user_keys::Column::UserId.eq(user_id)); + if user_key_id != 0 { + todo!("wip"); + } + (condition, q) };