package proxyd import ( "context" "crypto/tls" "errors" "fmt" "net/http" "os" "time" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/log" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/redis/go-redis/v9" "golang.org/x/exp/slog" "golang.org/x/sync/semaphore" ) func SetLogLevel(logLevel slog.Leveler) { log.SetDefault(log.NewLogger(slog.NewJSONHandler( os.Stdout, &slog.HandlerOptions{Level: logLevel}))) } func Start(config *Config) (*Server, func(), error) { if len(config.Backends) == 0 { return nil, nil, errors.New("must define at least one backend") } if len(config.BackendGroups) == 0 { return nil, nil, errors.New("must define at least one backend group") } if len(config.RPCMethodMappings) == 0 { return nil, nil, errors.New("must define at least one RPC method mapping") } for authKey := range config.Authentication { if authKey == "none" { return nil, nil, errors.New("cannot use none as an auth key") } } // redis primary client var redisClient *redis.Client if config.Redis.URL != "" { rURL, err := ReadFromEnvOrConfig(config.Redis.URL) if err != nil { return nil, nil, err } redisClient, err = NewRedisClient(rURL) if err != nil { return nil, nil, err } } // redis read replica client // if read endpoint is not set, use primary endpoint var redisReadClient = redisClient if config.Redis.ReadURL != "" { if redisClient == nil { return nil, nil, errors.New("must specify a Redis primary URL. only read endpoint is set") } rURL, err := ReadFromEnvOrConfig(config.Redis.ReadURL) if err != nil { return nil, nil, err } redisReadClient, err = NewRedisClient(rURL) if err != nil { return nil, nil, err } } if redisClient == nil && config.RateLimit.UseRedis { return nil, nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config") } // While modifying shared globals is a bad practice, the alternative // is to clone these errors on every invocation. This is inefficient. // We'd also have to make sure that errors.Is and errors.As continue // to function properly on the cloned errors. if config.RateLimit.ErrorMessage != "" { ErrOverRateLimit.Message = config.RateLimit.ErrorMessage } if config.WhitelistErrorMessage != "" { ErrMethodNotWhitelisted.Message = config.WhitelistErrorMessage } if config.BatchConfig.ErrorMessage != "" { ErrTooManyBatchRequests.Message = config.BatchConfig.ErrorMessage } if config.SenderRateLimit.Enabled { if config.SenderRateLimit.Limit <= 0 { return nil, nil, errors.New("limit in sender_rate_limit must be > 0") } if time.Duration(config.SenderRateLimit.Interval) < time.Second { return nil, nil, errors.New("interval in sender_rate_limit must be >= 1s") } } maxConcurrentRPCs := config.Server.MaxConcurrentRPCs if maxConcurrentRPCs == 0 { maxConcurrentRPCs = math.MaxInt64 } rpcRequestSemaphore := semaphore.NewWeighted(maxConcurrentRPCs) backendNames := make([]string, 0) backendsByName := make(map[string]*Backend) for name, cfg := range config.Backends { opts := make([]BackendOpt, 0) rpcURL, err := ReadFromEnvOrConfig(cfg.RPCURL) if err != nil { return nil, nil, err } wsURL, err := ReadFromEnvOrConfig(cfg.WSURL) if err != nil { return nil, nil, err } if rpcURL == "" { return nil, nil, fmt.Errorf("must define an RPC URL for backend %s", name) } if config.BackendOptions.ResponseTimeoutSeconds != 0 { timeout := secondsToDuration(config.BackendOptions.ResponseTimeoutSeconds) opts = append(opts, WithTimeout(timeout)) } if config.BackendOptions.MaxRetries != 0 { opts = append(opts, WithMaxRetries(config.BackendOptions.MaxRetries)) } if config.BackendOptions.MaxResponseSizeBytes != 0 { opts = append(opts, WithMaxResponseSize(config.BackendOptions.MaxResponseSizeBytes)) } if config.BackendOptions.OutOfServiceSeconds != 0 { opts = append(opts, WithOutOfServiceDuration(secondsToDuration(config.BackendOptions.OutOfServiceSeconds))) } if config.BackendOptions.MaxDegradedLatencyThreshold > 0 { opts = append(opts, WithMaxDegradedLatencyThreshold(time.Duration(config.BackendOptions.MaxDegradedLatencyThreshold))) } if config.BackendOptions.MaxLatencyThreshold > 0 { opts = append(opts, WithMaxLatencyThreshold(time.Duration(config.BackendOptions.MaxLatencyThreshold))) } if config.BackendOptions.MaxErrorRateThreshold > 0 { opts = append(opts, WithMaxErrorRateThreshold(config.BackendOptions.MaxErrorRateThreshold)) } if cfg.MaxRPS != 0 { opts = append(opts, WithMaxRPS(cfg.MaxRPS)) } if cfg.MaxWSConns != 0 { opts = append(opts, WithMaxWSConns(cfg.MaxWSConns)) } if cfg.Password != "" { passwordVal, err := ReadFromEnvOrConfig(cfg.Password) if err != nil { return nil, nil, err } opts = append(opts, WithBasicAuth(cfg.Username, passwordVal)) } headers := map[string]string{} for headerName, headerValue := range cfg.Headers { headerValue, err := ReadFromEnvOrConfig(headerValue) if err != nil { return nil, nil, err } headers[headerName] = headerValue } opts = append(opts, WithHeaders(headers)) tlsConfig, err := configureBackendTLS(cfg) if err != nil { return nil, nil, err } if tlsConfig != nil { log.Info("using custom TLS config for backend", "name", name) opts = append(opts, WithTLSConfig(tlsConfig)) } if cfg.StripTrailingXFF { opts = append(opts, WithStrippedTrailingXFF()) } opts = append(opts, WithProxydIP(os.Getenv("PROXYD_IP"))) opts = append(opts, WithConsensusSkipPeerCountCheck(cfg.ConsensusSkipPeerCountCheck)) opts = append(opts, WithConsensusForcedCandidate(cfg.ConsensusForcedCandidate)) opts = append(opts, WithWeight(cfg.Weight)) receiptsTarget, err := ReadFromEnvOrConfig(cfg.ConsensusReceiptsTarget) if err != nil { return nil, nil, err } receiptsTarget, err = validateReceiptsTarget(receiptsTarget) if err != nil { return nil, nil, err } opts = append(opts, WithConsensusReceiptTarget(receiptsTarget)) back := NewBackend(name, rpcURL, wsURL, rpcRequestSemaphore, opts...) backendNames = append(backendNames, name) backendsByName[name] = back log.Info("configured backend", "name", name, "backend_names", backendNames, "rpc_url", rpcURL, "ws_url", wsURL) } backendGroups := make(map[string]*BackendGroup) for bgName, bg := range config.BackendGroups { backends := make([]*Backend, 0) fallbackBackends := make(map[string]bool) fallbackCount := 0 for _, bName := range bg.Backends { if backendsByName[bName] == nil { return nil, nil, fmt.Errorf("backend %s is not defined", bName) } backends = append(backends, backendsByName[bName]) for _, fb := range bg.Fallbacks { if bName == fb { fallbackBackends[bName] = true log.Info("configured backend as fallback", "backend_name", bName, "backend_group", bgName, ) fallbackCount++ } } if _, ok := fallbackBackends[bName]; !ok { fallbackBackends[bName] = false log.Info("configured backend as primary", "backend_name", bName, "backend_group", bgName, ) } } if fallbackCount != len(bg.Fallbacks) { return nil, nil, fmt.Errorf( "error: number of fallbacks instantiated (%d) did not match configured (%d) for backend group %s", fallbackCount, len(bg.Fallbacks), bgName, ) } backendGroups[bgName] = &BackendGroup{ Name: bgName, Backends: backends, WeightedRouting: bg.WeightedRouting, FallbackBackends: fallbackBackends, routingStrategy: bg.RoutingStrategy, } } var wsBackendGroup *BackendGroup if config.WSBackendGroup != "" { wsBackendGroup = backendGroups[config.WSBackendGroup] if wsBackendGroup == nil { return nil, nil, fmt.Errorf("ws backend group %s does not exist", config.WSBackendGroup) } } if wsBackendGroup == nil && config.Server.WSPort != 0 { return nil, nil, fmt.Errorf("a ws port was defined, but no ws group was defined") } for _, bg := range config.RPCMethodMappings { if backendGroups[bg] == nil { return nil, nil, fmt.Errorf("undefined backend group %s", bg) } } var resolvedAuth map[string]string if config.Authentication != nil { resolvedAuth = make(map[string]string) for secret, alias := range config.Authentication { resolvedSecret, err := ReadFromEnvOrConfig(secret) if err != nil { return nil, nil, err } resolvedAuth[resolvedSecret] = alias } } var ( cache Cache rpcCache RPCCache ) if config.Cache.Enabled { if redisClient == nil { log.Warn("redis is not configured, using in-memory cache") cache = newMemoryCache() } else { ttl := defaultCacheTtl if config.Cache.TTL != 0 { ttl = time.Duration(config.Cache.TTL) } cache = newRedisCache(redisClient, redisReadClient, config.Redis.Namespace, ttl) } rpcCache = newRPCCache(newCacheWithCompression(cache)) } srv, err := NewServer( backendGroups, wsBackendGroup, NewStringSetFromStrings(config.WSMethodWhitelist), config.RPCMethodMappings, config.Server.MaxBodySizeBytes, resolvedAuth, secondsToDuration(config.Server.TimeoutSeconds), config.Server.MaxUpstreamBatchSize, config.Server.EnableXServedByHeader, rpcCache, config.RateLimit, config.SenderRateLimit, config.Server.EnableRequestLog, config.Server.MaxRequestBodyLogLen, config.BatchConfig.MaxSize, redisClient, ) if err != nil { return nil, nil, fmt.Errorf("error creating server: %w", err) } // Enable to support browser websocket connections. // See https://pkg.go.dev/github.com/gorilla/websocket#hdr-Origin_Considerations if config.Server.AllowAllOrigins { srv.upgrader.CheckOrigin = func(r *http.Request) bool { return true } } if config.Metrics.Enabled { addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port) log.Info("starting metrics server", "addr", addr) go func() { if err := http.ListenAndServe(addr, promhttp.Handler()); err != nil { log.Error("error starting metrics server", "err", err) } }() } // To allow integration tests to cleanly come up, wait // 10ms to give the below goroutines enough time to // encounter an error creating their servers errTimer := time.NewTimer(10 * time.Millisecond) if config.Server.RPCPort != 0 { go func() { if err := srv.RPCListenAndServe(config.Server); err != nil { if errors.Is(err, http.ErrServerClosed) { log.Info("RPC server shut down") return } log.Crit("error starting RPC server", "err", err) } }() } if config.Server.WSPort != 0 { go func() { if err := srv.WSListenAndServe(config.Server.WSHost, config.Server.WSPort); err != nil { if errors.Is(err, http.ErrServerClosed) { log.Info("WS server shut down") return } log.Crit("error starting WS server", "err", err) } }() } else if !config.Server.EnableWS { log.Info("WS server not enabled (ws_port is set to 0)") } for bgName, bg := range backendGroups { bgcfg := config.BackendGroups[bgName] if !bgcfg.ValidateRoutingStrategy(bgName) { log.Crit("Invalid routing strategy provided. Valid options: fallback, multicall, consensus_aware, \"\"", "name", bgName) } log.Info("configuring routing strategy for backend_group", "name", bgName, "routing_strategy", bgcfg.RoutingStrategy) if bgcfg.RoutingStrategy == ConsensusAwareRoutingStrategy { log.Info("creating poller for consensus aware backend_group", "name", bgName) copts := make([]ConsensusOpt, 0) if bgcfg.ConsensusAsyncHandler == "noop" { copts = append(copts, WithAsyncHandler(NewNoopAsyncHandler())) } if bgcfg.ConsensusBanPeriod > 0 { copts = append(copts, WithBanPeriod(time.Duration(bgcfg.ConsensusBanPeriod))) } if bgcfg.ConsensusMaxUpdateThreshold > 0 { copts = append(copts, WithMaxUpdateThreshold(time.Duration(bgcfg.ConsensusMaxUpdateThreshold))) } if bgcfg.ConsensusMaxBlockLag > 0 { copts = append(copts, WithMaxBlockLag(bgcfg.ConsensusMaxBlockLag)) } if bgcfg.ConsensusMinPeerCount > 0 { copts = append(copts, WithMinPeerCount(uint64(bgcfg.ConsensusMinPeerCount))) } if bgcfg.ConsensusMaxBlockRange > 0 { copts = append(copts, WithMaxBlockRange(bgcfg.ConsensusMaxBlockRange)) } if bgcfg.ConsensusPollerInterval > 0 { copts = append(copts, WithPollerInterval(time.Duration(bgcfg.ConsensusPollerInterval))) } for _, be := range bgcfg.Backends { if fallback, ok := bg.FallbackBackends[be]; !ok { log.Crit("error backend not found in backend fallback configurations", "backend_name", be) } else { log.Debug("configuring new backend for group", "backend_group", bgName, "backend_name", be, "fallback", fallback) RecordBackendGroupFallbacks(bg, be, fallback) } } var tracker ConsensusTracker if bgcfg.ConsensusHA { if bgcfg.ConsensusHARedis.URL == "" { log.Crit("must specify a consensus_ha_redis config when consensus_ha is true") } topts := make([]RedisConsensusTrackerOpt, 0) if bgcfg.ConsensusHALockPeriod > 0 { topts = append(topts, WithLockPeriod(time.Duration(bgcfg.ConsensusHALockPeriod))) } if bgcfg.ConsensusHAHeartbeatInterval > 0 { topts = append(topts, WithHeartbeatInterval(time.Duration(bgcfg.ConsensusHAHeartbeatInterval))) } consensusHARedisClient, err := NewRedisClient(bgcfg.ConsensusHARedis.URL) if err != nil { return nil, nil, err } ns := fmt.Sprintf("%s:%s", bgcfg.ConsensusHARedis.Namespace, bg.Name) tracker = NewRedisConsensusTracker(context.Background(), consensusHARedisClient, bg, ns, topts...) copts = append(copts, WithTracker(tracker)) } cp := NewConsensusPoller(bg, copts...) bg.Consensus = cp if bgcfg.ConsensusHA { tracker.(*RedisConsensusTracker).Init() } } } <-errTimer.C log.Info("started proxyd") shutdownFunc := func() { log.Info("shutting down proxyd") srv.Shutdown() log.Info("goodbye") } return srv, shutdownFunc, nil } func validateReceiptsTarget(val string) (string, error) { if val == "" { val = ReceiptsTargetDebugGetRawReceipts } switch val { case ReceiptsTargetDebugGetRawReceipts, ReceiptsTargetAlchemyGetTransactionReceipts, ReceiptsTargetEthGetTransactionReceipts, ReceiptsTargetParityGetTransactionReceipts: return val, nil default: return "", fmt.Errorf("invalid receipts target: %s", val) } } func secondsToDuration(seconds int) time.Duration { return time.Duration(seconds) * time.Second } func configureBackendTLS(cfg *BackendConfig) (*tls.Config, error) { if cfg.CAFile == "" { return nil, nil } tlsConfig, err := CreateTLSClient(cfg.CAFile) if err != nil { return nil, err } if cfg.ClientCertFile != "" && cfg.ClientKeyFile != "" { cert, err := ParseKeyPair(cfg.ClientCertFile, cfg.ClientKeyFile) if err != nil { return nil, err } tlsConfig.Certificates = []tls.Certificate{cert} } return tlsConfig, nil }