thread fast rng

This commit is contained in:
Bryan Stitt 2022-11-12 06:11:58 +00:00
parent 8e3547bbd0
commit 9ae2337d1d
12 changed files with 183 additions and 35 deletions

19
Cargo.lock generated

@ -3623,6 +3623,15 @@ dependencies = [
"rand_core 0.3.1",
]
[[package]]
name = "rand_xoshiro"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
dependencies = [
"rand_core 0.6.3",
]
[[package]]
name = "rayon"
version = "1.5.3"
@ -4921,6 +4930,14 @@ dependencies = [
"syn",
]
[[package]]
name = "thread-fast-rng"
version = "0.1.0"
dependencies = [
"rand 0.8.5",
"rand_xoshiro",
]
[[package]]
name = "thread-id"
version = "4.0.0"
@ -5608,7 +5625,6 @@ dependencies = [
"parking_lot 0.12.1",
"petgraph",
"proctitle",
"rand 0.8.5",
"redis-rate-limiter",
"regex",
"reqwest",
@ -5620,6 +5636,7 @@ dependencies = [
"serde_json",
"serde_prometheus",
"siwe",
"thread-fast-rng",
"time 0.3.17",
"tokio",
"tokio-stream",

@ -4,6 +4,7 @@ members = [
"entities",
"migration",
"redis-rate-limiter",
"thread-fast-rng",
"web3_proxy",
]

@ -97,7 +97,7 @@ Flame graphs make a developer's join of finding slow code painless:
$ cat /proc/sys/kernel/perf_event_paranoid
4
$ echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid
0
-1
$ CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --bin web3_proxy

@ -235,9 +235,9 @@ These are roughly in order of completition
- [x] test that runs check_config against example.toml
- [-] add configurable size limits to all the Caches
- instead of configuring each cache with MB sizes, have one value for total memory footprint and then percentages for each cache
- [ ] improve sorting servers by weight
- if the utilization is > 100%, increase weight by 1? maybe just add utilization to the weight?
- if utilization is > hard limit, add a lot to the weight
- [x] improve sorting servers by weight. don't force to lower weights, still have a probability that smaller weights might be
- [ ] add block timestamp to the /status page
- [ ] cache the status page for a second
- [ ] actually block unauthenticated requests instead of emitting warning of "allowing without auth during development!"

@ -0,0 +1,10 @@
[package]
name = "thread-fast-rng"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.8.5"
rand_xoshiro = "0.6.0"

@ -0,0 +1,87 @@
//! works just like rand::thread_rng but with a rng that is **not** cryptographically secure
//!
//! TODO: currently uses Xoshiro256Plus. do some benchmarks
pub use rand;
use rand::{Error, Rng, RngCore, SeedableRng};
use rand_xoshiro::Xoshiro256Plus;
use std::{cell::UnsafeCell, rc::Rc};
#[derive(Clone, Debug)]
pub struct ThreadFastRng {
// Rc is explicitly !Send and !Sync
rng: Rc<UnsafeCell<Xoshiro256Plus>>,
}
thread_local! {
pub static THREAD_FAST_RNG: Rc<UnsafeCell<Xoshiro256Plus>> = {
// use a cryptographically secure rng for the seed
let mut crypto_rng = rand::thread_rng();
let seed = crypto_rng.gen();
// use a fast rng for things that aren't cryptography
let rng = Xoshiro256Plus::seed_from_u64(seed);
Rc::new(UnsafeCell::new(rng))
};
}
pub fn thread_fast_rng() -> ThreadFastRng {
let rng = THREAD_FAST_RNG.with(|t| t.clone());
ThreadFastRng { rng }
}
impl Default for ThreadFastRng {
fn default() -> Self {
thread_fast_rng()
}
}
impl RngCore for ThreadFastRng {
#[inline(always)]
fn next_u32(&mut self) -> u32 {
// SAFETY: We must make sure to stop using `rng` before anyone else
// creates another mutable reference
let rng = unsafe { &mut *self.rng.get() };
rng.next_u32()
}
#[inline(always)]
fn next_u64(&mut self) -> u64 {
// SAFETY: We must make sure to stop using `rng` before anyone else
// creates another mutable reference
let rng = unsafe { &mut *self.rng.get() };
rng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
// SAFETY: We must make sure to stop using `rng` before anyone else
// creates another mutable reference
let rng = unsafe { &mut *self.rng.get() };
rng.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
// SAFETY: We must make sure to stop using `rng` before anyone else
// creates another mutable reference
let rng = unsafe { &mut *self.rng.get() };
rng.try_fill_bytes(dest)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_thread_fast_rng() {
let mut r = thread_fast_rng();
r.gen::<i32>();
assert_eq!(r.gen_range(0..1), 0);
}
#[test]
fn test_thread_fast_rng_struct() {
let mut r = ThreadFastRng::default();
r.gen::<i32>();
assert_eq!(r.gen_range(0..1), 0);
}
}

@ -18,6 +18,7 @@ deferred-rate-limiter = { path = "../deferred-rate-limiter" }
entities = { path = "../entities" }
migration = { path = "../migration" }
redis-rate-limiter = { path = "../redis-rate-limiter" }
thread-fast-rng = { path = "../thread-fast-rng" }
anyhow = { version = "1.0.66", features = ["backtrace"] }
arc-swap = "1.5.1"
@ -47,7 +48,6 @@ num-traits = "0.2.15"
parking_lot = { version = "0.12.1", features = ["arc_lock"] }
petgraph = "0.6.2"
proctitle = "0.1.1"
rand = "0.8.5"
# TODO: regex has several "perf" features that we might want to use
regex = "1.7.0"
reqwest = { version = "0.11.12", default-features = false, features = ["json", "tokio-rustls"] }

@ -1,5 +1,6 @@
use metered::{metered, HitCount, Throughput};
use serde::Serialize;
use thread_fast_rng::{rand::Rng, thread_fast_rng};
#[derive(Default, Debug, Serialize)]
pub struct Biz {
@ -10,7 +11,7 @@ pub struct Biz {
impl Biz {
#[measure([HitCount, Throughput])]
pub fn biz(&self) {
let delay = std::time::Duration::from_millis(rand::random::<u64>() % 200);
let delay = std::time::Duration::from_millis(thread_fast_rng().gen::<u64>() % 200);
std::thread::sleep(delay);
}
}

@ -189,7 +189,7 @@ pub struct Web3ConnectionConfig {
pub soft_limit: u32,
/// the requests per second at which the server throws errors (rate limit or otherwise)
pub hard_limit: Option<u64>,
/// All else equal, a server with a lower weight receives requests
/// All else equal, a server with a lower weight receives more requests. Ranges 0-100
pub weight: u32,
/// Subscribe to the firehose of pending transactions
/// Don't do this with free rpcs

@ -10,7 +10,6 @@ use ethers::prelude::{Bytes, Middleware, ProviderError, TxHash, H256, U64};
use futures::future::try_join_all;
use futures::StreamExt;
use parking_lot::RwLock;
use rand::Rng;
use redis_rate_limiter::{RedisPool, RedisRateLimitResult, RedisRateLimiter};
use sea_orm::DatabaseConnection;
use serde::ser::{SerializeStruct, Serializer};
@ -21,6 +20,8 @@ use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{self, AtomicU32, AtomicU64};
use std::{cmp::Ordering, sync::Arc};
use thread_fast_rng::rand::Rng;
use thread_fast_rng::thread_fast_rng;
use tokio::sync::broadcast;
use tokio::sync::RwLock as AsyncRwLock;
use tokio::time::{interval, sleep, sleep_until, Duration, Instant, MissedTickBehavior};
@ -49,9 +50,8 @@ pub struct Web3Connection {
pub(super) soft_limit: u32,
/// TODO: have an enum for this so that "no limit" prints pretty?
block_data_limit: AtomicU64,
/// Lower weight are higher priority when sending requests
pub(super) weight: u32,
// TODO: async lock?
/// Lower weight are higher priority when sending requests. 0 to 99.
pub(super) weight: f64,
pub(super) head_block_id: RwLock<Option<BlockId>>,
pub(super) open_request_handle_metrics: Arc<OpenRequestHandleMetrics>,
}
@ -84,17 +84,20 @@ impl Web3Connection {
// TODO: is cache size 1 okay? i think we need
RedisRateLimiter::new(
"web3_proxy",
&format!("{}:{}", chain_id, url_str),
&format!("{}:{}", chain_id, name),
hard_rate_limit,
60.0,
redis_pool,
)
});
// turn weight 0 into 100% and weight 100 into 0%
let weight = (100 - weight) as f64 / 100.0;
let new_connection = Self {
name,
http_client,
url: url_str.clone(),
url: url_str,
active_requests: 0.into(),
total_requests: 0.into(),
provider: AsyncRwLock::new(None),
@ -293,7 +296,7 @@ impl Web3Connection {
let mut sleep_ms = if initial_sleep {
let first_sleep_ms = min(
cap_ms,
rand::thread_rng().gen_range(base_ms..(base_ms * range_multiplier)),
thread_fast_rng().gen_range(base_ms..(base_ms * range_multiplier)),
);
let reconnect_in = Duration::from_millis(first_sleep_ms);
@ -308,9 +311,10 @@ impl Web3Connection {
// retry until we succeed
while let Err(err) = self.reconnect(block_sender).await {
// thread_rng is crytographically secure. we don't need that, but we don't need this super efficient so its fine
sleep_ms = min(
cap_ms,
rand::thread_rng().gen_range(base_ms..(sleep_ms * range_multiplier)),
thread_fast_rng().gen_range(base_ms..(sleep_ms * range_multiplier)),
);
let retry_in = Duration::from_millis(sleep_ms);

@ -26,11 +26,10 @@ use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use serde_json::json;
use serde_json::value::RawValue;
use std::cmp;
use std::cmp::Reverse;
use std::fmt;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use thread_fast_rng::rand::seq::SliceRandom;
use tokio::sync::RwLock as AsyncRwLock;
use tokio::sync::{broadcast, watch};
use tokio::task;
@ -400,7 +399,7 @@ impl Web3Connections {
// filter the synced rpcs
// TODO: we are going to be checking "has_block_data" a lot now
let mut synced_rpcs: Vec<Arc<Web3Connection>> = self
let synced_rpcs: Vec<Arc<Web3Connection>> = self
.conns
.values()
.filter(|x| !skip.contains(x))
@ -414,35 +413,62 @@ impl Web3Connections {
return Err(anyhow::anyhow!("no servers are synced"));
}
let mut minimum = 0.0;
// we sort on a bunch of values. cache them here so that we don't do this math multiple times.
let sort_cache: HashMap<_, _> = synced_rpcs
let weight_map: HashMap<_, f64> = synced_rpcs
.iter()
.map(|rpc| {
// TODO: get active requests and the soft limit out of redis?
// TODO: put this on the rpc object instead?
let weight = rpc.weight;
// TODO: are active requests what we want? do we want a counter for requests in the last second + any actives longer than that?
// TODO: get active requests out of redis?
// TODO: do something with hard limit instead?
let active_requests = rpc.active_requests();
let soft_limit = rpc.soft_limit;
let utilization = active_requests as f32 / soft_limit as f32;
// TODO: maybe store weight as the percentile
let available_requests = soft_limit as f64 * weight - active_requests as f64;
// TODO: utilization isn't enough we need to sort on some combination of utilization and if a server is archive or not
// TODO: if a server's utilization is high and it has a low weight, it will keep getting requests. this isn't really what we want
if available_requests < 0.0 {
minimum = available_requests.min(minimum);
}
// TODO: double check this sorts how we want
(rpc.clone(), (weight, utilization, Reverse(soft_limit)))
(rpc.clone(), available_requests)
})
.collect();
synced_rpcs.sort_unstable_by(|a, b| {
let a_sorts = sort_cache.get(a).expect("sort_cache should always have a");
let b_sorts = sort_cache.get(b).expect("sort_cache should always have b");
// we can't have negative numbers. shift up if any are negative
let weight_map: HashMap<_, f64> = if minimum < 0.0 {
weight_map
} else {
weight_map
.into_iter()
.map(|(rpc, weight)| {
// TODO: is simple addition the right way to shift everyone?
let x = weight + minimum;
// partial_cmp because we are comparing floats
a_sorts.partial_cmp(b_sorts).unwrap_or(cmp::Ordering::Equal)
});
(rpc, x)
})
.collect()
};
let sorted_rpcs = {
let mut rng = thread_fast_rng::thread_fast_rng();
synced_rpcs
.choose_multiple_weighted(&mut rng, synced_rpcs.len(), |rpc| {
*weight_map
.get(rpc)
.expect("rpc should always be in the weight map")
})
.unwrap()
.collect::<Vec<_>>()
};
// now that the rpcs are sorted, try to get an active request handle for one of them
for rpc in synced_rpcs.into_iter() {
for rpc in sorted_rpcs.into_iter() {
// increment our connection counter
match rpc.try_request_handle(authorization).await {
Ok(OpenRequestResult::Handle(handle)) => {

@ -12,13 +12,13 @@ use metered::metered;
use metered::HitCount;
use metered::ResponseTime;
use metered::Throughput;
use rand::Rng;
use sea_orm::ActiveEnum;
use sea_orm::ActiveModelTrait;
use serde_json::json;
use std::fmt;
use std::sync::atomic::{self, AtomicBool, Ordering};
use std::sync::Arc;
use thread_fast_rng::rand::Rng;
use tokio::time::{sleep, Duration, Instant};
use tracing::Level;
use tracing::{debug, error, trace, warn};
@ -222,7 +222,9 @@ impl OpenRequestHandle {
} else if log_revert_chance == 1.0 {
trace!(%method, "gaurenteed chance. SAVING on revert");
error_handler
} else if rand::thread_rng().gen_range(0.0f64..=1.0) < log_revert_chance {
} else if thread_fast_rng::thread_fast_rng().gen_range(0.0f64..=1.0)
< log_revert_chance
{
trace!(%method, "missed chance. skipping save on revert");
RequestErrorHandler::DebugLevel
} else {