Merge branch 'develop' into felipe/new-cache

This commit is contained in:
mergify[bot] 2023-05-18 04:53:58 +00:00 committed by GitHub
commit 991dc2627d
15 changed files with 77 additions and 507 deletions

@ -121,7 +121,6 @@ type Backend struct {
wsURL string wsURL string
authUsername string authUsername string
authPassword string authPassword string
rateLimiter BackendRateLimiter
client *LimitedHTTPClient client *LimitedHTTPClient
dialer *websocket.Dialer dialer *websocket.Dialer
maxRetries int maxRetries int
@ -243,7 +242,6 @@ func NewBackend(
name string, name string,
rpcURL string, rpcURL string,
wsURL string, wsURL string,
rateLimiter BackendRateLimiter,
rpcSemaphore *semaphore.Weighted, rpcSemaphore *semaphore.Weighted,
opts ...BackendOpt, opts ...BackendOpt,
) *Backend { ) *Backend {
@ -251,7 +249,6 @@ func NewBackend(
Name: name, Name: name,
rpcURL: rpcURL, rpcURL: rpcURL,
wsURL: wsURL, wsURL: wsURL,
rateLimiter: rateLimiter,
maxResponseSize: math.MaxInt64, maxResponseSize: math.MaxInt64,
client: &LimitedHTTPClient{ client: &LimitedHTTPClient{
Client: http.Client{Timeout: 5 * time.Second}, Client: http.Client{Timeout: 5 * time.Second},
@ -281,15 +278,6 @@ func NewBackend(
} }
func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]*RPCRes, error) { func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
if !b.Online() {
RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOffline)
return nil, ErrBackendOffline
}
if b.IsRateLimited() {
RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOverCapacity)
return nil, ErrBackendOverCapacity
}
var lastError error var lastError error
// <= to account for the first attempt not technically being // <= to account for the first attempt not technically being
// a retry // a retry
@ -340,24 +328,12 @@ func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]
return res, err return res, err
} }
b.setOffline()
return nil, wrapErr(lastError, "permanent error forwarding request") return nil, wrapErr(lastError, "permanent error forwarding request")
} }
func (b *Backend) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) { func (b *Backend) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
if !b.Online() {
return nil, ErrBackendOffline
}
if b.IsWSSaturated() {
return nil, ErrBackendOverCapacity
}
backendConn, _, err := b.dialer.Dial(b.wsURL, nil) // nolint:bodyclose backendConn, _, err := b.dialer.Dial(b.wsURL, nil) // nolint:bodyclose
if err != nil { if err != nil {
b.setOffline()
if err := b.rateLimiter.DecBackendWSConns(b.Name); err != nil {
log.Error("error decrementing backend ws conns", "name", b.Name, "err", err)
}
return nil, wrapErr(err, "error dialing backend") return nil, wrapErr(err, "error dialing backend")
} }
@ -365,66 +341,6 @@ func (b *Backend) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet
return NewWSProxier(b, clientConn, backendConn, methodWhitelist), nil return NewWSProxier(b, clientConn, backendConn, methodWhitelist), nil
} }
func (b *Backend) Online() bool {
online, err := b.rateLimiter.IsBackendOnline(b.Name)
if err != nil {
log.Warn(
"error getting backend availability, assuming it is offline",
"name", b.Name,
"err", err,
)
return false
}
return online
}
func (b *Backend) IsRateLimited() bool {
if b.maxRPS == 0 {
return false
}
usedLimit, err := b.rateLimiter.IncBackendRPS(b.Name)
if err != nil {
log.Error(
"error getting backend used rate limit, assuming limit is exhausted",
"name", b.Name,
"err", err,
)
return true
}
return b.maxRPS < usedLimit
}
func (b *Backend) IsWSSaturated() bool {
if b.maxWSConns == 0 {
return false
}
incremented, err := b.rateLimiter.IncBackendWSConns(b.Name, b.maxWSConns)
if err != nil {
log.Error(
"error getting backend used ws conns, assuming limit is exhausted",
"name", b.Name,
"err", err,
)
return true
}
return !incremented
}
func (b *Backend) setOffline() {
err := b.rateLimiter.SetBackendOffline(b.Name, b.outOfServiceInterval)
if err != nil {
log.Warn(
"error setting backend offline",
"name", b.Name,
"err", err,
)
}
}
// ForwardRPC makes a call directly to a backend and populate the response into `res` // ForwardRPC makes a call directly to a backend and populate the response into `res`
func (b *Backend) ForwardRPC(ctx context.Context, res *RPCRes, id string, method string, params ...any) error { func (b *Backend) ForwardRPC(ctx context.Context, res *RPCRes, id string, method string, params ...any) error {
jsonParams, err := json.Marshal(params) jsonParams, err := json.Marshal(params)
@ -968,9 +884,6 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
func (w *WSProxier) close() { func (w *WSProxier) close() {
w.clientConn.Close() w.clientConn.Close()
w.backendConn.Close() w.backendConn.Close()
if err := w.backend.rateLimiter.DecBackendWSConns(w.backend.Name); err != nil {
log.Error("error decrementing backend ws conns", "name", w.backend.Name, "err", err)
}
activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec() activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec()
} }
@ -984,10 +897,6 @@ func (w *WSProxier) prepareClientMsg(msg []byte) (*RPCReq, error) {
return req, ErrMethodNotWhitelisted return req, ErrMethodNotWhitelisted
} }
if w.backend.IsRateLimited() {
return req, ErrBackendOverCapacity
}
return req, nil return req, nil
} }

