From 1da730daa2d3c6a62ce19c0d64dab6133d5b9c5c Mon Sep 17 00:00:00 2001 From: Bryan Stitt Date: Thu, 20 Oct 2022 06:54:45 +0000 Subject: [PATCH] less unwrap --- web3_proxy/src/frontend/users.rs | 28 ++++---- web3_proxy/src/rpcs/connection.rs | 2 +- web3_proxy/src/user_queries.rs | 102 ++++++++++++++++++++---------- 3 files changed, 84 insertions(+), 48 deletions(-) diff --git a/web3_proxy/src/frontend/users.rs b/web3_proxy/src/frontend/users.rs index 58ac42e2..3e1515f9 100644 --- a/web3_proxy/src/frontend/users.rs +++ b/web3_proxy/src/frontend/users.rs @@ -13,20 +13,17 @@ use axum::{ }; use axum_client_ip::ClientIp; use axum_macros::debug_handler; -use chrono::NaiveDateTime; use entities::{user, user_keys}; use ethers::{prelude::Address, types::Bytes}; use hashbrown::HashMap; use http::StatusCode; use redis_rate_limiter::redis::AsyncCommands; -use redis_rate_limiter::RedisConnection; use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, TransactionTrait}; use serde::{Deserialize, Serialize}; use siwe::{Message, VerificationOpts}; use std::ops::Add; use std::sync::Arc; use time::{Duration, OffsetDateTime}; -use tracing::warn; use ulid::Ulid; /// `GET /user/login/:user_address` or `GET /user/login/:user_address/:message_eip` -- Start the "Sign In with Ethereum" (siwe) login flow. @@ -60,8 +57,7 @@ pub async fn user_login_get( // create a message and save it in redis // TODO: how many seconds? get from config? - // TODO: while developing, we put a giant number here - let expire_seconds: usize = 28800; + let expire_seconds: usize = 20 * 60; let nonce = Ulid::new(); @@ -94,13 +90,12 @@ pub async fn user_login_get( resources: vec![], }; - let session_key = format!("pending:{}", nonce); - // TODO: if no redis server, store in local cache? at least give a better error. right now this seems to be a 502 // the address isn't enough. we need to save the actual message so we can read the nonce // TODO: what message format is the most efficient to store in redis? probably eip191_bytes // we add 1 to expire_seconds just to be sure redis has the key for the full expiration_time // TODO: store a maximum number of attempted logins? anyone can request so we don't want to allow DOS attacks + let session_key = format!("login_nonce:{}", nonce); app.redis_conn() .await? .set_ex(session_key, message.to_string(), expire_seconds + 1) @@ -179,15 +174,24 @@ pub async fn user_login_post( } // we can't trust that they didn't tamper with the message in some way - let their_msg: siwe::Message = payload.msg.parse().unwrap(); + let their_msg: siwe::Message = payload.msg.parse().context("parsing user's message")?; - let their_sig: [u8; 65] = payload.sig.as_ref().try_into().unwrap(); + let their_sig: [u8; 65] = payload + .sig + .as_ref() + .try_into() + .context("parsing signature")?; + + // TODO: this is fragile + let login_nonce_key = format!("login_nonce:{}", &their_msg.nonce); // fetch the message we gave them from our redis // TODO: use getdel - let our_msg: String = app.redis_conn().await?.get(&their_msg.nonce).await?; + let our_msg: Option = app.redis_conn().await?.get(&login_nonce_key).await?; - let our_msg: siwe::Message = our_msg.parse().unwrap(); + let our_msg: String = our_msg.context("login nonce not found")?; + + let our_msg: siwe::Message = our_msg.parse().context("parsing siwe message")?; let verify_config = VerificationOpts { domain: Some(our_msg.domain), @@ -446,7 +450,7 @@ pub async fn user_stats_detailed_get( let db_conn = app.db_conn().context("connecting to db")?; let redis_conn = app.redis_conn().await.context("connecting to redis")?; - let x = get_detailed_stats(&app, bearer, db_conn, redis_conn, params).await?; + let x = get_detailed_stats(&app, bearer, params).await?; Ok(Json(x).into_response()) } diff --git a/web3_proxy/src/rpcs/connection.rs b/web3_proxy/src/rpcs/connection.rs index d32f8c84..9186426f 100644 --- a/web3_proxy/src/rpcs/connection.rs +++ b/web3_proxy/src/rpcs/connection.rs @@ -130,7 +130,7 @@ impl Web3Connection { // TODO: there has to be a cleaner way to do this if chain_id != found_chain_id.as_u64() { return Err(anyhow::anyhow!( - "incorrect chain id! Expected {}. Found {}", + "incorrect chain id! Config has {}, but RPC has {}", chain_id, found_chain_id ) diff --git a/web3_proxy/src/user_queries.rs b/web3_proxy/src/user_queries.rs index ac3ccf34..dfd9bb5c 100644 --- a/web3_proxy/src/user_queries.rs +++ b/web3_proxy/src/user_queries.rs @@ -10,8 +10,8 @@ use migration::Expr; use num::Zero; use redis_rate_limiter::{redis::AsyncCommands, RedisConnection}; use sea_orm::{ - ColumnTrait, Condition, DatabaseConnection, EntityTrait, JoinType, PaginatorTrait, QueryFilter, - QueryOrder, QuerySelect, RelationTrait, + ColumnTrait, Condition, EntityTrait, JoinType, PaginatorTrait, QueryFilter, QueryOrder, + QuerySelect, RelationTrait, }; use tracing::trace; @@ -19,7 +19,7 @@ use crate::app::Web3ProxyApp; /// get the attached address from redis for the given auth_token. /// 0 means all users -async fn get_user_from_params( +async fn get_user_id_from_params( mut redis_conn: RedisConnection, // this is a long type. should we strip it down? bearer: Option>>, @@ -56,7 +56,10 @@ async fn get_user_from_params( /// only allow user_key to be set if user_id is also set. /// this will keep people from reading someone else's keys. /// 0 means none. -fn get_user_key_from_params(user_id: u64, params: &HashMap) -> anyhow::Result { +fn get_user_key_id_from_params( + user_id: u64, + params: &HashMap, +) -> anyhow::Result { if user_id > 0 { params.get("user_key").map_or_else( || Ok(0), @@ -153,11 +156,11 @@ pub async fn get_aggregate_rpc_stats_from_params( let db_conn = app.db_conn().context("connecting to db")?; let redis_conn = app.redis_conn().await.context("connecting to redis")?; - let user_id = get_user_from_params(redis_conn, bearer, ¶ms).await?; + let user_id = get_user_id_from_params(redis_conn, bearer, ¶ms).await?; let chain_id = get_chain_id_from_params(app, ¶ms)?; let query_start = get_query_start_from_params(¶ms)?; - let page = get_page_from_params(¶ms)?; let query_window_seconds = get_query_window_seconds_from_params(¶ms)?; + let page = get_page_from_params(¶ms)?; // TODO: warn if unknown fields in params @@ -175,9 +178,15 @@ pub async fn get_aggregate_rpc_stats_from_params( response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert( "query_start", - serde_json::to_value(query_start.timestamp())?, + serde_json::to_value(query_start.timestamp() as u64)?, ); + if query_window_seconds != 0 { + response.insert( + "query_window_seconds", + serde_json::to_value(query_window_seconds)?, + ); + } // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? let q = rpc_accounting::Entity::find() @@ -210,6 +219,7 @@ pub async fn get_aggregate_rpc_stats_from_params( ) .order_by_asc(rpc_accounting::Column::PeriodDatetime.min()); + // TODO: DRYer let q = if query_window_seconds != 0 { /* let query_start_timestamp: u64 = query_start @@ -233,7 +243,7 @@ pub async fn get_aggregate_rpc_stats_from_params( .group_by(Expr::cust("query_window")) } else { // TODO: order by more than this? - // query_window_seconds + // query_window_seconds is not set so we aggregate all records q }; @@ -297,20 +307,21 @@ pub async fn get_user_stats(chain_id: u64) -> u64 { } /// stats grouped by key_id and error_repsponse and method and key -/// -/// TODO: take a "timebucket" duration in minutes that will make a more advanced pub async fn get_detailed_stats( app: &Web3ProxyApp, bearer: Option>>, - db_conn: DatabaseConnection, - redis_conn: RedisConnection, params: HashMap, ) -> anyhow::Result> { - let user_id = get_user_from_params(redis_conn, bearer, ¶ms).await?; - let user_key = get_user_key_from_params(user_id, ¶ms)?; + let db_conn = app.db_conn().context("connecting to db")?; + let redis_conn = app.redis_conn().await.context("connecting to redis")?; + + let user_id = get_user_id_from_params(redis_conn, bearer, ¶ms).await?; + let user_key_id = get_user_key_id_from_params(user_id, ¶ms)?; let chain_id = get_chain_id_from_params(app, ¶ms)?; let query_start = get_query_start_from_params(¶ms)?; + let query_window_seconds = get_query_window_seconds_from_params(¶ms)?; let page = get_page_from_params(¶ms)?; + // TODO: handle secondary users, too // TODO: page size from config let page_size = 200; @@ -324,9 +335,16 @@ pub async fn get_detailed_stats( response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert( "query_start", - serde_json::to_value(query_start.timestamp())?, + serde_json::to_value(query_start.timestamp() as u64)?, ); + if query_window_seconds != 0 { + response.insert( + "query_window_seconds", + serde_json::to_value(query_window_seconds)?, + ); + } + // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? let q = rpc_accounting::Entity::find() @@ -382,7 +400,7 @@ pub async fn get_detailed_stats( (condition, q) }; - let (condition, q) = if user_id.is_zero() { + let (condition, q) = if user_id == 0 { // 0 means everyone. don't filter on user (condition, q) } else { @@ -394,21 +412,26 @@ pub async fn get_detailed_stats( rpc_accounting::Relation::UserKeys.def(), ) .column(user_keys::Column::UserId) - // no need to group_by user_id when we are grouping by key_id - // .group_by(user_keys::Column::UserId) - .column(user_keys::Column::Id) - .group_by(user_keys::Column::Id); + .group_by(user_keys::Column::UserId); let condition = condition.add(user_keys::Column::UserId.eq(user_id)); + let q = if user_key_id == 0 { + q.column(user_keys::Column::UserId) + .group_by(user_keys::Column::UserId) + } else { + response.insert("user_key_id", serde_json::to_value(user_key_id)?); + + // no need to group_by user_id when we are grouping by key_id + q.column(user_keys::Column::Id) + .group_by(user_keys::Column::Id) + }; + (condition, q) }; let q = q.filter(condition); - // TODO: enum between searching on user_key_id on user_id - // TODO: handle secondary users, too - // log query here. i think sea orm has a useful log level for this // TODO: transform this into a nested hashmap instead of a giant table? @@ -431,18 +454,20 @@ pub async fn get_detailed_stats( /// /// TODO: take a "timebucket" duration in minutes that will make a more advanced pub async fn get_revert_logs( - chain_id: u64, - db_conn: &DatabaseConnection, - page: usize, - page_size: usize, - query_start: chrono::NaiveDateTime, - user_id: u64, - user_key_id: u64, + app: &Web3ProxyApp, + bearer: Option>>, + params: HashMap, ) -> anyhow::Result> { - // aggregate stats, but grouped by method and error - trace!(?chain_id, %query_start, ?user_id, "get_aggregate_stats"); + let db_conn = app.db_conn().context("connecting to db")?; + let redis_conn = app.redis_conn().await.context("connecting to redis")?; - // TODO: minimum query_start of 90 days ago? + let user_id = get_user_id_from_params(redis_conn, bearer, ¶ms).await?; + let user_key_id = get_user_key_id_from_params(user_id, ¶ms)?; + let chain_id = get_chain_id_from_params(app, ¶ms)?; + let query_start = get_query_start_from_params(¶ms)?; + let query_window_seconds = get_query_window_seconds_from_params(¶ms)?; + let page = get_page_from_params(¶ms)?; + let page_size = get_page_from_params(¶ms)?; let mut response = HashMap::new(); @@ -451,9 +476,16 @@ pub async fn get_revert_logs( response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert( "query_start", - serde_json::to_value(query_start.timestamp())?, + serde_json::to_value(query_start.timestamp() as u64)?, ); + if query_window_seconds != 0 { + response.insert( + "query_window_seconds", + serde_json::to_value(query_window_seconds)?, + ); + } + // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we count uptime? let q = rpc_accounting::Entity::find() @@ -545,7 +577,7 @@ pub async fn get_revert_logs( // TODO: transform this into a nested hashmap instead of a giant table? let r = q .into_json() - .paginate(db_conn, page_size) + .paginate(&db_conn, page_size) .fetch_page(page) .await?;