less unwrap

This commit is contained in:
Bryan Stitt 2022-10-20 06:54:45 +00:00
parent a3d0803618
commit 1da730daa2
3 changed files with 84 additions and 48 deletions

@ -13,20 +13,17 @@ use axum::{
}; };
use axum_client_ip::ClientIp; use axum_client_ip::ClientIp;
use axum_macros::debug_handler; use axum_macros::debug_handler;
use chrono::NaiveDateTime;
use entities::{user, user_keys}; use entities::{user, user_keys};
use ethers::{prelude::Address, types::Bytes}; use ethers::{prelude::Address, types::Bytes};
use hashbrown::HashMap; use hashbrown::HashMap;
use http::StatusCode; use http::StatusCode;
use redis_rate_limiter::redis::AsyncCommands; use redis_rate_limiter::redis::AsyncCommands;
use redis_rate_limiter::RedisConnection;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, TransactionTrait}; use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, TransactionTrait};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use siwe::{Message, VerificationOpts}; use siwe::{Message, VerificationOpts};
use std::ops::Add; use std::ops::Add;
use std::sync::Arc; use std::sync::Arc;
use time::{Duration, OffsetDateTime}; use time::{Duration, OffsetDateTime};
use tracing::warn;
use ulid::Ulid; use ulid::Ulid;
/// `GET /user/login/:user_address` or `GET /user/login/:user_address/:message_eip` -- Start the "Sign In with Ethereum" (siwe) login flow. /// `GET /user/login/:user_address` or `GET /user/login/:user_address/:message_eip` -- Start the "Sign In with Ethereum" (siwe) login flow.
@ -60,8 +57,7 @@ pub async fn user_login_get(
// create a message and save it in redis // create a message and save it in redis
// TODO: how many seconds? get from config? // TODO: how many seconds? get from config?
// TODO: while developing, we put a giant number here let expire_seconds: usize = 20 * 60;
let expire_seconds: usize = 28800;
let nonce = Ulid::new(); let nonce = Ulid::new();
@ -94,13 +90,12 @@ pub async fn user_login_get(
resources: vec![], resources: vec![],
}; };
let session_key = format!("pending:{}", nonce);
// TODO: if no redis server, store in local cache? at least give a better error. right now this seems to be a 502 // TODO: if no redis server, store in local cache? at least give a better error. right now this seems to be a 502
// the address isn't enough. we need to save the actual message so we can read the nonce // the address isn't enough. we need to save the actual message so we can read the nonce
// TODO: what message format is the most efficient to store in redis? probably eip191_bytes // TODO: what message format is the most efficient to store in redis? probably eip191_bytes
// we add 1 to expire_seconds just to be sure redis has the key for the full expiration_time // we add 1 to expire_seconds just to be sure redis has the key for the full expiration_time
// TODO: store a maximum number of attempted logins? anyone can request so we don't want to allow DOS attacks // TODO: store a maximum number of attempted logins? anyone can request so we don't want to allow DOS attacks
let session_key = format!("login_nonce:{}", nonce);
app.redis_conn() app.redis_conn()
.await? .await?
.set_ex(session_key, message.to_string(), expire_seconds + 1) .set_ex(session_key, message.to_string(), expire_seconds + 1)
@ -179,15 +174,24 @@ pub async fn user_login_post(
} }
// we can't trust that they didn't tamper with the message in some way // we can't trust that they didn't tamper with the message in some way
let their_msg: siwe::Message = payload.msg.parse().unwrap(); let their_msg: siwe::Message = payload.msg.parse().context("parsing user's message")?;
let their_sig: [u8; 65] = payload.sig.as_ref().try_into().unwrap(); let their_sig: [u8; 65] = payload
.sig
.as_ref()
.try_into()
.context("parsing signature")?;
// TODO: this is fragile
let login_nonce_key = format!("login_nonce:{}", &their_msg.nonce);
// fetch the message we gave them from our redis // fetch the message we gave them from our redis
// TODO: use getdel // TODO: use getdel
let our_msg: String = app.redis_conn().await?.get(&their_msg.nonce).await?; let our_msg: Option<String> = app.redis_conn().await?.get(&login_nonce_key).await?;
let our_msg: siwe::Message = our_msg.parse().unwrap(); let our_msg: String = our_msg.context("login nonce not found")?;
let our_msg: siwe::Message = our_msg.parse().context("parsing siwe message")?;
let verify_config = VerificationOpts { let verify_config = VerificationOpts {
domain: Some(our_msg.domain), domain: Some(our_msg.domain),
@ -446,7 +450,7 @@ pub async fn user_stats_detailed_get(
let db_conn = app.db_conn().context("connecting to db")?; let db_conn = app.db_conn().context("connecting to db")?;
let redis_conn = app.redis_conn().await.context("connecting to redis")?; let redis_conn = app.redis_conn().await.context("connecting to redis")?;
let x = get_detailed_stats(&app, bearer, db_conn, redis_conn, params).await?; let x = get_detailed_stats(&app, bearer, params).await?;
Ok(Json(x).into_response()) Ok(Json(x).into_response())
} }

@ -130,7 +130,7 @@ impl Web3Connection {
// TODO: there has to be a cleaner way to do this // TODO: there has to be a cleaner way to do this
if chain_id != found_chain_id.as_u64() { if chain_id != found_chain_id.as_u64() {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"incorrect chain id! Expected {}. Found {}", "incorrect chain id! Config has {}, but RPC has {}",
chain_id, chain_id,
found_chain_id found_chain_id
) )

@ -10,8 +10,8 @@ use migration::Expr;
use num::Zero; use num::Zero;
use redis_rate_limiter::{redis::AsyncCommands, RedisConnection}; use redis_rate_limiter::{redis::AsyncCommands, RedisConnection};
use sea_orm::{ use sea_orm::{
ColumnTrait, Condition, DatabaseConnection, EntityTrait, JoinType, PaginatorTrait, QueryFilter, ColumnTrait, Condition, EntityTrait, JoinType, PaginatorTrait, QueryFilter, QueryOrder,
QueryOrder, QuerySelect, RelationTrait, QuerySelect, RelationTrait,
}; };
use tracing::trace; use tracing::trace;
@ -19,7 +19,7 @@ use crate::app::Web3ProxyApp;
/// get the attached address from redis for the given auth_token. /// get the attached address from redis for the given auth_token.
/// 0 means all users /// 0 means all users
async fn get_user_from_params( async fn get_user_id_from_params(
mut redis_conn: RedisConnection, mut redis_conn: RedisConnection,
// this is a long type. should we strip it down? // this is a long type. should we strip it down?
bearer: Option<TypedHeader<Authorization<Bearer>>>, bearer: Option<TypedHeader<Authorization<Bearer>>>,
@ -56,7 +56,10 @@ async fn get_user_from_params(
/// only allow user_key to be set if user_id is also set. /// only allow user_key to be set if user_id is also set.
/// this will keep people from reading someone else's keys. /// this will keep people from reading someone else's keys.
/// 0 means none. /// 0 means none.
fn get_user_key_from_params(user_id: u64, params: &HashMap<String, String>) -> anyhow::Result<u64> { fn get_user_key_id_from_params(
user_id: u64,
params: &HashMap<String, String>,
) -> anyhow::Result<u64> {
if user_id > 0 { if user_id > 0 {
params.get("user_key").map_or_else( params.get("user_key").map_or_else(
|| Ok(0), || Ok(0),
@ -153,11 +156,11 @@ pub async fn get_aggregate_rpc_stats_from_params(
let db_conn = app.db_conn().context("connecting to db")?; let db_conn = app.db_conn().context("connecting to db")?;
let redis_conn = app.redis_conn().await.context("connecting to redis")?; let redis_conn = app.redis_conn().await.context("connecting to redis")?;
let user_id = get_user_from_params(redis_conn, bearer, &params).await?; let user_id = get_user_id_from_params(redis_conn, bearer, &params).await?;
let chain_id = get_chain_id_from_params(app, &params)?; let chain_id = get_chain_id_from_params(app, &params)?;
let query_start = get_query_start_from_params(&params)?; let query_start = get_query_start_from_params(&params)?;
let page = get_page_from_params(&params)?;
let query_window_seconds = get_query_window_seconds_from_params(&params)?; let query_window_seconds = get_query_window_seconds_from_params(&params)?;
let page = get_page_from_params(&params)?;
// TODO: warn if unknown fields in params // TODO: warn if unknown fields in params
@ -175,9 +178,15 @@ pub async fn get_aggregate_rpc_stats_from_params(
response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert("chain_id", serde_json::to_value(chain_id)?);
response.insert( response.insert(
"query_start", "query_start",
serde_json::to_value(query_start.timestamp())?, serde_json::to_value(query_start.timestamp() as u64)?,
); );
if query_window_seconds != 0 {
response.insert(
"query_window_seconds",
serde_json::to_value(query_window_seconds)?,
);
}
// TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users?
// TODO: how do we count uptime? // TODO: how do we count uptime?
let q = rpc_accounting::Entity::find() let q = rpc_accounting::Entity::find()
@ -210,6 +219,7 @@ pub async fn get_aggregate_rpc_stats_from_params(
) )
.order_by_asc(rpc_accounting::Column::PeriodDatetime.min()); .order_by_asc(rpc_accounting::Column::PeriodDatetime.min());
// TODO: DRYer
let q = if query_window_seconds != 0 { let q = if query_window_seconds != 0 {
/* /*
let query_start_timestamp: u64 = query_start let query_start_timestamp: u64 = query_start
@ -233,7 +243,7 @@ pub async fn get_aggregate_rpc_stats_from_params(
.group_by(Expr::cust("query_window")) .group_by(Expr::cust("query_window"))
} else { } else {
// TODO: order by more than this? // TODO: order by more than this?
// query_window_seconds // query_window_seconds is not set so we aggregate all records
q q
}; };
@ -297,20 +307,21 @@ pub async fn get_user_stats(chain_id: u64) -> u64 {
} }
/// stats grouped by key_id and error_repsponse and method and key /// stats grouped by key_id and error_repsponse and method and key
///
/// TODO: take a "timebucket" duration in minutes that will make a more advanced
pub async fn get_detailed_stats( pub async fn get_detailed_stats(
app: &Web3ProxyApp, app: &Web3ProxyApp,
bearer: Option<TypedHeader<Authorization<Bearer>>>, bearer: Option<TypedHeader<Authorization<Bearer>>>,
db_conn: DatabaseConnection,
redis_conn: RedisConnection,
params: HashMap<String, String>, params: HashMap<String, String>,
) -> anyhow::Result<HashMap<&str, serde_json::Value>> { ) -> anyhow::Result<HashMap<&str, serde_json::Value>> {
let user_id = get_user_from_params(redis_conn, bearer, &params).await?; let db_conn = app.db_conn().context("connecting to db")?;
let user_key = get_user_key_from_params(user_id, &params)?; let redis_conn = app.redis_conn().await.context("connecting to redis")?;
let user_id = get_user_id_from_params(redis_conn, bearer, &params).await?;
let user_key_id = get_user_key_id_from_params(user_id, &params)?;
let chain_id = get_chain_id_from_params(app, &params)?; let chain_id = get_chain_id_from_params(app, &params)?;
let query_start = get_query_start_from_params(&params)?; let query_start = get_query_start_from_params(&params)?;
let query_window_seconds = get_query_window_seconds_from_params(&params)?;
let page = get_page_from_params(&params)?; let page = get_page_from_params(&params)?;
// TODO: handle secondary users, too
// TODO: page size from config // TODO: page size from config
let page_size = 200; let page_size = 200;
@ -324,9 +335,16 @@ pub async fn get_detailed_stats(
response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert("chain_id", serde_json::to_value(chain_id)?);
response.insert( response.insert(
"query_start", "query_start",
serde_json::to_value(query_start.timestamp())?, serde_json::to_value(query_start.timestamp() as u64)?,
); );
if query_window_seconds != 0 {
response.insert(
"query_window_seconds",
serde_json::to_value(query_window_seconds)?,
);
}
// TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users?
// TODO: how do we count uptime? // TODO: how do we count uptime?
let q = rpc_accounting::Entity::find() let q = rpc_accounting::Entity::find()
@ -382,7 +400,7 @@ pub async fn get_detailed_stats(
(condition, q) (condition, q)
}; };
let (condition, q) = if user_id.is_zero() { let (condition, q) = if user_id == 0 {
// 0 means everyone. don't filter on user // 0 means everyone. don't filter on user
(condition, q) (condition, q)
} else { } else {
@ -394,21 +412,26 @@ pub async fn get_detailed_stats(
rpc_accounting::Relation::UserKeys.def(), rpc_accounting::Relation::UserKeys.def(),
) )
.column(user_keys::Column::UserId) .column(user_keys::Column::UserId)
// no need to group_by user_id when we are grouping by key_id .group_by(user_keys::Column::UserId);
// .group_by(user_keys::Column::UserId)
.column(user_keys::Column::Id)
.group_by(user_keys::Column::Id);
let condition = condition.add(user_keys::Column::UserId.eq(user_id)); let condition = condition.add(user_keys::Column::UserId.eq(user_id));
let q = if user_key_id == 0 {
q.column(user_keys::Column::UserId)
.group_by(user_keys::Column::UserId)
} else {
response.insert("user_key_id", serde_json::to_value(user_key_id)?);
// no need to group_by user_id when we are grouping by key_id
q.column(user_keys::Column::Id)
.group_by(user_keys::Column::Id)
};
(condition, q) (condition, q)
}; };
let q = q.filter(condition); let q = q.filter(condition);
// TODO: enum between searching on user_key_id on user_id
// TODO: handle secondary users, too
// log query here. i think sea orm has a useful log level for this // log query here. i think sea orm has a useful log level for this
// TODO: transform this into a nested hashmap instead of a giant table? // TODO: transform this into a nested hashmap instead of a giant table?
@ -431,18 +454,20 @@ pub async fn get_detailed_stats(
/// ///
/// TODO: take a "timebucket" duration in minutes that will make a more advanced /// TODO: take a "timebucket" duration in minutes that will make a more advanced
pub async fn get_revert_logs( pub async fn get_revert_logs(
chain_id: u64, app: &Web3ProxyApp,
db_conn: &DatabaseConnection, bearer: Option<TypedHeader<Authorization<Bearer>>>,
page: usize, params: HashMap<String, String>,
page_size: usize,
query_start: chrono::NaiveDateTime,
user_id: u64,
user_key_id: u64,
) -> anyhow::Result<HashMap<&str, serde_json::Value>> { ) -> anyhow::Result<HashMap<&str, serde_json::Value>> {
// aggregate stats, but grouped by method and error let db_conn = app.db_conn().context("connecting to db")?;
trace!(?chain_id, %query_start, ?user_id, "get_aggregate_stats"); let redis_conn = app.redis_conn().await.context("connecting to redis")?;
// TODO: minimum query_start of 90 days ago? let user_id = get_user_id_from_params(redis_conn, bearer, &params).await?;
let user_key_id = get_user_key_id_from_params(user_id, &params)?;
let chain_id = get_chain_id_from_params(app, &params)?;
let query_start = get_query_start_from_params(&params)?;
let query_window_seconds = get_query_window_seconds_from_params(&params)?;
let page = get_page_from_params(&params)?;
let page_size = get_page_from_params(&params)?;
let mut response = HashMap::new(); let mut response = HashMap::new();
@ -451,9 +476,16 @@ pub async fn get_revert_logs(
response.insert("chain_id", serde_json::to_value(chain_id)?); response.insert("chain_id", serde_json::to_value(chain_id)?);
response.insert( response.insert(
"query_start", "query_start",
serde_json::to_value(query_start.timestamp())?, serde_json::to_value(query_start.timestamp() as u64)?,
); );
if query_window_seconds != 0 {
response.insert(
"query_window_seconds",
serde_json::to_value(query_window_seconds)?,
);
}
// TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users? // TODO: how do we get count reverts compared to other errors? does it matter? what about http errors to our users?
// TODO: how do we count uptime? // TODO: how do we count uptime?
let q = rpc_accounting::Entity::find() let q = rpc_accounting::Entity::find()
@ -545,7 +577,7 @@ pub async fn get_revert_logs(
// TODO: transform this into a nested hashmap instead of a giant table? // TODO: transform this into a nested hashmap instead of a giant table?
let r = q let r = q
.into_json() .into_json()
.paginate(db_conn, page_size) .paginate(&db_conn, page_size)
.fetch_page(page) .fetch_page(page)
.await?; .await?;