@ -1,286 +0,0 @@
package proxyd
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"math"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
)
const MaxRPSScript = `
local current
current = redis.call("incr", KEYS[1])
if current == 1 then
redis.call("expire", KEYS[1], 1)
end
return current
`
const MaxConcurrentWSConnsScript = `
redis.call("sadd", KEYS[1], KEYS[2])
local total = 0
local scanres = redis.call("sscan", KEYS[1], 0)
for _, k in ipairs(scanres[2]) do
local value = redis.call("get", k)
if value then
total = total + value
end
end
if total < tonumber(ARGV[1]) then
redis.call("incr", KEYS[2])
redis.call("expire", KEYS[2], 300)
return true
end
return false
`
type BackendRateLimiter interface {
IsBackendOnline(name string) (bool, error)
SetBackendOffline(name string, duration time.Duration) error
IncBackendRPS(name string) (int, error)
IncBackendWSConns(name string, max int) (bool, error)
DecBackendWSConns(name string) error
FlushBackendWSConns(names []string) error
}
type RedisBackendRateLimiter struct {
rdb *redis.Client
randID string
touchKeys map[string]time.Duration
tkMtx sync.Mutex
}
func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter {
out := &RedisBackendRateLimiter{
rdb: rdb,
randID: randStr(20),
touchKeys: make(map[string]time.Duration),
}
go out.touch()
return out
}
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result()
if err != nil {
RecordRedisError("IsBackendOnline")
return false, wrapErr(err, "error getting backend availability")
}
return exists == 0, nil
}
func (r *RedisBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
if duration == 0 {
return nil
}
err := r.rdb.SetEX(
context.Background(),
fmt.Sprintf("backend:%s:offline", name),
1,
duration,
).Err()
if err != nil {
RecordRedisError("SetBackendOffline")
return wrapErr(err, "error setting backend unavailable")
}
return nil
}
func (r *RedisBackendRateLimiter) IncBackendRPS(name string) (int, error) {
cmd := r.rdb.Eval(
context.Background(),
MaxRPSScript,
[]string{fmt.Sprintf("backend:%s:ratelimit", name)},
)
rps, err := cmd.Int()
if err != nil {
RecordRedisError("IncBackendRPS")
return -1, wrapErr(err, "error upserting backend rate limit")
}
return rps, nil
}
func (r *RedisBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
r.tkMtx.Lock()
r.touchKeys[connsKey] = 5 * time.Minute
r.tkMtx.Unlock()
cmd := r.rdb.Eval(
context.Background(),
MaxConcurrentWSConnsScript,
[]string{
fmt.Sprintf("backend:%s:proxies", name),
connsKey,
},
max,
)
incremented, err := cmd.Bool()
// false gets coerced to redis.nil, see https://redis.io/commands/eval#conversion-between-lua-and-redis-data-types
if err == redis.Nil {
return false, nil
}
if err != nil {
RecordRedisError("IncBackendWSConns")
return false, wrapErr(err, "error incrementing backend ws conns")
}
return incremented, nil
}
func (r *RedisBackendRateLimiter) DecBackendWSConns(name string) error {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.Decr(context.Background(), connsKey).Err()
if err != nil {
RecordRedisError("DecBackendWSConns")
return wrapErr(err, "error decrementing backend ws conns")
}
return nil
}
func (r *RedisBackendRateLimiter) FlushBackendWSConns(names []string) error {
ctx := context.Background()
for _, name := range names {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.SRem(
ctx,
fmt.Sprintf("backend:%s:proxies", name),
connsKey,
).Err()
if err != nil {
return wrapErr(err, "error flushing backend ws conns")
}
err = r.rdb.Del(ctx, connsKey).Err()
if err != nil {
return wrapErr(err, "error flushing backend ws conns")
}
}
return nil
}
func (r *RedisBackendRateLimiter) touch() {
for {
r.tkMtx.Lock()
for key, dur := range r.touchKeys {
if err := r.rdb.Expire(context.Background(), key, dur).Err(); err != nil {
RecordRedisError("touch")
log.Error("error touching redis key", "key", key, "err", err)
}
}
r.tkMtx.Unlock()
time.Sleep(5 * time.Second)
}
}
type LocalBackendRateLimiter struct {
deadBackends map[string]time.Time
backendRPS map[string]int
backendWSConns map[string]int
mtx sync.RWMutex
}
func NewLocalBackendRateLimiter() *LocalBackendRateLimiter {
out := &LocalBackendRateLimiter{
deadBackends: make(map[string]time.Time),
backendRPS: make(map[string]int),
backendWSConns: make(map[string]int),
}
go out.clear()
return out
}
func (l *LocalBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
l.mtx.RLock()
defer l.mtx.RUnlock()
return l.deadBackends[name].Before(time.Now()), nil
}
func (l *LocalBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
l.mtx.Lock()
defer l.mtx.Unlock()
l.deadBackends[name] = time.Now().Add(duration)
return nil
}
func (l *LocalBackendRateLimiter) IncBackendRPS(name string) (int, error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.backendRPS[name] += 1
return l.backendRPS[name], nil
}
func (l *LocalBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.backendWSConns[name] == max {
return false, nil
}
l.backendWSConns[name] += 1
return true, nil
}
func (l *LocalBackendRateLimiter) DecBackendWSConns(name string) error {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.backendWSConns[name] == 0 {
return nil
}
l.backendWSConns[name] -= 1
return nil
}
func (l *LocalBackendRateLimiter) FlushBackendWSConns(names []string) error {
return nil
}
func (l *LocalBackendRateLimiter) clear() {
for {
time.Sleep(time.Second)
l.mtx.Lock()
l.backendRPS = make(map[string]int)
l.mtx.Unlock()
}
}
func randStr(l int) string {
b := make([]byte, l)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
type NoopBackendRateLimiter struct{}
var noopBackendRateLimiter = &NoopBackendRateLimiter{}
func (n *NoopBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
return true, nil
}
func (n *NoopBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
return nil
}
func (n *NoopBackendRateLimiter) IncBackendRPS(name string) (int, error) {
return math.MaxInt, nil
}
func (n *NoopBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
return true, nil
}
func (n *NoopBackendRateLimiter) DecBackendWSConns(name string) error {
return nil
}
func (n *NoopBackendRateLimiter) FlushBackendWSConns(names []string) error {
return nil
}

@ -2,6 +2,7 @@ package proxyd
import ( import (
"context" "context"
"strings"
"time" "time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
@ -44,15 +45,23 @@ func (c *cache) Put(ctx context.Context, key string, value string) error {
type redisCache struct { type redisCache struct {
rdb *redis.Client rdb *redis.Client
prefix string
} }
func newRedisCache(rdb *redis.Client) *redisCache { func newRedisCache(rdb *redis.Client, prefix string) *redisCache {
return &redisCache{rdb} return &redisCache{rdb, prefix}
}
func (c *redisCache) namespaced(key string) string {
if c.prefix == "" {
return key
}
return strings.Join([]string{c.prefix, key}, ":")
} }
func (c *redisCache) Get(ctx context.Context, key string) (string, error) { func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
start := time.Now() start := time.Now()
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, c.namespaced(key)).Result()
redisCacheDurationSumm.WithLabelValues("GET").Observe(float64(time.Since(start).Milliseconds())) redisCacheDurationSumm.WithLabelValues("GET").Observe(float64(time.Since(start).Milliseconds()))
if err == redis.Nil { if err == redis.Nil {
@ -66,7 +75,7 @@ func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
func (c *redisCache) Put(ctx context.Context, key string, value string) error { func (c *redisCache) Put(ctx context.Context, key string, value string) error {
start := time.Now() start := time.Now()
err := c.rdb.SetEX(ctx, key, value, redisTTL).Err() err := c.rdb.SetEX(ctx, c.namespaced(key), value, redisTTL).Err()
redisCacheDurationSumm.WithLabelValues("SETEX").Observe(float64(time.Since(start).Milliseconds())) redisCacheDurationSumm.WithLabelValues("SETEX").Observe(float64(time.Since(start).Milliseconds()))
if err != nil { if err != nil {

@ -33,6 +33,7 @@ type CacheConfig struct {
type RedisConfig struct { type RedisConfig struct {
URL string `toml:"url"` URL string `toml:"url"`
Namespace string `toml:"namespace"`
} }
type MetricsConfig struct { type MetricsConfig struct {
@ -43,7 +44,6 @@ type MetricsConfig struct {
type RateLimitConfig struct { type RateLimitConfig struct {
UseRedis bool `toml:"use_redis"` UseRedis bool `toml:"use_redis"`
EnableBackendRateLimiter bool `toml:"enable_backend_rate_limiter"`
BaseRate int `toml:"base_rate"` BaseRate int `toml:"base_rate"`
BaseInterval TOMLDuration `toml:"base_interval"` BaseInterval TOMLDuration `toml:"base_interval"`
ExemptOrigins []string `toml:"exempt_origins"` ExemptOrigins []string `toml:"exempt_origins"`

@ -17,11 +17,14 @@ const (
PollerInterval = 1 * time.Second PollerInterval = 1 * time.Second
) )
type OnConsensusBroken func()
// ConsensusPoller checks the consensus state for each member of a BackendGroup // ConsensusPoller checks the consensus state for each member of a BackendGroup
// resolves the highest common block for multiple nodes, and reconciles the consensus // resolves the highest common block for multiple nodes, and reconciles the consensus
// in case of block hash divergence to minimize re-orgs // in case of block hash divergence to minimize re-orgs
type ConsensusPoller struct { type ConsensusPoller struct {
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
listeners []OnConsensusBroken
backendGroup *BackendGroup backendGroup *BackendGroup
backendState map[*Backend]*backendState backendState map[*Backend]*backendState
@ -150,6 +153,16 @@ func WithAsyncHandler(asyncHandler ConsensusAsyncHandler) ConsensusOpt {
} }
} }
func WithListener(listener OnConsensusBroken) ConsensusOpt {
return func(cp *ConsensusPoller) {
cp.AddListener(listener)
}
}
func (cp *ConsensusPoller) AddListener(listener OnConsensusBroken) {
cp.listeners = append(cp.listeners, listener)
}
func WithBanPeriod(banPeriod time.Duration) ConsensusOpt { func WithBanPeriod(banPeriod time.Duration) ConsensusOpt {
return func(cp *ConsensusPoller) { return func(cp *ConsensusPoller) {
cp.banPeriod = banPeriod cp.banPeriod = banPeriod
@ -220,14 +233,8 @@ func (cp *ConsensusPoller) UpdateBackend(ctx context.Context, be *Backend) {
return return
} }
// if backend exhausted rate limit we'll skip it for now // if backend is not healthy state we'll only resume checking it after ban
if be.IsRateLimited() { if !be.IsHealthy() {
log.Debug("skipping backend - rate limited", "backend", be.Name)
return
}
// if backend it not online or not in a health state we'll only resume checkin it after ban
if !be.Online() || !be.IsHealthy() {
log.Warn("backend banned - not online or not healthy", "backend", be.Name) log.Warn("backend banned - not online or not healthy", "backend", be.Name)
cp.Ban(be) cp.Ban(be)
return return
@ -348,12 +355,11 @@ func (cp *ConsensusPoller) UpdateBackendGroupConsensus(ctx context.Context) {
/* /*
a serving node needs to be: a serving node needs to be:
- healthy (network) - healthy (network)
- not rate limited - updated recently
- online
- not banned - not banned
- with minimum peer count - with minimum peer count
- updated recently - not lagging latest block
- not lagging - in sync
*/ */
peerCount, inSync, latestBlockNumber, _, lastUpdate, bannedUntil := cp.getBackendState(be) peerCount, inSync, latestBlockNumber, _, lastUpdate, bannedUntil := cp.getBackendState(be)
@ -361,7 +367,7 @@ func (cp *ConsensusPoller) UpdateBackendGroupConsensus(ctx context.Context) {
isBanned := time.Now().Before(bannedUntil) isBanned := time.Now().Before(bannedUntil)
notEnoughPeers := !be.skipPeerCountCheck && peerCount < cp.minPeerCount notEnoughPeers := !be.skipPeerCountCheck && peerCount < cp.minPeerCount
lagging := latestBlockNumber < proposedBlock lagging := latestBlockNumber < proposedBlock
if !be.IsHealthy() || be.IsRateLimited() || !be.Online() || notUpdated || isBanned || notEnoughPeers || lagging || !inSync { if !be.IsHealthy() || notUpdated || isBanned || notEnoughPeers || lagging || !inSync {
filteredBackendsNames = append(filteredBackendsNames, be.Name) filteredBackendsNames = append(filteredBackendsNames, be.Name)
continue continue
} }
@ -398,6 +404,9 @@ func (cp *ConsensusPoller) UpdateBackendGroupConsensus(ctx context.Context) {
if broken { if broken {
// propagate event to other interested parts, such as cache invalidator // propagate event to other interested parts, such as cache invalidator
for _, l := range cp.listeners {
l()
}
log.Info("consensus broken", "currentConsensusBlockNumber", currentConsensusBlockNumber, "proposedBlock", proposedBlock, "proposedBlockHash", proposedBlockHash) log.Info("consensus broken", "currentConsensusBlockNumber", currentConsensusBlockNumber, "proposedBlock", proposedBlock, "proposedBlockHash", proposedBlockHash)
} }

@ -289,6 +289,11 @@ func TestConsensus(t *testing.T) {
h2.ResetOverrides() h2.ResetOverrides()
bg.Consensus.Unban() bg.Consensus.Unban()
listenerCalled := false
bg.Consensus.AddListener(func() {
listenerCalled = true
})
for _, be := range bg.Backends { for _, be := range bg.Backends {
bg.Consensus.UpdateBackend(ctx, be) bg.Consensus.UpdateBackend(ctx, be)
} }
@ -334,7 +339,7 @@ func TestConsensus(t *testing.T) {
// should resolve to 0x1, since 0x2 is out of consensus at the moment // should resolve to 0x1, since 0x2 is out of consensus at the moment
require.Equal(t, "0x1", bg.Consensus.GetConsensusBlockNumber().String()) require.Equal(t, "0x1", bg.Consensus.GetConsensusBlockNumber().String())
// later, when impl events, listen to broken consensus event require.True(t, listenerCalled)
}) })
t.Run("broken consensus with depth 2", func(t *testing.T) { t.Run("broken consensus with depth 2", func(t *testing.T) {

@ -190,7 +190,7 @@ func TestOutOfServiceInterval(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 200, statusCode) require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res) RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 2, len(badBackend.Requests())) require.Equal(t, 4, len(badBackend.Requests()))
require.Equal(t, 2, len(goodBackend.Requests())) require.Equal(t, 2, len(goodBackend.Requests()))
_, statusCode, err = client.SendBatchRPC( _, statusCode, err = client.SendBatchRPC(
@ -199,7 +199,7 @@ func TestOutOfServiceInterval(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 200, statusCode) require.Equal(t, 200, statusCode)
require.Equal(t, 2, len(badBackend.Requests())) require.Equal(t, 8, len(badBackend.Requests()))
require.Equal(t, 4, len(goodBackend.Requests())) require.Equal(t, 4, len(goodBackend.Requests()))
time.Sleep(time.Second) time.Sleep(time.Second)
@ -209,7 +209,7 @@ func TestOutOfServiceInterval(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 200, statusCode) require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res) RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 3, len(badBackend.Requests())) require.Equal(t, 9, len(badBackend.Requests()))
require.Equal(t, 4, len(goodBackend.Requests())) require.Equal(t, 4, len(goodBackend.Requests()))
} }
@ -261,7 +261,6 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
config.BackendOptions.MaxRetries = 2 config.BackendOptions.MaxRetries = 2
// Setup redis to detect offline backends // Setup redis to detect offline backends
config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port()) config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())
redisClient, err := proxyd.NewRedisClient(config.Redis.URL)
require.NoError(t, err) require.NoError(t, err)
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
@ -286,10 +285,4 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res) RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res)
require.Equal(t, 1, len(badBackend.Requests())) require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests())) require.Equal(t, 1, len(goodBackend.Requests()))
rr := proxyd.NewRedisRateLimiter(redisClient)
require.NoError(t, err)
online, err := rr.IsBackendOnline("bad")
require.NoError(t, err)
require.Equal(t, true, online)
} }

@ -21,23 +21,6 @@ const frontendOverLimitResponseWithID = `{"error":{"code":-32016,"message":"over
var ethChainID = "eth_chainId" var ethChainID = "eth_chainId"
func TestBackendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("backend_rate_limit")
client := NewProxydClient("http://127.0.0.1:8545")
_, shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
limitedRes, codes := spamReqs(t, client, ethChainID, 503, 3)
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
}
func TestFrontendMaxRPSLimit(t *testing.T) { func TestFrontendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()

@ -1,21 +0,0 @@
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
max_rps = 2
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
[rate_limit]
enable_backend_rate_limiter = true

@ -6,6 +6,7 @@ response_timeout_seconds = 1
[redis] [redis]
url = "$REDIS_URL" url = "$REDIS_URL"
namespace = "proxyd"
[cache] [cache]
enabled = true enabled = true

@ -20,6 +20,3 @@ backends = ["bad", "good"]
[rpc_method_mappings] [rpc_method_mappings]
eth_chainId = "main" eth_chainId = "main"
[rate_limit]
enable_backend_rate_limiter = true

@ -26,6 +26,3 @@ backends = ["good"]
[rpc_method_mappings] [rpc_method_mappings]
eth_chainId = "main" eth_chainId = "main"
[rate_limit]
enable_backend_rate_limiter = true

@ -270,32 +270,3 @@ func TestWSClientClosure(t *testing.T) {
}) })
} }
} }
func TestWSClientMaxConns(t *testing.T) {
backend := NewMockWSBackend(nil, nil, nil)
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
_, shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
doneCh := make(chan struct{}, 1)
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
require.NoError(t, err)
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, func(err error) {
require.Contains(t, err.Error(), "unexpected EOF")
doneCh <- struct{}{}
})
require.NoError(t, err)
timeout := time.NewTicker(30 * time.Second)
select {
case <-timeout.C:
t.Fatalf("timed out")
case <-doneCh:
return
}
}

@ -49,19 +49,6 @@ func Start(config *Config) (*Server, func(), error) {
return nil, nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config") return nil, nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config")
} }
var lim BackendRateLimiter
var err error
if config.RateLimit.EnableBackendRateLimiter {
if redisClient != nil {
lim = NewRedisRateLimiter(redisClient)
} else {
log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalBackendRateLimiter()
}
} else {
lim = noopBackendRateLimiter
}
// While modifying shared globals is a bad practice, the alternative // While modifying shared globals is a bad practice, the alternative
// is to clone these errors on every invocation. This is inefficient. // is to clone these errors on every invocation. This is inefficient.
// We'd also have to make sure that errors.Is and errors.As continue // We'd also have to make sure that errors.Is and errors.As continue
@ -157,10 +144,14 @@ func Start(config *Config) (*Server, func(), error) {
opts = append(opts, WithProxydIP(os.Getenv("PROXYD_IP"))) opts = append(opts, WithProxydIP(os.Getenv("PROXYD_IP")))
opts = append(opts, WithSkipPeerCountCheck(cfg.SkipPeerCountCheck)) opts = append(opts, WithSkipPeerCountCheck(cfg.SkipPeerCountCheck))
back := NewBackend(name, rpcURL, wsURL, lim, rpcRequestSemaphore, opts...) back := NewBackend(name, rpcURL, wsURL, rpcRequestSemaphore, opts...)
backendNames = append(backendNames, name) backendNames = append(backendNames, name)
backendsByName[name] = back backendsByName[name] = back
log.Info("configured backend", "name", name, "rpc_url", rpcURL, "ws_url", wsURL) log.Info("configured backend",
"name", name,
"backend_names", backendNames,
"rpc_url", rpcURL,
"ws_url", wsURL)
} }
backendGroups := make(map[string]*BackendGroup) backendGroups := make(map[string]*BackendGroup)
@ -227,7 +218,7 @@ func Start(config *Config) (*Server, func(), error) {
log.Warn("redis is not configured, using in-memory cache") log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache() cache = newMemoryCache()
} else { } else {
cache = newRedisCache(redisClient) cache = newRedisCache(redisClient, config.Redis.Namespace)
} }
// Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind // Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind
ethClient, err := ethclient.Dial(blockSyncRPCURL) ethClient, err := ethclient.Dial(blockSyncRPCURL)
@ -335,9 +326,6 @@ func Start(config *Config) (*Server, func(), error) {
shutdownFunc := func() { shutdownFunc := func() {
log.Info("shutting down proxyd") log.Info("shutting down proxyd")
srv.Shutdown() srv.Shutdown()
if err := lim.FlushBackendWSConns(backendNames); err != nil {
log.Error("error flushing backend ws conns", "err", err)
}
log.Info("goodbye") log.Info("goodbye")
} }

@ -2,6 +2,8 @@ package proxyd
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -222,6 +224,11 @@ func (s *Server) Shutdown() {
if s.wsServer != nil { if s.wsServer != nil {
_ = s.wsServer.Shutdown(context.Background()) _ = s.wsServer.Shutdown(context.Background())
} }
for _, bg := range s.BackendGroups {
if bg.Consensus != nil {
bg.Consensus.Shutdown()
}
}
} }
func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
@ -586,6 +593,14 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
) )
} }
func randStr(l int) string {
b := make([]byte, l)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
func (s *Server) isUnlimitedOrigin(origin string) bool { func (s *Server) isUnlimitedOrigin(origin string) bool {
for _, pat := range s.limExemptOrigins { for _, pat := range s.limExemptOrigins {
if pat.MatchString(origin) { if pat.MatchString(origin) {