node: refactor package node (#21105)

This PR significantly changes the APIs for instantiating Ethereum nodes in
a Go program. The new APIs are not backwards-compatible, but we feel that
this is made up for by the much simpler way of registering services on
node.Node. You can find more information and rationale in the design
document: https://gist.github.com/renaynay/5bec2de19fde66f4d04c535fd24f0775.

There is also a new feature in Node's Go API: it is now possible to
register arbitrary handlers on the user-facing HTTP server. In geth, this
facility is used to enable GraphQL.

There is a single minor change relevant for geth users in this PR: The
GraphQL API is no longer available separately from the JSON-RPC HTTP
server. If you want GraphQL, you need to enable it using the
./geth --http --graphql flag combination.

The --graphql.port and --graphql.addr flags are no longer available.
This commit is contained in:
rene 2020-08-03 19:40:46 +02:00 committed by GitHub
parent b2b14e6ce3
commit c0c01612e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 2606 additions and 2887 deletions

@ -235,23 +235,20 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Assemble the Ethereum light client protocol // Assemble the Ethereum light client protocol
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { cfg := eth.DefaultConfig
cfg := eth.DefaultConfig cfg.SyncMode = downloader.LightSync
cfg.SyncMode = downloader.LightSync cfg.NetworkId = network
cfg.NetworkId = network cfg.Genesis = genesis
cfg.Genesis = genesis lesBackend, err := les.New(stack, &cfg)
return les.New(ctx, &cfg) if err != nil {
}); err != nil { return nil, fmt.Errorf("Failed to register the Ethereum service: %w", err)
return nil, err
} }
// Assemble the ethstats monitoring and reporting service' // Assemble the ethstats monitoring and reporting service'
if stats != "" { if stats != "" {
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { if err := ethstats.New(stack, lesBackend.ApiBackend, lesBackend.Engine(), stats); err != nil {
var serv *les.LightEthereum
ctx.Service(&serv)
return ethstats.New(stats, nil, serv)
}); err != nil {
return nil, err return nil, err
} }
} }
@ -268,7 +265,7 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u
// Attach to the client and retrieve and interesting metadatas // Attach to the client and retrieve and interesting metadatas
api, err := stack.Attach() api, err := stack.Attach()
if err != nil { if err != nil {
stack.Stop() stack.Close()
return nil, err return nil, err
} }
client := ethclient.NewClient(api) client := ethclient.NewClient(api)

@ -239,8 +239,9 @@ func initGenesis(ctx *cli.Context) error {
if err := json.NewDecoder(file).Decode(genesis); err != nil { if err := json.NewDecoder(file).Decode(genesis); err != nil {
utils.Fatalf("invalid genesis file: %v", err) utils.Fatalf("invalid genesis file: %v", err)
} }
// Open an initialise both full and light databases // Open an initialise both full and light databases
stack := makeFullNode(ctx) stack, _ := makeConfigNode(ctx)
defer stack.Close() defer stack.Close()
for _, name := range []string{"chaindata", "lightchaindata"} { for _, name := range []string{"chaindata", "lightchaindata"} {
@ -277,7 +278,7 @@ func importChain(ctx *cli.Context) error {
utils.SetupMetrics(ctx) utils.SetupMetrics(ctx)
// Start system runtime metrics collection // Start system runtime metrics collection
go metrics.CollectProcessMetrics(3 * time.Second) go metrics.CollectProcessMetrics(3 * time.Second)
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
chain, db := utils.MakeChain(ctx, stack, false) chain, db := utils.MakeChain(ctx, stack, false)
@ -371,7 +372,7 @@ func exportChain(ctx *cli.Context) error {
if len(ctx.Args()) < 1 { if len(ctx.Args()) < 1 {
utils.Fatalf("This command requires an argument.") utils.Fatalf("This command requires an argument.")
} }
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
chain, _ := utils.MakeChain(ctx, stack, true) chain, _ := utils.MakeChain(ctx, stack, true)
@ -406,7 +407,7 @@ func importPreimages(ctx *cli.Context) error {
if len(ctx.Args()) < 1 { if len(ctx.Args()) < 1 {
utils.Fatalf("This command requires an argument.") utils.Fatalf("This command requires an argument.")
} }
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
db := utils.MakeChainDatabase(ctx, stack) db := utils.MakeChainDatabase(ctx, stack)
@ -424,7 +425,7 @@ func exportPreimages(ctx *cli.Context) error {
if len(ctx.Args()) < 1 { if len(ctx.Args()) < 1 {
utils.Fatalf("This command requires an argument.") utils.Fatalf("This command requires an argument.")
} }
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
db := utils.MakeChainDatabase(ctx, stack) db := utils.MakeChainDatabase(ctx, stack)
@ -446,7 +447,7 @@ func copyDb(ctx *cli.Context) error {
utils.Fatalf("Source ancient chain directory path argument missing") utils.Fatalf("Source ancient chain directory path argument missing")
} }
// Initialize a new chain for the running node to sync into // Initialize a new chain for the running node to sync into
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
chain, chainDb := utils.MakeChain(ctx, stack, false) chain, chainDb := utils.MakeChain(ctx, stack, false)
@ -554,7 +555,7 @@ func confirmAndRemoveDB(database string, kind string) {
} }
func dump(ctx *cli.Context) error { func dump(ctx *cli.Context) error {
stack := makeFullNode(ctx) stack, _ := makeFullNode(ctx)
defer stack.Close() defer stack.Close()
chain, chainDb := utils.MakeChain(ctx, stack, true) chain, chainDb := utils.MakeChain(ctx, stack, true)

@ -28,6 +28,7 @@ import (
"github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
whisper "github.com/ethereum/go-ethereum/whisper/whisperv6" whisper "github.com/ethereum/go-ethereum/whisper/whisperv6"
@ -144,9 +145,10 @@ func enableWhisper(ctx *cli.Context) bool {
return false return false
} }
func makeFullNode(ctx *cli.Context) *node.Node { func makeFullNode(ctx *cli.Context) (*node.Node, ethapi.Backend) {
stack, cfg := makeConfigNode(ctx) stack, cfg := makeConfigNode(ctx)
utils.RegisterEthService(stack, &cfg.Eth)
backend := utils.RegisterEthService(stack, &cfg.Eth)
// Whisper must be explicitly enabled by specifying at least 1 whisper flag or in dev mode // Whisper must be explicitly enabled by specifying at least 1 whisper flag or in dev mode
shhEnabled := enableWhisper(ctx) shhEnabled := enableWhisper(ctx)
@ -165,13 +167,13 @@ func makeFullNode(ctx *cli.Context) *node.Node {
} }
// Configure GraphQL if requested // Configure GraphQL if requested
if ctx.GlobalIsSet(utils.GraphQLEnabledFlag.Name) { if ctx.GlobalIsSet(utils.GraphQLEnabledFlag.Name) {
utils.RegisterGraphQLService(stack, cfg.Node.GraphQLEndpoint(), cfg.Node.GraphQLCors, cfg.Node.GraphQLVirtualHosts, cfg.Node.HTTPTimeouts) utils.RegisterGraphQLService(stack, backend, cfg.Node)
} }
// Add the Ethereum Stats daemon if requested. // Add the Ethereum Stats daemon if requested.
if cfg.Ethstats.URL != "" { if cfg.Ethstats.URL != "" {
utils.RegisterEthStatsService(stack, cfg.Ethstats.URL) utils.RegisterEthStatsService(stack, backend, cfg.Ethstats.URL)
} }
return stack return stack, backend
} }
// dumpConfig is the dumpconfig command. // dumpConfig is the dumpconfig command.

@ -78,12 +78,12 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/JavaScript-Cons
func localConsole(ctx *cli.Context) error { func localConsole(ctx *cli.Context) error {
// Create and start the node based on the CLI flags // Create and start the node based on the CLI flags
prepare(ctx) prepare(ctx)
node := makeFullNode(ctx) stack, backend := makeFullNode(ctx)
startNode(ctx, node) startNode(ctx, stack, backend)
defer node.Close() defer stack.Close()
// Attach to the newly started node and start the JavaScript console // Attach to the newly started node and start the JavaScript console
client, err := node.Attach() client, err := stack.Attach()
if err != nil { if err != nil {
utils.Fatalf("Failed to attach to the inproc geth: %v", err) utils.Fatalf("Failed to attach to the inproc geth: %v", err)
} }
@ -190,12 +190,12 @@ func dialRPC(endpoint string) (*rpc.Client, error) {
// everything down. // everything down.
func ephemeralConsole(ctx *cli.Context) error { func ephemeralConsole(ctx *cli.Context) error {
// Create and start the node based on the CLI flags // Create and start the node based on the CLI flags
node := makeFullNode(ctx) stack, backend := makeFullNode(ctx)
startNode(ctx, node) startNode(ctx, stack, backend)
defer node.Close() defer stack.Close()
// Attach to the newly started node and start the JavaScript console // Attach to the newly started node and start the JavaScript console
client, err := node.Attach() client, err := stack.Attach()
if err != nil { if err != nil {
utils.Fatalf("Failed to attach to the inproc geth: %v", err) utils.Fatalf("Failed to attach to the inproc geth: %v", err)
} }

@ -119,8 +119,7 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc
} else { } else {
// Force chain initialization // Force chain initialization
args := []string{"--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--datadir", datadir} args := []string{"--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--datadir", datadir}
geth := runGeth(t, append(args, []string{"--exec", "2+2", "console"}...)...) runGeth(t, append(args, []string{"--exec", "2+2", "console"}...)...).WaitExit()
geth.WaitExit()
} }
// Retrieve the DAO config flag from the database // Retrieve the DAO config flag from the database
path := filepath.Join(datadir, "geth", "chaindata") path := filepath.Join(datadir, "geth", "chaindata")

@ -36,8 +36,8 @@ import (
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/internal/debug" "github.com/ethereum/go-ethereum/internal/debug"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/internal/flags" "github.com/ethereum/go-ethereum/internal/flags"
"github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
@ -171,8 +171,6 @@ var (
utils.LegacyRPCCORSDomainFlag, utils.LegacyRPCCORSDomainFlag,
utils.LegacyRPCVirtualHostsFlag, utils.LegacyRPCVirtualHostsFlag,
utils.GraphQLEnabledFlag, utils.GraphQLEnabledFlag,
utils.GraphQLListenAddrFlag,
utils.GraphQLPortFlag,
utils.GraphQLCORSDomainFlag, utils.GraphQLCORSDomainFlag,
utils.GraphQLVirtualHostsFlag, utils.GraphQLVirtualHostsFlag,
utils.HTTPApiFlag, utils.HTTPApiFlag,
@ -350,18 +348,20 @@ func geth(ctx *cli.Context) error {
if args := ctx.Args(); len(args) > 0 { if args := ctx.Args(); len(args) > 0 {
return fmt.Errorf("invalid command: %q", args[0]) return fmt.Errorf("invalid command: %q", args[0])
} }
prepare(ctx) prepare(ctx)
node := makeFullNode(ctx) stack, backend := makeFullNode(ctx)
defer node.Close() defer stack.Close()
startNode(ctx, node)
node.Wait() startNode(ctx, stack, backend)
stack.Wait()
return nil return nil
} }
// startNode boots up the system node and all registered protocols, after which // startNode boots up the system node and all registered protocols, after which
// it unlocks any requested accounts, and starts the RPC/IPC interfaces and the // it unlocks any requested accounts, and starts the RPC/IPC interfaces and the
// miner. // miner.
func startNode(ctx *cli.Context, stack *node.Node) { func startNode(ctx *cli.Context, stack *node.Node, backend ethapi.Backend) {
debug.Memsize.Add("node", stack) debug.Memsize.Add("node", stack)
// Start up the node itself // Start up the node itself
@ -381,25 +381,6 @@ func startNode(ctx *cli.Context, stack *node.Node) {
} }
ethClient := ethclient.NewClient(rpcClient) ethClient := ethclient.NewClient(rpcClient)
// Set contract backend for ethereum service if local node
// is serving LES requests.
if ctx.GlobalInt(utils.LegacyLightServFlag.Name) > 0 || ctx.GlobalInt(utils.LightServeFlag.Name) > 0 {
var ethService *eth.Ethereum
if err := stack.Service(&ethService); err != nil {
utils.Fatalf("Failed to retrieve ethereum service: %v", err)
}
ethService.SetContractBackend(ethClient)
}
// Set contract backend for les service if local node is
// running as a light client.
if ctx.GlobalString(utils.SyncModeFlag.Name) == "light" {
var lesService *les.LightEthereum
if err := stack.Service(&lesService); err != nil {
utils.Fatalf("Failed to retrieve light ethereum service: %v", err)
}
lesService.SetContractBackend(ethClient)
}
go func() { go func() {
// Open any wallets already attached // Open any wallets already attached
for _, wallet := range stack.AccountManager().Wallets() { for _, wallet := range stack.AccountManager().Wallets() {
@ -451,7 +432,7 @@ func startNode(ctx *cli.Context, stack *node.Node) {
if timestamp := time.Unix(int64(done.Latest.Time), 0); time.Since(timestamp) < 10*time.Minute { if timestamp := time.Unix(int64(done.Latest.Time), 0); time.Since(timestamp) < 10*time.Minute {
log.Info("Synchronisation completed", "latestnum", done.Latest.Number, "latesthash", done.Latest.Hash(), log.Info("Synchronisation completed", "latestnum", done.Latest.Number, "latesthash", done.Latest.Hash(),
"age", common.PrettyAge(timestamp)) "age", common.PrettyAge(timestamp))
stack.Stop() stack.Close()
} }
} }
}() }()
@ -463,24 +444,24 @@ func startNode(ctx *cli.Context, stack *node.Node) {
if ctx.GlobalString(utils.SyncModeFlag.Name) == "light" { if ctx.GlobalString(utils.SyncModeFlag.Name) == "light" {
utils.Fatalf("Light clients do not support mining") utils.Fatalf("Light clients do not support mining")
} }
var ethereum *eth.Ethereum ethBackend, ok := backend.(*eth.EthAPIBackend)
if err := stack.Service(&ethereum); err != nil { if !ok {
utils.Fatalf("Ethereum service not running: %v", err) utils.Fatalf("Ethereum service not running: %v", err)
} }
// Set the gas price to the limits from the CLI and start mining // Set the gas price to the limits from the CLI and start mining
gasprice := utils.GlobalBig(ctx, utils.MinerGasPriceFlag.Name) gasprice := utils.GlobalBig(ctx, utils.MinerGasPriceFlag.Name)
if ctx.GlobalIsSet(utils.LegacyMinerGasPriceFlag.Name) && !ctx.GlobalIsSet(utils.MinerGasPriceFlag.Name) { if ctx.GlobalIsSet(utils.LegacyMinerGasPriceFlag.Name) && !ctx.GlobalIsSet(utils.MinerGasPriceFlag.Name) {
gasprice = utils.GlobalBig(ctx, utils.LegacyMinerGasPriceFlag.Name) gasprice = utils.GlobalBig(ctx, utils.LegacyMinerGasPriceFlag.Name)
} }
ethereum.TxPool().SetGasPrice(gasprice) ethBackend.TxPool().SetGasPrice(gasprice)
// start mining
threads := ctx.GlobalInt(utils.MinerThreadsFlag.Name) threads := ctx.GlobalInt(utils.MinerThreadsFlag.Name)
if ctx.GlobalIsSet(utils.LegacyMinerThreadsFlag.Name) && !ctx.GlobalIsSet(utils.MinerThreadsFlag.Name) { if ctx.GlobalIsSet(utils.LegacyMinerThreadsFlag.Name) && !ctx.GlobalIsSet(utils.MinerThreadsFlag.Name) {
threads = ctx.GlobalInt(utils.LegacyMinerThreadsFlag.Name) threads = ctx.GlobalInt(utils.LegacyMinerThreadsFlag.Name)
log.Warn("The flag --minerthreads is deprecated and will be removed in the future, please use --miner.threads") log.Warn("The flag --minerthreads is deprecated and will be removed in the future, please use --miner.threads")
} }
if err := ethBackend.StartMining(threads); err != nil {
if err := ethereum.StartMining(threads); err != nil {
utils.Fatalf("Failed to start mining: %v", err) utils.Fatalf("Failed to start mining: %v", err)
} }
} }

@ -142,8 +142,6 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.WSApiFlag, utils.WSApiFlag,
utils.WSAllowedOriginsFlag, utils.WSAllowedOriginsFlag,
utils.GraphQLEnabledFlag, utils.GraphQLEnabledFlag,
utils.GraphQLListenAddrFlag,
utils.GraphQLPortFlag,
utils.GraphQLCORSDomainFlag, utils.GraphQLCORSDomainFlag,
utils.GraphQLVirtualHostsFlag, utils.GraphQLVirtualHostsFlag,
utils.RPCGlobalGasCap, utils.RPCGlobalGasCap,
@ -231,6 +229,8 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.LegacyWSApiFlag, utils.LegacyWSApiFlag,
utils.LegacyGpoBlocksFlag, utils.LegacyGpoBlocksFlag,
utils.LegacyGpoPercentileFlag, utils.LegacyGpoPercentileFlag,
utils.LegacyGraphQLListenAddrFlag,
utils.LegacyGraphQLPortFlag,
}, debug.DeprecatedFlags...), }, debug.DeprecatedFlags...),
}, },
{ {

@ -289,7 +289,7 @@ func createNode(ctx *cli.Context) error {
config.PrivateKey = privKey config.PrivateKey = privKey
} }
if services := ctx.String("services"); services != "" { if services := ctx.String("services"); services != "" {
config.Services = strings.Split(services, ",") config.Lifecycles = strings.Split(services, ",")
} }
node, err := client.CreateNode(config) node, err := client.CreateNode(config)
if err != nil { if err != nil {

@ -73,7 +73,7 @@ func StartNode(stack *node.Node) {
defer signal.Stop(sigc) defer signal.Stop(sigc)
<-sigc <-sigc
log.Info("Got interrupt, shutting down...") log.Info("Got interrupt, shutting down...")
go stack.Stop() go stack.Close()
for i := 10; i > 0; i-- { for i := 10; i > 0; i-- {
<-sigc <-sigc
if i > 1 { if i > 1 {

@ -19,7 +19,6 @@ package utils
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -49,6 +48,7 @@ import (
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethstats" "github.com/ethereum/go-ethereum/ethstats"
"github.com/ethereum/go-ethereum/graphql" "github.com/ethereum/go-ethereum/graphql"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/internal/flags" "github.com/ethereum/go-ethereum/internal/flags"
"github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
@ -63,7 +63,6 @@ import (
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
whisper "github.com/ethereum/go-ethereum/whisper/whisperv6" whisper "github.com/ethereum/go-ethereum/whisper/whisperv6"
pcsclite "github.com/gballet/go-libpcsclite" pcsclite "github.com/gballet/go-libpcsclite"
cli "gopkg.in/urfave/cli.v1" cli "gopkg.in/urfave/cli.v1"
@ -517,6 +516,20 @@ var (
Usage: "API's offered over the HTTP-RPC interface", Usage: "API's offered over the HTTP-RPC interface",
Value: "", Value: "",
} }
GraphQLEnabledFlag = cli.BoolFlag{
Name: "graphql",
Usage: "Enable GraphQL on the HTTP-RPC server. Note that GraphQL can only be started if an HTTP server is started as well.",
}
GraphQLCORSDomainFlag = cli.StringFlag{
Name: "graphql.corsdomain",
Usage: "Comma separated list of domains from which to accept cross origin requests (browser enforced)",
Value: "",
}
GraphQLVirtualHostsFlag = cli.StringFlag{
Name: "graphql.vhosts",
Usage: "Comma separated list of virtual hostnames from which to accept requests (server enforced). Accepts '*' wildcard.",
Value: strings.Join(node.DefaultConfig.GraphQLVirtualHosts, ","),
}
WSEnabledFlag = cli.BoolFlag{ WSEnabledFlag = cli.BoolFlag{
Name: "ws", Name: "ws",
Usage: "Enable the WS-RPC server", Usage: "Enable the WS-RPC server",
@ -541,30 +554,6 @@ var (
Usage: "Origins from which to accept websockets requests", Usage: "Origins from which to accept websockets requests",
Value: "", Value: "",
} }
GraphQLEnabledFlag = cli.BoolFlag{
Name: "graphql",
Usage: "Enable the GraphQL server",
}
GraphQLListenAddrFlag = cli.StringFlag{
Name: "graphql.addr",
Usage: "GraphQL server listening interface",
Value: node.DefaultGraphQLHost,
}
GraphQLPortFlag = cli.IntFlag{
Name: "graphql.port",
Usage: "GraphQL server listening port",
Value: node.DefaultGraphQLPort,
}
GraphQLCORSDomainFlag = cli.StringFlag{
Name: "graphql.corsdomain",
Usage: "Comma separated list of domains from which to accept cross origin requests (browser enforced)",
Value: "",
}
GraphQLVirtualHostsFlag = cli.StringFlag{
Name: "graphql.vhosts",
Usage: "Comma separated list of virtual hostnames from which to accept requests (server enforced). Accepts '*' wildcard.",
Value: strings.Join(node.DefaultConfig.GraphQLVirtualHosts, ","),
}
ExecFlag = cli.StringFlag{ ExecFlag = cli.StringFlag{
Name: "exec", Name: "exec",
Usage: "Execute JavaScript statement", Usage: "Execute JavaScript statement",
@ -951,13 +940,6 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
// setGraphQL creates the GraphQL listener interface string from the set // setGraphQL creates the GraphQL listener interface string from the set
// command line flags, returning empty if the GraphQL endpoint is disabled. // command line flags, returning empty if the GraphQL endpoint is disabled.
func setGraphQL(ctx *cli.Context, cfg *node.Config) { func setGraphQL(ctx *cli.Context, cfg *node.Config) {
if ctx.GlobalBool(GraphQLEnabledFlag.Name) && cfg.GraphQLHost == "" {
cfg.GraphQLHost = "127.0.0.1"
if ctx.GlobalIsSet(GraphQLListenAddrFlag.Name) {
cfg.GraphQLHost = ctx.GlobalString(GraphQLListenAddrFlag.Name)
}
}
cfg.GraphQLPort = ctx.GlobalInt(GraphQLPortFlag.Name)
if ctx.GlobalIsSet(GraphQLCORSDomainFlag.Name) { if ctx.GlobalIsSet(GraphQLCORSDomainFlag.Name) {
cfg.GraphQLCors = splitAndTrim(ctx.GlobalString(GraphQLCORSDomainFlag.Name)) cfg.GraphQLCors = splitAndTrim(ctx.GlobalString(GraphQLCORSDomainFlag.Name))
} }
@ -1692,70 +1674,46 @@ func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) {
} }
// RegisterEthService adds an Ethereum client to the stack. // RegisterEthService adds an Ethereum client to the stack.
func RegisterEthService(stack *node.Node, cfg *eth.Config) { func RegisterEthService(stack *node.Node, cfg *eth.Config) ethapi.Backend {
var err error
if cfg.SyncMode == downloader.LightSync { if cfg.SyncMode == downloader.LightSync {
err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { backend, err := les.New(stack, cfg)
return les.New(ctx, cfg) if err != nil {
}) Fatalf("Failed to register the Ethereum service: %v", err)
}
return backend.ApiBackend
} else { } else {
err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { backend, err := eth.New(stack, cfg)
fullNode, err := eth.New(ctx, cfg) if err != nil {
if fullNode != nil && cfg.LightServ > 0 { Fatalf("Failed to register the Ethereum service: %v", err)
ls, _ := les.NewLesServer(fullNode, cfg) }
fullNode.AddLesServer(ls) if cfg.LightServ > 0 {
_, err := les.NewLesServer(stack, backend, cfg)
if err != nil {
Fatalf("Failed to create the LES server: %v", err)
} }
return fullNode, err }
}) return backend.APIBackend
}
if err != nil {
Fatalf("Failed to register the Ethereum service: %v", err)
} }
} }
// RegisterShhService configures Whisper and adds it to the given node. // RegisterShhService configures Whisper and adds it to the given node.
func RegisterShhService(stack *node.Node, cfg *whisper.Config) { func RegisterShhService(stack *node.Node, cfg *whisper.Config) {
if err := stack.Register(func(n *node.ServiceContext) (node.Service, error) { if _, err := whisper.New(stack, cfg); err != nil {
return whisper.New(cfg), nil
}); err != nil {
Fatalf("Failed to register the Whisper service: %v", err) Fatalf("Failed to register the Whisper service: %v", err)
} }
} }
// RegisterEthStatsService configures the Ethereum Stats daemon and adds it to // RegisterEthStatsService configures the Ethereum Stats daemon and adds it to
// the given node. // the given node.
func RegisterEthStatsService(stack *node.Node, url string) { func RegisterEthStatsService(stack *node.Node, backend ethapi.Backend, url string) {
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { if err := ethstats.New(stack, backend, backend.Engine(), url); err != nil {
// Retrieve both eth and les services
var ethServ *eth.Ethereum
ctx.Service(&ethServ)
var lesServ *les.LightEthereum
ctx.Service(&lesServ)
// Let ethstats use whichever is not nil
return ethstats.New(url, ethServ, lesServ)
}); err != nil {
Fatalf("Failed to register the Ethereum Stats service: %v", err) Fatalf("Failed to register the Ethereum Stats service: %v", err)
} }
} }
// RegisterGraphQLService is a utility function to construct a new service and register it against a node. // RegisterGraphQLService is a utility function to construct a new service and register it against a node.
func RegisterGraphQLService(stack *node.Node, endpoint string, cors, vhosts []string, timeouts rpc.HTTPTimeouts) { func RegisterGraphQLService(stack *node.Node, backend ethapi.Backend, cfg node.Config) {
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { if err := graphql.New(stack, backend, cfg.GraphQLCors, cfg.GraphQLVirtualHosts); err != nil {
// Try to construct the GraphQL service backed by a full node
var ethServ *eth.Ethereum
if err := ctx.Service(&ethServ); err == nil {
return graphql.New(ethServ.APIBackend, endpoint, cors, vhosts, timeouts)
}
// Try to construct the GraphQL service backed by a light node
var lesServ *les.LightEthereum
if err := ctx.Service(&lesServ); err == nil {
return graphql.New(lesServ.ApiBackend, endpoint, cors, vhosts, timeouts)
}
// Well, this should not have happened, bail out
return nil, errors.New("no Ethereum service")
}); err != nil {
Fatalf("Failed to register the GraphQL service: %v", err) Fatalf("Failed to register the GraphQL service: %v", err)
} }
} }

@ -89,6 +89,8 @@ var (
Name: "testnet", Name: "testnet",
Usage: "Pre-configured test network (Deprecated: Please choose one of --goerli, --rinkeby, or --ropsten.)", Usage: "Pre-configured test network (Deprecated: Please choose one of --goerli, --rinkeby, or --ropsten.)",
} }
// (Deprecated May 2020, shown in aliased flags section)
LegacyRPCEnabledFlag = cli.BoolFlag{ LegacyRPCEnabledFlag = cli.BoolFlag{
Name: "rpc", Name: "rpc",
Usage: "Enable the HTTP-RPC server (deprecated, use --http)", Usage: "Enable the HTTP-RPC server (deprecated, use --http)",
@ -158,6 +160,17 @@ var (
Usage: "Comma separated enode URLs for P2P v5 discovery bootstrap (light server, light nodes) (deprecated, use --bootnodes)", Usage: "Comma separated enode URLs for P2P v5 discovery bootstrap (light server, light nodes) (deprecated, use --bootnodes)",
Value: "", Value: "",
} }
// (Deprecated July 2020, shown in aliased flags section)
LegacyGraphQLListenAddrFlag = cli.StringFlag{
Name: "graphql.addr",
Usage: "GraphQL server listening interface (deprecated, graphql can only be enabled on the HTTP-RPC server endpoint, use --graphql)",
}
LegacyGraphQLPortFlag = cli.IntFlag{
Name: "graphql.port",
Usage: "GraphQL server listening port (deprecated, graphql can only be enabled on the HTTP-RPC server endpoint, use --graphql)",
Value: node.DefaultHTTPPort,
}
) )
// showDeprecated displays deprecated flags that will be soon removed from the codebase. // showDeprecated displays deprecated flags that will be soon removed from the codebase.

@ -221,8 +221,7 @@ func initialize() {
MaxMessageSize: uint32(*argMaxSize), MaxMessageSize: uint32(*argMaxSize),
MinimumAcceptedPOW: *argPoW, MinimumAcceptedPOW: *argPoW,
} }
shh = whisper.StandaloneWhisperService(cfg)
shh = whisper.New(cfg)
if *argPoW != whisper.DefaultMinimumPoW { if *argPoW != whisper.DefaultMinimumPoW {
err := shh.SetMinimumPoW(*argPoW) err := shh.SetMinimumPoW(*argPoW)
@ -433,7 +432,7 @@ func run() {
return return
} }
defer server.Stop() defer server.Stop()
shh.Start(nil) shh.Start()
defer shh.Stop() defer shh.Stop()
if !*forwarderMode { if !*forwarderMode {

@ -109,7 +109,8 @@ func newTester(t *testing.T, confOverride func(*eth.Config)) *tester {
if confOverride != nil { if confOverride != nil {
confOverride(ethConf) confOverride(ethConf)
} }
if err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { return eth.New(ctx, ethConf) }); err != nil { ethBackend, err := eth.New(stack, ethConf)
if err != nil {
t.Fatalf("failed to register Ethereum protocol: %v", err) t.Fatalf("failed to register Ethereum protocol: %v", err)
} }
// Start the node and assemble the JavaScript console around it // Start the node and assemble the JavaScript console around it
@ -135,13 +136,10 @@ func newTester(t *testing.T, confOverride func(*eth.Config)) *tester {
t.Fatalf("failed to create JavaScript console: %v", err) t.Fatalf("failed to create JavaScript console: %v", err)
} }
// Create the final tester and return // Create the final tester and return
var ethereum *eth.Ethereum
stack.Service(&ethereum)
return &tester{ return &tester{
workspace: workspace, workspace: workspace,
stack: stack, stack: stack,
ethereum: ethereum, ethereum: ethBackend,
console: console, console: console,
input: prompter, input: prompter,
output: printer, output: printer,

@ -22,6 +22,7 @@ import (
"math" "math"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -74,6 +75,7 @@ type freezer struct {
tables map[string]*freezerTable // Data tables for storing everything tables map[string]*freezerTable // Data tables for storing everything
instanceLock fileutil.Releaser // File-system lock to prevent double opens instanceLock fileutil.Releaser // File-system lock to prevent double opens
quit chan struct{} quit chan struct{}
closeOnce sync.Once
} }
// newFreezer creates a chain freezer that moves ancient chain data into // newFreezer creates a chain freezer that moves ancient chain data into
@ -128,16 +130,18 @@ func newFreezer(datadir string, namespace string) (*freezer, error) {
// Close terminates the chain freezer, unmapping all the data files. // Close terminates the chain freezer, unmapping all the data files.
func (f *freezer) Close() error { func (f *freezer) Close() error {
f.quit <- struct{}{}
var errs []error var errs []error
for _, table := range f.tables { f.closeOnce.Do(func() {
if err := table.Close(); err != nil { f.quit <- struct{}{}
for _, table := range f.tables {
if err := table.Close(); err != nil {
errs = append(errs, err)
}
}
if err := f.instanceLock.Release(); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
} })
if err := f.instanceLock.Release(); err != nil {
errs = append(errs, err)
}
if errs != nil { if errs != nil {
return fmt.Errorf("%v", errs) return fmt.Errorf("%v", errs)
} }

@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
@ -33,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/miner"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
@ -257,6 +259,10 @@ func (b *EthAPIBackend) TxPoolContent() (map[common.Address]types.Transactions,
return b.eth.TxPool().Content() return b.eth.TxPool().Content()
} }
func (b *EthAPIBackend) TxPool() *core.TxPool {
return b.eth.TxPool()
}
func (b *EthAPIBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { func (b *EthAPIBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription {
return b.eth.TxPool().SubscribeNewTxsEvent(ch) return b.eth.TxPool().SubscribeNewTxsEvent(ch)
} }
@ -307,3 +313,19 @@ func (b *EthAPIBackend) ServiceFilter(ctx context.Context, session *bloombits.Ma
go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests)
} }
} }
func (b *EthAPIBackend) Engine() consensus.Engine {
return b.eth.engine
}
func (b *EthAPIBackend) CurrentHeader() *types.Header {
return b.eth.blockchain.CurrentHeader()
}
func (b *EthAPIBackend) Miner() *miner.Miner {
return b.eth.Miner()
}
func (b *EthAPIBackend) StartMining(threads int) error {
return b.eth.StartMining(threads)
}

@ -26,7 +26,6 @@ import (
"sync/atomic" "sync/atomic"
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus"
@ -54,15 +53,6 @@ import (
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
type LesServer interface {
Start(srvr *p2p.Server)
Stop()
APIs() []rpc.API
Protocols() []p2p.Protocol
SetBloomBitsIndexer(bbIndexer *core.ChainIndexer)
SetContractBackend(bind.ContractBackend)
}
// Ethereum implements the Ethereum full node service. // Ethereum implements the Ethereum full node service.
type Ethereum struct { type Ethereum struct {
config *Config config *Config
@ -71,7 +61,6 @@ type Ethereum struct {
txPool *core.TxPool txPool *core.TxPool
blockchain *core.BlockChain blockchain *core.BlockChain
protocolManager *ProtocolManager protocolManager *ProtocolManager
lesServer LesServer
dialCandidates enode.Iterator dialCandidates enode.Iterator
// DB interfaces // DB interfaces
@ -94,25 +83,14 @@ type Ethereum struct {
networkID uint64 networkID uint64
netRPCService *ethapi.PublicNetAPI netRPCService *ethapi.PublicNetAPI
p2pServer *p2p.Server
lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase) lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase)
} }
func (s *Ethereum) AddLesServer(ls LesServer) {
s.lesServer = ls
ls.SetBloomBitsIndexer(s.bloomIndexer)
}
// SetClient sets a rpc client which connecting to our local node.
func (s *Ethereum) SetContractBackend(backend bind.ContractBackend) {
// Pass the rpc client to les server if it is enabled.
if s.lesServer != nil {
s.lesServer.SetContractBackend(backend)
}
}
// New creates a new Ethereum object (including the // New creates a new Ethereum object (including the
// initialisation of the common Ethereum object) // initialisation of the common Ethereum object)
func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { func New(stack *node.Node, config *Config) (*Ethereum, error) {
// Ensure configuration values are compatible and sane // Ensure configuration values are compatible and sane
if config.SyncMode == downloader.LightSync { if config.SyncMode == downloader.LightSync {
return nil, errors.New("can't run eth.Ethereum in light sync mode, use les.LightEthereum") return nil, errors.New("can't run eth.Ethereum in light sync mode, use les.LightEthereum")
@ -136,7 +114,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
log.Info("Allocated trie memory caches", "clean", common.StorageSize(config.TrieCleanCache)*1024*1024, "dirty", common.StorageSize(config.TrieDirtyCache)*1024*1024) log.Info("Allocated trie memory caches", "clean", common.StorageSize(config.TrieCleanCache)*1024*1024, "dirty", common.StorageSize(config.TrieDirtyCache)*1024*1024)
// Assemble the Ethereum object // Assemble the Ethereum object
chainDb, err := ctx.OpenDatabaseWithFreezer("chaindata", config.DatabaseCache, config.DatabaseHandles, config.DatabaseFreezer, "eth/db/chaindata/") chainDb, err := stack.OpenDatabaseWithFreezer("chaindata", config.DatabaseCache, config.DatabaseHandles, config.DatabaseFreezer, "eth/db/chaindata/")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,15 +127,16 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
eth := &Ethereum{ eth := &Ethereum{
config: config, config: config,
chainDb: chainDb, chainDb: chainDb,
eventMux: ctx.EventMux, eventMux: stack.EventMux(),
accountManager: ctx.AccountManager, accountManager: stack.AccountManager(),
engine: CreateConsensusEngine(ctx, chainConfig, &config.Ethash, config.Miner.Notify, config.Miner.Noverify, chainDb), engine: CreateConsensusEngine(stack, chainConfig, &config.Ethash, config.Miner.Notify, config.Miner.Noverify, chainDb),
closeBloomHandler: make(chan struct{}), closeBloomHandler: make(chan struct{}),
networkID: config.NetworkId, networkID: config.NetworkId,
gasPrice: config.Miner.GasPrice, gasPrice: config.Miner.GasPrice,
etherbase: config.Miner.Etherbase, etherbase: config.Miner.Etherbase,
bloomRequests: make(chan chan *bloombits.Retrieval), bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: NewBloomIndexer(chainDb, params.BloomBitsBlocks, params.BloomConfirms), bloomIndexer: NewBloomIndexer(chainDb, params.BloomBitsBlocks, params.BloomConfirms),
p2pServer: stack.Server(),
} }
bcVersion := rawdb.ReadDatabaseVersion(chainDb) bcVersion := rawdb.ReadDatabaseVersion(chainDb)
@ -183,7 +162,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
} }
cacheConfig = &core.CacheConfig{ cacheConfig = &core.CacheConfig{
TrieCleanLimit: config.TrieCleanCache, TrieCleanLimit: config.TrieCleanCache,
TrieCleanJournal: ctx.ResolvePath(config.TrieCleanCacheJournal), TrieCleanJournal: stack.ResolvePath(config.TrieCleanCacheJournal),
TrieCleanRejournal: config.TrieCleanCacheRejournal, TrieCleanRejournal: config.TrieCleanCacheRejournal,
TrieCleanNoPrefetch: config.NoPrefetch, TrieCleanNoPrefetch: config.NoPrefetch,
TrieDirtyLimit: config.TrieDirtyCache, TrieDirtyLimit: config.TrieDirtyCache,
@ -205,7 +184,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
eth.bloomIndexer.Start(eth.blockchain) eth.bloomIndexer.Start(eth.blockchain)
if config.TxPool.Journal != "" { if config.TxPool.Journal != "" {
config.TxPool.Journal = ctx.ResolvePath(config.TxPool.Journal) config.TxPool.Journal = stack.ResolvePath(config.TxPool.Journal)
} }
eth.txPool = core.NewTxPool(config.TxPool, chainConfig, eth.blockchain) eth.txPool = core.NewTxPool(config.TxPool, chainConfig, eth.blockchain)
@ -221,18 +200,25 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock) eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock)
eth.miner.SetExtra(makeExtraData(config.Miner.ExtraData)) eth.miner.SetExtra(makeExtraData(config.Miner.ExtraData))
eth.APIBackend = &EthAPIBackend{ctx.ExtRPCEnabled(), eth, nil} eth.APIBackend = &EthAPIBackend{stack.Config().ExtRPCEnabled(), eth, nil}
gpoParams := config.GPO gpoParams := config.GPO
if gpoParams.Default == nil { if gpoParams.Default == nil {
gpoParams.Default = config.Miner.GasPrice gpoParams.Default = config.Miner.GasPrice
} }
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P) eth.dialCandidates, err = eth.setupDiscovery(&stack.Config().P2P)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Start the RPC service
eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer, eth.NetVersion())
// Register the backend on the node
stack.RegisterAPIs(eth.APIs())
stack.RegisterProtocols(eth.Protocols())
stack.RegisterLifecycle(eth)
return eth, nil return eth, nil
} }
@ -254,7 +240,7 @@ func makeExtraData(extra []byte) []byte {
} }
// CreateConsensusEngine creates the required type of consensus engine instance for an Ethereum service // CreateConsensusEngine creates the required type of consensus engine instance for an Ethereum service
func CreateConsensusEngine(ctx *node.ServiceContext, chainConfig *params.ChainConfig, config *ethash.Config, notify []string, noverify bool, db ethdb.Database) consensus.Engine { func CreateConsensusEngine(stack *node.Node, chainConfig *params.ChainConfig, config *ethash.Config, notify []string, noverify bool, db ethdb.Database) consensus.Engine {
// If proof-of-authority is requested, set it up // If proof-of-authority is requested, set it up
if chainConfig.Clique != nil { if chainConfig.Clique != nil {
return clique.New(chainConfig.Clique, db) return clique.New(chainConfig.Clique, db)
@ -272,7 +258,7 @@ func CreateConsensusEngine(ctx *node.ServiceContext, chainConfig *params.ChainCo
return ethash.NewShared() return ethash.NewShared()
default: default:
engine := ethash.New(ethash.Config{ engine := ethash.New(ethash.Config{
CacheDir: ctx.ResolvePath(config.CacheDir), CacheDir: stack.ResolvePath(config.CacheDir),
CachesInMem: config.CachesInMem, CachesInMem: config.CachesInMem,
CachesOnDisk: config.CachesOnDisk, CachesOnDisk: config.CachesOnDisk,
CachesLockMmap: config.CachesLockMmap, CachesLockMmap: config.CachesLockMmap,
@ -291,18 +277,9 @@ func CreateConsensusEngine(ctx *node.ServiceContext, chainConfig *params.ChainCo
func (s *Ethereum) APIs() []rpc.API { func (s *Ethereum) APIs() []rpc.API {
apis := ethapi.GetAPIs(s.APIBackend) apis := ethapi.GetAPIs(s.APIBackend)
// Append any APIs exposed explicitly by the les server
if s.lesServer != nil {
apis = append(apis, s.lesServer.APIs()...)
}
// Append any APIs exposed explicitly by the consensus engine // Append any APIs exposed explicitly by the consensus engine
apis = append(apis, s.engine.APIs(s.BlockChain())...) apis = append(apis, s.engine.APIs(s.BlockChain())...)
// Append any APIs exposed explicitly by the les server
if s.lesServer != nil {
apis = append(apis, s.lesServer.APIs()...)
}
// Append all the local APIs and return // Append all the local APIs and return
return append(apis, []rpc.API{ return append(apis, []rpc.API{
{ {
@ -517,8 +494,9 @@ func (s *Ethereum) NetVersion() uint64 { return s.networkID }
func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader }
func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 } func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 }
func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning } func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning }
func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer }
// Protocols implements node.Service, returning all the currently configured // Protocols returns all the currently configured
// network protocols to start. // network protocols to start.
func (s *Ethereum) Protocols() []p2p.Protocol { func (s *Ethereum) Protocols() []p2p.Protocol {
protos := make([]p2p.Protocol, len(ProtocolVersions)) protos := make([]p2p.Protocol, len(ProtocolVersions))
@ -527,47 +505,35 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
protos[i].Attributes = []enr.Entry{s.currentEthEntry()} protos[i].Attributes = []enr.Entry{s.currentEthEntry()}
protos[i].DialCandidates = s.dialCandidates protos[i].DialCandidates = s.dialCandidates
} }
if s.lesServer != nil {
protos = append(protos, s.lesServer.Protocols()...)
}
return protos return protos
} }
// Start implements node.Service, starting all internal goroutines needed by the // Start implements node.Lifecycle, starting all internal goroutines needed by the
// Ethereum protocol implementation. // Ethereum protocol implementation.
func (s *Ethereum) Start(srvr *p2p.Server) error { func (s *Ethereum) Start() error {
s.startEthEntryUpdate(srvr.LocalNode()) s.startEthEntryUpdate(s.p2pServer.LocalNode())
// Start the bloom bits servicing goroutines // Start the bloom bits servicing goroutines
s.startBloomHandlers(params.BloomBitsBlocks) s.startBloomHandlers(params.BloomBitsBlocks)
// Start the RPC service
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.NetVersion())
// Figure out a max peers count based on the server limits // Figure out a max peers count based on the server limits
maxPeers := srvr.MaxPeers maxPeers := s.p2pServer.MaxPeers
if s.config.LightServ > 0 { if s.config.LightServ > 0 {
if s.config.LightPeers >= srvr.MaxPeers { if s.config.LightPeers >= s.p2pServer.MaxPeers {
return fmt.Errorf("invalid peer config: light peer count (%d) >= total peer count (%d)", s.config.LightPeers, srvr.MaxPeers) return fmt.Errorf("invalid peer config: light peer count (%d) >= total peer count (%d)", s.config.LightPeers, s.p2pServer.MaxPeers)
} }
maxPeers -= s.config.LightPeers maxPeers -= s.config.LightPeers
} }
// Start the networking layer and the light server if requested // Start the networking layer and the light server if requested
s.protocolManager.Start(maxPeers) s.protocolManager.Start(maxPeers)
if s.lesServer != nil {
s.lesServer.Start(srvr)
}
return nil return nil
} }
// Stop implements node.Service, terminating all internal goroutines used by the // Stop implements node.Lifecycle, terminating all internal goroutines used by the
// Ethereum protocol. // Ethereum protocol.
func (s *Ethereum) Stop() error { func (s *Ethereum) Stop() error {
// Stop all the peer-related stuff first. // Stop all the peer-related stuff first.
s.protocolManager.Stop() s.protocolManager.Stop()
if s.lesServer != nil {
s.lesServer.Stop()
}
// Then stop everything else. // Then stop everything else.
s.bloomIndexer.Close() s.bloomIndexer.Close()

@ -187,17 +187,18 @@ var (
func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { func newTestBackend(t *testing.T) (*node.Node, []*types.Block) {
// Generate test chain. // Generate test chain.
genesis, blocks := generateTestChain() genesis, blocks := generateTestChain()
// Create node
// Start Ethereum service.
var ethservice *eth.Ethereum
n, err := node.New(&node.Config{}) n, err := node.New(&node.Config{})
n.Register(func(ctx *node.ServiceContext) (node.Service, error) { if err != nil {
config := &eth.Config{Genesis: genesis} t.Fatalf("can't create new node: %v", err)
config.Ethash.PowMode = ethash.ModeFake }
ethservice, err = eth.New(ctx, config) // Create Ethereum Service
return ethservice, err config := &eth.Config{Genesis: genesis}
}) config.Ethash.PowMode = ethash.ModeFake
ethservice, err := eth.New(n, config)
if err != nil {
t.Fatalf("can't create new ethereum service: %v", err)
}
// Import the test chain. // Import the test chain.
if err := n.Start(); err != nil { if err := n.Start(); err != nil {
t.Fatalf("can't start test node: %v", err) t.Fatalf("can't start test node: %v", err)
@ -231,7 +232,7 @@ func generateTestChain() (*core.Genesis, []*types.Block) {
func TestHeader(t *testing.T) { func TestHeader(t *testing.T) {
backend, chain := newTestBackend(t) backend, chain := newTestBackend(t)
client, _ := backend.Attach() client, _ := backend.Attach()
defer backend.Stop() defer backend.Close()
defer client.Close() defer client.Close()
tests := map[string]struct { tests := map[string]struct {
@ -275,7 +276,7 @@ func TestHeader(t *testing.T) {
func TestBalanceAt(t *testing.T) { func TestBalanceAt(t *testing.T) {
backend, _ := newTestBackend(t) backend, _ := newTestBackend(t)
client, _ := backend.Attach() client, _ := backend.Attach()
defer backend.Stop() defer backend.Close()
defer client.Close() defer client.Close()
tests := map[string]struct { tests := map[string]struct {
@ -321,7 +322,7 @@ func TestBalanceAt(t *testing.T) {
func TestTransactionInBlockInterrupted(t *testing.T) { func TestTransactionInBlockInterrupted(t *testing.T) {
backend, _ := newTestBackend(t) backend, _ := newTestBackend(t)
client, _ := backend.Attach() client, _ := backend.Attach()
defer backend.Stop() defer backend.Close()
defer client.Close() defer client.Close()
ec := NewClient(client) ec := NewClient(client)
@ -339,7 +340,7 @@ func TestTransactionInBlockInterrupted(t *testing.T) {
func TestChainID(t *testing.T) { func TestChainID(t *testing.T) {
backend, _ := newTestBackend(t) backend, _ := newTestBackend(t)
client, _ := backend.Attach() client, _ := backend.Attach()
defer backend.Stop() defer backend.Close()
defer client.Close() defer client.Close()
ec := NewClient(client) ec := NewClient(client)

@ -36,9 +36,12 @@ import (
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/miner"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -56,23 +59,33 @@ const (
chainHeadChanSize = 10 chainHeadChanSize = 10
) )
type txPool interface { // backend encompasses the bare-minimum functionality needed for ethstats reporting
// SubscribeNewTxsEvent should return an event subscription of type backend interface {
// NewTxsEvent and send events to the given channel. SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription
SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription
CurrentHeader() *types.Header
HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error)
GetTd(ctx context.Context, hash common.Hash) *big.Int
Stats() (pending int, queued int)
Downloader() *downloader.Downloader
} }
type blockChain interface { // fullNodeBackend encompasses the functionality necessary for a full node
SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription // reporting to ethstats
type fullNodeBackend interface {
backend
Miner() *miner.Miner
BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error)
CurrentBlock() *types.Block
SuggestPrice(ctx context.Context) (*big.Int, error)
} }
// Service implements an Ethereum netstats reporting daemon that pushes local // Service implements an Ethereum netstats reporting daemon that pushes local
// chain statistics up to a monitoring server. // chain statistics up to a monitoring server.
type Service struct { type Service struct {
server *p2p.Server // Peer-to-peer server to retrieve networking infos server *p2p.Server // Peer-to-peer server to retrieve networking infos
eth *eth.Ethereum // Full Ethereum service if monitoring a full node backend backend
les *les.LightEthereum // Light Ethereum service if monitoring a light node engine consensus.Engine // Consensus engine to retrieve variadic block fields
engine consensus.Engine // Consensus engine to retrieve variadic block fields
node string // Name of the node to display on the monitoring page node string // Name of the node to display on the monitoring page
pass string // Password to authorize access to the monitoring page pass string // Password to authorize access to the monitoring page
@ -83,50 +96,37 @@ type Service struct {
} }
// New returns a monitoring service ready for stats reporting. // New returns a monitoring service ready for stats reporting.
func New(url string, ethServ *eth.Ethereum, lesServ *les.LightEthereum) (*Service, error) { func New(node *node.Node, backend backend, engine consensus.Engine, url string) error {
// Parse the netstats connection url // Parse the netstats connection url
re := regexp.MustCompile("([^:@]*)(:([^@]*))?@(.+)") re := regexp.MustCompile("([^:@]*)(:([^@]*))?@(.+)")
parts := re.FindStringSubmatch(url) parts := re.FindStringSubmatch(url)
if len(parts) != 5 { if len(parts) != 5 {
return nil, fmt.Errorf("invalid netstats url: \"%s\", should be nodename:secret@host:port", url) return fmt.Errorf("invalid netstats url: \"%s\", should be nodename:secret@host:port", url)
} }
// Assemble and return the stats service ethstats := &Service{
var engine consensus.Engine backend: backend,
if ethServ != nil { engine: engine,
engine = ethServ.Engine() server: node.Server(),
} else { node: parts[1],
engine = lesServ.Engine() pass: parts[3],
host: parts[4],
pongCh: make(chan struct{}),
histCh: make(chan []uint64, 1),
} }
return &Service{
eth: ethServ, node.RegisterLifecycle(ethstats)
les: lesServ, return nil
engine: engine,
node: parts[1],
pass: parts[3],
host: parts[4],
pongCh: make(chan struct{}),
histCh: make(chan []uint64, 1),
}, nil
} }
// Protocols implements node.Service, returning the P2P network protocols used // Start implements node.Lifecycle, starting up the monitoring and reporting daemon.
// by the stats service (nil as it doesn't use the devp2p overlay network). func (s *Service) Start() error {
func (s *Service) Protocols() []p2p.Protocol { return nil }
// APIs implements node.Service, returning the RPC API endpoints provided by the
// stats service (nil as it doesn't provide any user callable APIs).
func (s *Service) APIs() []rpc.API { return nil }
// Start implements node.Service, starting up the monitoring and reporting daemon.
func (s *Service) Start(server *p2p.Server) error {
s.server = server
go s.loop() go s.loop()
log.Info("Stats daemon started") log.Info("Stats daemon started")
return nil return nil
} }
// Stop implements node.Service, terminating the monitoring and reporting daemon. // Stop implements node.Lifecycle, terminating the monitoring and reporting daemon.
func (s *Service) Stop() error { func (s *Service) Stop() error {
log.Info("Stats daemon stopped") log.Info("Stats daemon stopped")
return nil return nil
@ -136,22 +136,12 @@ func (s *Service) Stop() error {
// until termination. // until termination.
func (s *Service) loop() { func (s *Service) loop() {
// Subscribe to chain events to execute updates on // Subscribe to chain events to execute updates on
var blockchain blockChain
var txpool txPool
if s.eth != nil {
blockchain = s.eth.BlockChain()
txpool = s.eth.TxPool()
} else {
blockchain = s.les.BlockChain()
txpool = s.les.TxPool()
}
chainHeadCh := make(chan core.ChainHeadEvent, chainHeadChanSize) chainHeadCh := make(chan core.ChainHeadEvent, chainHeadChanSize)
headSub := blockchain.SubscribeChainHeadEvent(chainHeadCh) headSub := s.backend.SubscribeChainHeadEvent(chainHeadCh)
defer headSub.Unsubscribe() defer headSub.Unsubscribe()
txEventCh := make(chan core.NewTxsEvent, txChanSize) txEventCh := make(chan core.NewTxsEvent, txChanSize)
txSub := txpool.SubscribeNewTxsEvent(txEventCh) txSub := s.backend.SubscribeNewTxsEvent(txEventCh)
defer txSub.Unsubscribe() defer txSub.Unsubscribe()
// Start a goroutine that exhausts the subscriptions to avoid events piling up // Start a goroutine that exhausts the subscriptions to avoid events piling up
@ -560,13 +550,15 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats {
txs []txStats txs []txStats
uncles []*types.Header uncles []*types.Header
) )
if s.eth != nil {
// Full nodes have all needed information available // check if backend is a full node
fullBackend, ok := s.backend.(fullNodeBackend)
if ok {
if block == nil { if block == nil {
block = s.eth.BlockChain().CurrentBlock() block = fullBackend.CurrentBlock()
} }
header = block.Header() header = block.Header()
td = s.eth.BlockChain().GetTd(header.Hash(), header.Number.Uint64()) td = fullBackend.GetTd(context.Background(), header.Hash())
txs = make([]txStats, len(block.Transactions())) txs = make([]txStats, len(block.Transactions()))
for i, tx := range block.Transactions() { for i, tx := range block.Transactions() {
@ -578,11 +570,12 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats {
if block != nil { if block != nil {
header = block.Header() header = block.Header()
} else { } else {
header = s.les.BlockChain().CurrentHeader() header = s.backend.CurrentHeader()
} }
td = s.les.BlockChain().GetTd(header.Hash(), header.Number.Uint64()) td = s.backend.GetTd(context.Background(), header.Hash())
txs = []txStats{} txs = []txStats{}
} }
// Assemble and return the block stats // Assemble and return the block stats
author, _ := s.engine.Author(header) author, _ := s.engine.Author(header)
@ -613,12 +606,7 @@ func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error {
indexes = append(indexes, list...) indexes = append(indexes, list...)
} else { } else {
// No indexes requested, send back the top ones // No indexes requested, send back the top ones
var head int64 head := s.backend.CurrentHeader().Number.Int64()
if s.eth != nil {
head = s.eth.BlockChain().CurrentHeader().Number.Int64()
} else {
head = s.les.BlockChain().CurrentHeader().Number.Int64()
}
start := head - historyUpdateRange + 1 start := head - historyUpdateRange + 1
if start < 0 { if start < 0 {
start = 0 start = 0
@ -630,12 +618,13 @@ func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error {
// Gather the batch of blocks to report // Gather the batch of blocks to report
history := make([]*blockStats, len(indexes)) history := make([]*blockStats, len(indexes))
for i, number := range indexes { for i, number := range indexes {
fullBackend, ok := s.backend.(fullNodeBackend)
// Retrieve the next block if it's known to us // Retrieve the next block if it's known to us
var block *types.Block var block *types.Block
if s.eth != nil { if ok {
block = s.eth.BlockChain().GetBlockByNumber(number) block, _ = fullBackend.BlockByNumber(context.Background(), rpc.BlockNumber(number)) // TODO ignore error here ?
} else { } else {
if header := s.les.BlockChain().GetHeaderByNumber(number); header != nil { if header, _ := s.backend.HeaderByNumber(context.Background(), rpc.BlockNumber(number)); header != nil {
block = types.NewBlockWithHeader(header) block = types.NewBlockWithHeader(header)
} }
} }
@ -673,12 +662,7 @@ type pendStats struct {
// it to the stats server. // it to the stats server.
func (s *Service) reportPending(conn *websocket.Conn) error { func (s *Service) reportPending(conn *websocket.Conn) error {
// Retrieve the pending count from the local blockchain // Retrieve the pending count from the local blockchain
var pending int pending, _ := s.backend.Stats()
if s.eth != nil {
pending, _ = s.eth.TxPool().Stats()
} else {
pending = s.les.TxPool().Stats()
}
// Assemble the transaction stats and send it to the server // Assemble the transaction stats and send it to the server
log.Trace("Sending pending transactions to ethstats", "count", pending) log.Trace("Sending pending transactions to ethstats", "count", pending)
@ -705,7 +689,7 @@ type nodeStats struct {
Uptime int `json:"uptime"` Uptime int `json:"uptime"`
} }
// reportPending retrieves various stats about the node at the networking and // reportStats retrieves various stats about the node at the networking and
// mining layer and reports it to the stats server. // mining layer and reports it to the stats server.
func (s *Service) reportStats(conn *websocket.Conn) error { func (s *Service) reportStats(conn *websocket.Conn) error {
// Gather the syncing and mining infos from the local miner instance // Gather the syncing and mining infos from the local miner instance
@ -715,18 +699,20 @@ func (s *Service) reportStats(conn *websocket.Conn) error {
syncing bool syncing bool
gasprice int gasprice int
) )
if s.eth != nil { // check if backend is a full node
mining = s.eth.Miner().Mining() fullBackend, ok := s.backend.(fullNodeBackend)
hashrate = int(s.eth.Miner().HashRate()) if ok {
mining = fullBackend.Miner().Mining()
hashrate = int(fullBackend.Miner().HashRate())
sync := s.eth.Downloader().Progress() sync := fullBackend.Downloader().Progress()
syncing = s.eth.BlockChain().CurrentHeader().Number.Uint64() >= sync.HighestBlock syncing = fullBackend.CurrentHeader().Number.Uint64() >= sync.HighestBlock
price, _ := s.eth.APIBackend.SuggestPrice(context.Background()) price, _ := fullBackend.SuggestPrice(context.Background())
gasprice = int(price.Uint64()) gasprice = int(price.Uint64())
} else { } else {
sync := s.les.Downloader().Progress() sync := s.backend.Downloader().Progress()
syncing = s.les.BlockChain().CurrentHeader().Number.Uint64() >= sync.HighestBlock syncing = s.backend.CurrentHeader().Number.Uint64() >= sync.HighestBlock
} }
// Assemble the node stats and send it to the server // Assemble the node stats and send it to the server
log.Trace("Sending node details to ethstats") log.Trace("Sending node details to ethstats")

@ -17,12 +17,118 @@
package graphql package graphql
import ( import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing" "testing"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/node"
"github.com/stretchr/testify/assert"
) )
func TestBuildSchema(t *testing.T) { func TestBuildSchema(t *testing.T) {
stack, err := node.New(&node.DefaultConfig)
if err != nil {
t.Fatalf("could not create new node: %v", err)
}
// Make sure the schema can be parsed and matched up to the object model. // Make sure the schema can be parsed and matched up to the object model.
if _, err := newHandler(nil); err != nil { if err := newHandler(stack, nil, []string{}, []string{}); err != nil {
t.Errorf("Could not construct GraphQL handler: %v", err) t.Errorf("Could not construct GraphQL handler: %v", err)
} }
} }
// Tests that a graphQL request is successfully handled when graphql is enabled on the specified endpoint
func TestGraphQLHTTPOnSamePort_GQLRequest_Successful(t *testing.T) {
stack := createNode(t, true)
defer stack.Close()
// start node
if err := stack.Start(); err != nil {
t.Fatalf("could not start node: %v", err)
}
// create http request
body := strings.NewReader("{\"query\": \"{block{number}}\",\"variables\": null}")
gqlReq, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s/graphql", "127.0.0.1:9393"), body)
if err != nil {
t.Error("could not issue new http request ", err)
}
gqlReq.Header.Set("Content-Type", "application/json")
// read from response
resp := doHTTPRequest(t, gqlReq)
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("could not read from response body: %v", err)
}
expected := "{\"data\":{\"block\":{\"number\":\"0x0\"}}}"
assert.Equal(t, expected, string(bodyBytes))
}
// Tests that a graphQL request is not handled successfully when graphql is not enabled on the specified endpoint
func TestGraphQLHTTPOnSamePort_GQLRequest_Unsuccessful(t *testing.T) {
stack := createNode(t, false)
defer stack.Close()
if err := stack.Start(); err != nil {
t.Fatalf("could not start node: %v", err)
}
// create http request
body := strings.NewReader("{\"query\": \"{block{number}}\",\"variables\": null}")
gqlReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s/graphql", "127.0.0.1:9393"), body)
if err != nil {
t.Error("could not issue new http request ", err)
}
gqlReq.Header.Set("Content-Type", "application/json")
// read from response
resp := doHTTPRequest(t, gqlReq)
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("could not read from response body: %v", err)
}
// make sure the request is not handled successfully
assert.Equal(t, 404, resp.StatusCode)
assert.Equal(t, "404 page not found\n", string(bodyBytes))
}
func createNode(t *testing.T, gqlEnabled bool) *node.Node {
stack, err := node.New(&node.Config{
HTTPHost: "127.0.0.1",
HTTPPort: 9393,
WSHost: "127.0.0.1",
WSPort: 9393,
})
if err != nil {
t.Fatalf("could not create node: %v", err)
}
if !gqlEnabled {
return stack
}
createGQLService(t, stack, "127.0.0.1:9393")
return stack
}
func createGQLService(t *testing.T, stack *node.Node, endpoint string) {
// create backend
ethBackend, err := eth.New(stack, &eth.DefaultConfig)
if err != nil {
t.Fatalf("could not create eth backend: %v", err)
}
// create gql service
err = New(stack, ethBackend.APIBackend, []string{}, []string{})
if err != nil {
t.Fatalf("could not create graphql service: %v", err)
}
}
func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatal("could not issue a GET request to the given endpoint", err)
}
return resp
}

@ -17,99 +17,36 @@
package graphql package graphql
import ( import (
"fmt"
"net"
"net/http"
"github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
"github.com/graph-gophers/graphql-go" "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay" "github.com/graph-gophers/graphql-go/relay"
) )
// Service encapsulates a GraphQL service.
type Service struct {
endpoint string // The host:port endpoint for this service.
cors []string // Allowed CORS domains
vhosts []string // Recognised vhosts
timeouts rpc.HTTPTimeouts // Timeout settings for HTTP requests.
backend ethapi.Backend // The backend that queries will operate on.
handler http.Handler // The `http.Handler` used to answer queries.
listener net.Listener // The listening socket.
}
// New constructs a new GraphQL service instance. // New constructs a new GraphQL service instance.
func New(backend ethapi.Backend, endpoint string, cors, vhosts []string, timeouts rpc.HTTPTimeouts) (*Service, error) { func New(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) error {
return &Service{ if backend == nil {
endpoint: endpoint, panic("missing backend")
cors: cors,
vhosts: vhosts,
timeouts: timeouts,
backend: backend,
}, nil
}
// Protocols returns the list of protocols exported by this service.
func (s *Service) Protocols() []p2p.Protocol { return nil }
// APIs returns the list of APIs exported by this service.
func (s *Service) APIs() []rpc.API { return nil }
// Start is called after all services have been constructed and the networking
// layer was also initialized to spawn any goroutines required by the service.
func (s *Service) Start(server *p2p.Server) error {
var err error
s.handler, err = newHandler(s.backend)
if err != nil {
return err
} }
if s.listener, err = net.Listen("tcp", s.endpoint); err != nil { // check if http server with given endpoint exists and enable graphQL on it
return err return newHandler(stack, backend, cors, vhosts)
}
// create handler stack and wrap the graphql handler
handler := node.NewHTTPHandlerStack(s.handler, s.cors, s.vhosts)
// make sure timeout values are meaningful
node.CheckTimeouts(&s.timeouts)
// create http server
httpSrv := &http.Server{
Handler: handler,
ReadTimeout: s.timeouts.ReadTimeout,
WriteTimeout: s.timeouts.WriteTimeout,
IdleTimeout: s.timeouts.IdleTimeout,
}
go httpSrv.Serve(s.listener)
log.Info("GraphQL endpoint opened", "url", fmt.Sprintf("http://%s", s.endpoint))
return nil
} }
// newHandler returns a new `http.Handler` that will answer GraphQL queries. // newHandler returns a new `http.Handler` that will answer GraphQL queries.
// It additionally exports an interactive query browser on the / endpoint. // It additionally exports an interactive query browser on the / endpoint.
func newHandler(backend ethapi.Backend) (http.Handler, error) { func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) error {
q := Resolver{backend} q := Resolver{backend}
s, err := graphql.ParseSchema(schema, &q) s, err := graphql.ParseSchema(schema, &q)
if err != nil { if err != nil {
return nil, err return err
} }
h := &relay.Handler{Schema: s} h := &relay.Handler{Schema: s}
handler := node.NewHTTPHandlerStack(h, cors, vhosts)
mux := http.NewServeMux() stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{})
mux.Handle("/", GraphiQL{}) stack.RegisterHandler("GraphQL", "/graphql", handler)
mux.Handle("/graphql", h) stack.RegisterHandler("GraphQL", "/graphql/", handler)
mux.Handle("/graphql/", h)
return mux, nil
}
// Stop terminates all goroutines belonging to the service, blocking until they
// are all terminated.
func (s *Service) Stop() error {
if s.listener != nil {
s.listener.Close()
s.listener = nil
log.Info("GraphQL endpoint closed", "url", fmt.Sprintf("http://%s", s.endpoint))
}
return nil return nil
} }

@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
@ -45,14 +46,16 @@ type Backend interface {
ChainDb() ethdb.Database ChainDb() ethdb.Database
AccountManager() *accounts.Manager AccountManager() *accounts.Manager
ExtRPCEnabled() bool ExtRPCEnabled() bool
RPCTxFeeCap() float64 // global tx fee cap for all transaction related APIs
RPCGasCap() uint64 // global gas cap for eth_call over rpc: DoS protection RPCGasCap() uint64 // global gas cap for eth_call over rpc: DoS protection
RPCTxFeeCap() float64 // global tx fee cap for all transaction related APIs
// Blockchain API // Blockchain API
SetHead(number uint64) SetHead(number uint64)
HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error) HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error)
HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error)
HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error) HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error)
CurrentHeader() *types.Header
CurrentBlock() *types.Block
BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error) BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error)
BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error)
BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error)
@ -84,7 +87,7 @@ type Backend interface {
SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription
ChainConfig() *params.ChainConfig ChainConfig() *params.ChainConfig
CurrentBlock() *types.Block Engine() consensus.Engine
} }
func GetAPIs(apiBackend Backend) []rpc.API { func GetAPIs(apiBackend Backend) []rpc.API {

@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
@ -282,3 +283,11 @@ func (b *LesApiBackend) ServiceFilter(ctx context.Context, session *bloombits.Ma
go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests)
} }
} }
func (b *LesApiBackend) Engine() consensus.Engine {
return b.eth.engine
}
func (b *LesApiBackend) CurrentHeader() *types.Header {
return b.eth.blockchain.CurrentHeader()
}

@ -55,7 +55,7 @@ func TestMain(m *testing.M) {
log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true))))
// register the Delivery service which will run as a devp2p // register the Delivery service which will run as a devp2p
// protocol when using the exec adapter // protocol when using the exec adapter
adapters.RegisterServices(services) adapters.RegisterLifecycles(services)
os.Exit(m.Run()) os.Exit(m.Run())
} }
@ -392,7 +392,7 @@ func getCapacityInfo(ctx context.Context, t *testing.T, server *rpc.Client) (min
return return
} }
var services = adapters.Services{ var services = adapters.LifecycleConstructors{
"lesclient": newLesClientService, "lesclient": newLesClientService,
"lesserver": newLesServerService, "lesserver": newLesServerService,
} }
@ -414,7 +414,7 @@ func NewNetwork() (*simulations.Network, func(), error) {
return net, teardown, nil return net, teardown, nil
} }
func NewAdapter(adapterType string, services adapters.Services) (adapter adapters.NodeAdapter, teardown func(), err error) { func NewAdapter(adapterType string, services adapters.LifecycleConstructors) (adapter adapters.NodeAdapter, teardown func(), err error) {
teardown = func() {} teardown = func() {}
switch adapterType { switch adapterType {
case "sim": case "sim":
@ -454,7 +454,7 @@ func testSim(t *testing.T, serverCount, clientCount int, serverDir, clientDir []
for i := range clients { for i := range clients {
clientconf := adapters.RandomNodeConfig() clientconf := adapters.RandomNodeConfig()
clientconf.Services = []string{"lesclient"} clientconf.Lifecycles = []string{"lesclient"}
if len(clientDir) == clientCount { if len(clientDir) == clientCount {
clientconf.DataDir = clientDir[i] clientconf.DataDir = clientDir[i]
} }
@ -467,7 +467,7 @@ func testSim(t *testing.T, serverCount, clientCount int, serverDir, clientDir []
for i := range servers { for i := range servers {
serverconf := adapters.RandomNodeConfig() serverconf := adapters.RandomNodeConfig()
serverconf.Services = []string{"lesserver"} serverconf.Lifecycles = []string{"lesserver"}
if len(serverDir) == serverCount { if len(serverDir) == serverCount {
serverconf.DataDir = serverDir[i] serverconf.DataDir = serverDir[i]
} }
@ -492,26 +492,25 @@ func testSim(t *testing.T, serverCount, clientCount int, serverDir, clientDir []
return test(ctx, net, servers, clients) return test(ctx, net, servers, clients)
} }
func newLesClientService(ctx *adapters.ServiceContext) (node.Service, error) { func newLesClientService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
config := eth.DefaultConfig config := eth.DefaultConfig
config.SyncMode = downloader.LightSync config.SyncMode = downloader.LightSync
config.Ethash.PowMode = ethash.ModeFake config.Ethash.PowMode = ethash.ModeFake
return New(ctx.NodeContext, &config) return New(stack, &config)
} }
func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) { func newLesServerService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
config := eth.DefaultConfig config := eth.DefaultConfig
config.SyncMode = downloader.FullSync config.SyncMode = downloader.FullSync
config.LightServ = testServerCapacity config.LightServ = testServerCapacity
config.LightPeers = testMaxClients config.LightPeers = testMaxClients
ethereum, err := eth.New(ctx.NodeContext, &config) ethereum, err := eth.New(stack, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
server, err := NewLesServer(ethereum, &config) _, err = NewLesServer(stack, ethereum, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ethereum.AddLesServer(server)
return ethereum, nil return ethereum, nil
} }

@ -51,16 +51,6 @@ type CheckpointOracle struct {
// New creates a checkpoint oracle handler with given configs and callback. // New creates a checkpoint oracle handler with given configs and callback.
func New(config *params.CheckpointOracleConfig, getLocal func(uint64) params.TrustedCheckpoint) *CheckpointOracle { func New(config *params.CheckpointOracleConfig, getLocal func(uint64) params.TrustedCheckpoint) *CheckpointOracle {
if config == nil {
log.Info("Checkpoint registrar is not enabled")
return nil
}
if config.Address == (common.Address{}) || uint64(len(config.Signers)) < config.Threshold {
log.Warn("Invalid checkpoint registrar config")
return nil
}
log.Info("Configured checkpoint registrar", "address", config.Address, "signers", len(config.Signers), "threshold", config.Threshold)
return &CheckpointOracle{ return &CheckpointOracle{
config: config, config: config,
getLocal: getLocal, getLocal: getLocal,

@ -22,7 +22,6 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
@ -37,7 +36,6 @@ import (
"github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/les/checkpointoracle"
lpc "github.com/ethereum/go-ethereum/les/lespay/client" lpc "github.com/ethereum/go-ethereum/les/lespay/client"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
@ -72,14 +70,17 @@ type LightEthereum struct {
engine consensus.Engine engine consensus.Engine
accountManager *accounts.Manager accountManager *accounts.Manager
netRPCService *ethapi.PublicNetAPI netRPCService *ethapi.PublicNetAPI
p2pServer *p2p.Server
} }
func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { // New creates an instance of the light client.
chainDb, err := ctx.OpenDatabase("lightchaindata", config.DatabaseCache, config.DatabaseHandles, "eth/db/chaindata/") func New(stack *node.Node, config *eth.Config) (*LightEthereum, error) {
chainDb, err := stack.OpenDatabase("lightchaindata", config.DatabaseCache, config.DatabaseHandles, "eth/db/chaindata/")
if err != nil { if err != nil {
return nil, err return nil, err
} }
lespayDb, err := ctx.OpenDatabase("lespay", 0, 0, "eth/db/lespay") lespayDb, err := stack.OpenDatabase("lespay", 0, 0, "eth/db/lespay")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,17 +101,18 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
}, },
peers: peers, peers: peers,
eventMux: ctx.EventMux, eventMux: stack.EventMux(),
reqDist: newRequestDistributor(peers, &mclock.System{}), reqDist: newRequestDistributor(peers, &mclock.System{}),
accountManager: ctx.AccountManager, accountManager: stack.AccountManager(),
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), engine: eth.CreateConsensusEngine(stack, chainConfig, &config.Ethash, nil, false, chainDb),
bloomRequests: make(chan chan *bloombits.Retrieval), bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
p2pServer: stack.Server(),
} }
peers.subscribe((*vtSubscription)(leth.valueTracker)) peers.subscribe((*vtSubscription)(leth.valueTracker))
dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P) dnsdisc, err := leth.setupDiscovery(&stack.Config().P2P)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,11 +141,7 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
// Set up checkpoint oracle. // Set up checkpoint oracle.
oracle := config.CheckpointOracle leth.oracle = leth.setupOracle(stack, genesisHash, config)
if oracle == nil {
oracle = params.CheckpointOracles[genesisHash]
}
leth.oracle = checkpointoracle.New(oracle, leth.localCheckpoint)
// Note: AddChildIndexer starts the update process for the child // Note: AddChildIndexer starts the update process for the child
leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer) leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer)
@ -160,7 +158,7 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
} }
leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil} leth.ApiBackend = &LesApiBackend{stack.Config().ExtRPCEnabled(), leth, nil}
gpoParams := config.GPO gpoParams := config.GPO
if gpoParams.Default == nil { if gpoParams.Default == nil {
gpoParams.Default = config.Miner.GasPrice gpoParams.Default = config.Miner.GasPrice
@ -172,6 +170,14 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
leth.blockchain.DisableCheckFreq() leth.blockchain.DisableCheckFreq()
} }
leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer, leth.config.NetworkId)
// Register the backend on the node
stack.RegisterAPIs(leth.APIs())
stack.RegisterProtocols(leth.Protocols())
stack.RegisterLifecycle(leth)
return leth, nil return leth, nil
} }
@ -265,8 +271,7 @@ func (s *LightEthereum) LesVersion() int { return int(ClientP
func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader } func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader }
func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux } func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux }
// Protocols implements node.Service, returning all the currently configured // Protocols returns all the currently configured network protocols to start.
// network protocols to start.
func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Protocols() []p2p.Protocol {
return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
if p := s.peers.peer(id.String()); p != nil { if p := s.peers.peer(id.String()); p != nil {
@ -276,9 +281,9 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
}, s.dialCandidates) }, s.dialCandidates)
} }
// Start implements node.Service, starting all internal goroutines needed by the // Start implements node.Lifecycle, starting all internal goroutines needed by the
// light ethereum protocol implementation. // light ethereum protocol implementation.
func (s *LightEthereum) Start(srvr *p2p.Server) error { func (s *LightEthereum) Start() error {
log.Warn("Light client mode is an experimental feature") log.Warn("Light client mode is an experimental feature")
s.serverPool.start() s.serverPool.start()
@ -287,11 +292,10 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error {
s.startBloomHandlers(params.BloomBitsBlocksClient) s.startBloomHandlers(params.BloomBitsBlocksClient)
s.handler.start() s.handler.start()
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
return nil return nil
} }
// Stop implements node.Service, terminating all internal goroutines used by the // Stop implements node.Lifecycle, terminating all internal goroutines used by the
// Ethereum protocol. // Ethereum protocol.
func (s *LightEthereum) Stop() error { func (s *LightEthereum) Stop() error {
close(s.closeCh) close(s.closeCh)
@ -314,11 +318,3 @@ func (s *LightEthereum) Stop() error {
log.Info("Light ethereum stopped") log.Info("Light ethereum stopped")
return nil return nil
} }
// SetClient sets the rpc client and binds the registrar contract.
func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) {
if s.oracle == nil {
return
}
s.oracle.Start(backend)
}

@ -26,9 +26,12 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/les/checkpointoracle" "github.com/ethereum/go-ethereum/les/checkpointoracle"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -145,3 +148,26 @@ func (c *lesCommons) localCheckpoint(index uint64) params.TrustedCheckpoint {
BloomRoot: light.GetBloomTrieRoot(c.chainDb, index, sectionHead), BloomRoot: light.GetBloomTrieRoot(c.chainDb, index, sectionHead),
} }
} }
// setupOracle sets up the checkpoint oracle contract client.
func (c *lesCommons) setupOracle(node *node.Node, genesis common.Hash, ethconfig *eth.Config) *checkpointoracle.CheckpointOracle {
config := ethconfig.CheckpointOracle
if config == nil {
// Try loading default config.
config = params.CheckpointOracles[genesis]
}
if config == nil {
log.Info("Checkpoint registrar is not enabled")
return nil
}
if config.Address == (common.Address{}) || uint64(len(config.Signers)) < config.Threshold {
log.Warn("Invalid checkpoint registrar config")
return nil
}
oracle := checkpointoracle.New(config, c.localCheckpoint)
rpcClient, _ := node.Attach()
client := ethclient.NewClient(rpcClient)
oracle.Start(client)
log.Info("Configured checkpoint registrar", "address", config.Address, "signers", len(config.Signers), "threshold", config.Threshold)
return oracle
}

@ -20,14 +20,12 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"time" "time"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/checkpointoracle"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -55,9 +53,11 @@ type LesServer struct {
minCapacity, maxCapacity, freeCapacity uint64 minCapacity, maxCapacity, freeCapacity uint64
threadsIdle int // Request serving threads count when system is idle. threadsIdle int // Request serving threads count when system is idle.
threadsBusy int // Request serving threads count when system is busy(block insertion). threadsBusy int // Request serving threads count when system is busy(block insertion).
p2pSrv *p2p.Server
} }
func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
// Collect les protocol version information supported by local node. // Collect les protocol version information supported by local node.
lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
for i, pv := range AdvertiseProtocolVersions { for i, pv := range AdvertiseProtocolVersions {
@ -88,17 +88,15 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100), servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100),
threadsBusy: config.LightServ/100 + 1, threadsBusy: config.LightServ/100 + 1,
threadsIdle: threads, threadsIdle: threads,
p2pSrv: node.Server(),
} }
srv.handler = newServerHandler(srv, e.BlockChain(), e.ChainDb(), e.TxPool(), e.Synced) srv.handler = newServerHandler(srv, e.BlockChain(), e.ChainDb(), e.TxPool(), e.Synced)
srv.costTracker, srv.minCapacity = newCostTracker(e.ChainDb(), config) srv.costTracker, srv.minCapacity = newCostTracker(e.ChainDb(), config)
srv.freeCapacity = srv.minCapacity srv.freeCapacity = srv.minCapacity
srv.oracle = srv.setupOracle(node, e.BlockChain().Genesis().Hash(), config)
// Set up checkpoint oracle. // Initialize the bloom trie indexer.
oracle := config.CheckpointOracle e.BloomIndexer().AddChildIndexer(srv.bloomTrieIndexer)
if oracle == nil {
oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()]
}
srv.oracle = checkpointoracle.New(oracle, srv.localCheckpoint)
// Initialize server capacity management fields. // Initialize server capacity management fields.
srv.defParams = flowcontrol.ServerParams{ srv.defParams = flowcontrol.ServerParams{
@ -125,6 +123,11 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
"chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot) "chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot)
} }
srv.chtIndexer.Start(e.BlockChain()) srv.chtIndexer.Start(e.BlockChain())
node.RegisterProtocols(srv.Protocols())
node.RegisterAPIs(srv.APIs())
node.RegisterLifecycle(srv)
return srv, nil return srv, nil
} }
@ -166,14 +169,14 @@ func (s *LesServer) Protocols() []p2p.Protocol {
} }
// Start starts the LES server // Start starts the LES server
func (s *LesServer) Start(srvr *p2p.Server) { func (s *LesServer) Start() error {
s.privateKey = srvr.PrivateKey s.privateKey = s.p2pSrv.PrivateKey
s.handler.start() s.handler.start()
s.wg.Add(1) s.wg.Add(1)
go s.capacityManagement() go s.capacityManagement()
if srvr.DiscV5 != nil { if s.p2pSrv.DiscV5 != nil {
for _, topic := range s.lesTopics { for _, topic := range s.lesTopics {
topic := topic topic := topic
go func() { go func() {
@ -181,14 +184,16 @@ func (s *LesServer) Start(srvr *p2p.Server) {
logger.Info("Starting topic registration") logger.Info("Starting topic registration")
defer logger.Info("Terminated topic registration") defer logger.Info("Terminated topic registration")
srvr.DiscV5.RegisterTopic(topic, s.closeCh) s.p2pSrv.DiscV5.RegisterTopic(topic, s.closeCh)
}() }()
} }
} }
return nil
} }
// Stop stops the LES service // Stop stops the LES service
func (s *LesServer) Stop() { func (s *LesServer) Stop() error {
close(s.closeCh) close(s.closeCh)
// Disconnect existing sessions. // Disconnect existing sessions.
@ -207,18 +212,8 @@ func (s *LesServer) Stop() {
s.chtIndexer.Close() s.chtIndexer.Close()
s.wg.Wait() s.wg.Wait()
log.Info("Les server stopped") log.Info("Les server stopped")
}
func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) { return nil
bloomIndexer.AddChildIndexer(s.bloomTrieIndexer)
}
// SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched.
func (s *LesServer) SetContractBackend(backend bind.ContractBackend) {
if s.oracle == nil {
return
}
s.oracle.Start(backend)
} }
// capacityManagement starts an event handler loop that updates the recharge curve of // capacityManagement starts an event handler loop that updates the recharge curve of

@ -61,30 +61,31 @@ func main() {
genesis := makeGenesis(faucets, sealers) genesis := makeGenesis(faucets, sealers)
var ( var (
nodes []*node.Node nodes []*eth.Ethereum
enodes []*enode.Node enodes []*enode.Node
) )
for _, sealer := range sealers { for _, sealer := range sealers {
// Start the node and wait until it's up // Start the node and wait until it's up
node, err := makeSealer(genesis) stack, ethBackend, err := makeSealer(genesis)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer node.Close() defer stack.Close()
for node.Server().NodeInfo().Ports.Listener == 0 { for stack.Server().NodeInfo().Ports.Listener == 0 {
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
} }
// Connect the node to al the previous ones // Connect the node to all the previous ones
for _, n := range enodes { for _, n := range enodes {
node.Server().AddPeer(n) stack.Server().AddPeer(n)
} }
// Start tracking the node and it's enode // Start tracking the node and its enode
nodes = append(nodes, node) nodes = append(nodes, ethBackend)
enodes = append(enodes, node.Server().Self()) enodes = append(enodes, stack.Server().Self())
// Inject the signer key and start sealing with it // Inject the signer key and start sealing with it
store := node.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) store := stack.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore)
signer, err := store.ImportECDSA(sealer, "") signer, err := store.ImportECDSA(sealer, "")
if err != nil { if err != nil {
panic(err) panic(err)
@ -93,15 +94,11 @@ func main() {
panic(err) panic(err)
} }
} }
// Iterate over all the nodes and start signing with them
time.Sleep(3 * time.Second)
// Iterate over all the nodes and start signing on them
time.Sleep(3 * time.Second)
for _, node := range nodes { for _, node := range nodes {
var ethereum *eth.Ethereum if err := node.StartMining(1); err != nil {
if err := node.Service(&ethereum); err != nil {
panic(err)
}
if err := ethereum.StartMining(1); err != nil {
panic(err) panic(err)
} }
} }
@ -110,25 +107,22 @@ func main() {
// Start injecting transactions from the faucet like crazy // Start injecting transactions from the faucet like crazy
nonces := make([]uint64, len(faucets)) nonces := make([]uint64, len(faucets))
for { for {
// Pick a random signer node
index := rand.Intn(len(faucets)) index := rand.Intn(len(faucets))
backend := nodes[index%len(nodes)]
// Fetch the accessor for the relevant signer
var ethereum *eth.Ethereum
if err := nodes[index%len(nodes)].Service(&ethereum); err != nil {
panic(err)
}
// Create a self transaction and inject into the pool // Create a self transaction and inject into the pool
tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000), nil), types.HomesteadSigner{}, faucets[index]) tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000), nil), types.HomesteadSigner{}, faucets[index])
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := ethereum.TxPool().AddLocal(tx); err != nil { if err := backend.TxPool().AddLocal(tx); err != nil {
panic(err) panic(err)
} }
nonces[index]++ nonces[index]++
// Wait if we're too saturated // Wait if we're too saturated
if pend, _ := ethereum.TxPool().Stats(); pend > 2048 { if pend, _ := backend.TxPool().Stats(); pend > 2048 {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
} }
@ -171,7 +165,7 @@ func makeGenesis(faucets []*ecdsa.PrivateKey, sealers []*ecdsa.PrivateKey) *core
return genesis return genesis
} }
func makeSealer(genesis *core.Genesis) (*node.Node, error) { func makeSealer(genesis *core.Genesis) (*node.Node, *eth.Ethereum, error) {
// Define the basic configurations for the Ethereum node // Define the basic configurations for the Ethereum node
datadir, _ := ioutil.TempDir("", "") datadir, _ := ioutil.TempDir("", "")
@ -189,27 +183,28 @@ func makeSealer(genesis *core.Genesis) (*node.Node, error) {
// Start the node and configure a full Ethereum node on it // Start the node and configure a full Ethereum node on it
stack, err := node.New(config) stack, err := node.New(config)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { // Create and register the backend
return eth.New(ctx, &eth.Config{ ethBackend, err := eth.New(stack, &eth.Config{
Genesis: genesis, Genesis: genesis,
NetworkId: genesis.Config.ChainID.Uint64(), NetworkId: genesis.Config.ChainID.Uint64(),
SyncMode: downloader.FullSync, SyncMode: downloader.FullSync,
DatabaseCache: 256, DatabaseCache: 256,
DatabaseHandles: 256, DatabaseHandles: 256,
TxPool: core.DefaultTxPoolConfig, TxPool: core.DefaultTxPoolConfig,
GPO: eth.DefaultConfig.GPO, GPO: eth.DefaultConfig.GPO,
Miner: miner.Config{ Miner: miner.Config{
GasFloor: genesis.GasLimit * 9 / 10, GasFloor: genesis.GasLimit * 9 / 10,
GasCeil: genesis.GasLimit * 11 / 10, GasCeil: genesis.GasLimit * 11 / 10,
GasPrice: big.NewInt(1), GasPrice: big.NewInt(1),
Recommit: time.Second, Recommit: time.Second,
}, },
}) })
}); err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// Start the node and return if successful
return stack, stack.Start() err = stack.Start()
return stack, ethBackend, err
} }

@ -61,43 +61,39 @@ func main() {
genesis := makeGenesis(faucets) genesis := makeGenesis(faucets)
var ( var (
nodes []*node.Node nodes []*eth.Ethereum
enodes []*enode.Node enodes []*enode.Node
) )
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
// Start the node and wait until it's up // Start the node and wait until it's up
node, err := makeMiner(genesis) stack, ethBackend, err := makeMiner(genesis)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer node.Close() defer stack.Close()
for node.Server().NodeInfo().Ports.Listener == 0 { for stack.Server().NodeInfo().Ports.Listener == 0 {
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
} }
// Connect the node to al the previous ones // Connect the node to all the previous ones
for _, n := range enodes { for _, n := range enodes {
node.Server().AddPeer(n) stack.Server().AddPeer(n)
} }
// Start tracking the node and it's enode // Start tracking the node and its enode
nodes = append(nodes, node) nodes = append(nodes, ethBackend)
enodes = append(enodes, node.Server().Self()) enodes = append(enodes, stack.Server().Self())
// Inject the signer key and start sealing with it // Inject the signer key and start sealing with it
store := node.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) store := stack.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore)
if _, err := store.NewAccount(""); err != nil { if _, err := store.NewAccount(""); err != nil {
panic(err) panic(err)
} }
} }
// Iterate over all the nodes and start signing with them
time.Sleep(3 * time.Second)
// Iterate over all the nodes and start mining
time.Sleep(3 * time.Second)
for _, node := range nodes { for _, node := range nodes {
var ethereum *eth.Ethereum if err := node.StartMining(1); err != nil {
if err := node.Service(&ethereum); err != nil {
panic(err)
}
if err := ethereum.StartMining(1); err != nil {
panic(err) panic(err)
} }
} }
@ -106,25 +102,22 @@ func main() {
// Start injecting transactions from the faucets like crazy // Start injecting transactions from the faucets like crazy
nonces := make([]uint64, len(faucets)) nonces := make([]uint64, len(faucets))
for { for {
// Pick a random mining node
index := rand.Intn(len(faucets)) index := rand.Intn(len(faucets))
backend := nodes[index%len(nodes)]
// Fetch the accessor for the relevant signer
var ethereum *eth.Ethereum
if err := nodes[index%len(nodes)].Service(&ethereum); err != nil {
panic(err)
}
// Create a self transaction and inject into the pool // Create a self transaction and inject into the pool
tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000+rand.Int63n(65536)), nil), types.HomesteadSigner{}, faucets[index]) tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000+rand.Int63n(65536)), nil), types.HomesteadSigner{}, faucets[index])
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := ethereum.TxPool().AddLocal(tx); err != nil { if err := backend.TxPool().AddLocal(tx); err != nil {
panic(err) panic(err)
} }
nonces[index]++ nonces[index]++
// Wait if we're too saturated // Wait if we're too saturated
if pend, _ := ethereum.TxPool().Stats(); pend > 2048 { if pend, _ := backend.TxPool().Stats(); pend > 2048 {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
} }
@ -149,7 +142,7 @@ func makeGenesis(faucets []*ecdsa.PrivateKey) *core.Genesis {
return genesis return genesis
} }
func makeMiner(genesis *core.Genesis) (*node.Node, error) { func makeMiner(genesis *core.Genesis) (*node.Node, *eth.Ethereum, error) {
// Define the basic configurations for the Ethereum node // Define the basic configurations for the Ethereum node
datadir, _ := ioutil.TempDir("", "") datadir, _ := ioutil.TempDir("", "")
@ -165,31 +158,31 @@ func makeMiner(genesis *core.Genesis) (*node.Node, error) {
NoUSB: true, NoUSB: true,
UseLightweightKDF: true, UseLightweightKDF: true,
} }
// Start the node and configure a full Ethereum node on it // Create the node and configure a full Ethereum node on it
stack, err := node.New(config) stack, err := node.New(config)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { ethBackend, err := eth.New(stack, &eth.Config{
return eth.New(ctx, &eth.Config{ Genesis: genesis,
Genesis: genesis, NetworkId: genesis.Config.ChainID.Uint64(),
NetworkId: genesis.Config.ChainID.Uint64(), SyncMode: downloader.FullSync,
SyncMode: downloader.FullSync, DatabaseCache: 256,
DatabaseCache: 256, DatabaseHandles: 256,
DatabaseHandles: 256, TxPool: core.DefaultTxPoolConfig,
TxPool: core.DefaultTxPoolConfig, GPO: eth.DefaultConfig.GPO,
GPO: eth.DefaultConfig.GPO, Ethash: eth.DefaultConfig.Ethash,
Ethash: eth.DefaultConfig.Ethash, Miner: miner.Config{
Miner: miner.Config{ GasFloor: genesis.GasLimit * 9 / 10,
GasFloor: genesis.GasLimit * 9 / 10, GasCeil: genesis.GasLimit * 11 / 10,
GasCeil: genesis.GasLimit * 11 / 10, GasPrice: big.NewInt(1),
GasPrice: big.NewInt(1), Recommit: time.Second,
Recommit: time.Second, },
}, })
}) if err != nil {
}); err != nil { return nil, nil, err
return nil, err
} }
// Start the node and return if successful
return stack, stack.Start() err = stack.Start()
return stack, ethBackend, err
} }

@ -175,49 +175,44 @@ func NewNode(datadir string, config *NodeConfig) (stack *Node, _ error) {
ethConf.SyncMode = downloader.LightSync ethConf.SyncMode = downloader.LightSync
ethConf.NetworkId = uint64(config.EthereumNetworkID) ethConf.NetworkId = uint64(config.EthereumNetworkID)
ethConf.DatabaseCache = config.EthereumDatabaseCache ethConf.DatabaseCache = config.EthereumDatabaseCache
if err := rawStack.Register(func(ctx *node.ServiceContext) (node.Service, error) { lesBackend, err := les.New(rawStack, &ethConf)
return les.New(ctx, &ethConf) if err != nil {
}); err != nil {
return nil, fmt.Errorf("ethereum init: %v", err) return nil, fmt.Errorf("ethereum init: %v", err)
} }
// If netstats reporting is requested, do it // If netstats reporting is requested, do it
if config.EthereumNetStats != "" { if config.EthereumNetStats != "" {
if err := rawStack.Register(func(ctx *node.ServiceContext) (node.Service, error) { if err := ethstats.New(rawStack, lesBackend.ApiBackend, lesBackend.Engine(), config.EthereumNetStats); err != nil {
var lesServ *les.LightEthereum
ctx.Service(&lesServ)
return ethstats.New(config.EthereumNetStats, nil, lesServ)
}); err != nil {
return nil, fmt.Errorf("netstats init: %v", err) return nil, fmt.Errorf("netstats init: %v", err)
} }
} }
} }
// Register the Whisper protocol if requested // Register the Whisper protocol if requested
if config.WhisperEnabled { if config.WhisperEnabled {
if err := rawStack.Register(func(*node.ServiceContext) (node.Service, error) { if _, err := whisper.New(rawStack, &whisper.DefaultConfig); err != nil {
return whisper.New(&whisper.DefaultConfig), nil
}); err != nil {
return nil, fmt.Errorf("whisper init: %v", err) return nil, fmt.Errorf("whisper init: %v", err)
} }
} }
return &Node{rawStack}, nil return &Node{rawStack}, nil
} }
// Close terminates a running node along with all it's services, tearing internal // Close terminates a running node along with all it's services, tearing internal state
// state doen too. It's not possible to restart a closed node. // down. It is not possible to restart a closed node.
func (n *Node) Close() error { func (n *Node) Close() error {
return n.node.Close() return n.node.Close()
} }
// Start creates a live P2P node and starts running it. // Start creates a live P2P node and starts running it.
func (n *Node) Start() error { func (n *Node) Start() error {
// TODO: recreate the node so it can be started multiple times
return n.node.Start() return n.node.Start()
} }
// Stop terminates a running node along with all it's services. If the node was // Stop terminates a running node along with all its services. If the node was not started,
// not started, an error is returned. // an error is returned. It is not possible to restart a stopped node.
//
// Deprecated: use Close()
func (n *Node) Stop() error { func (n *Node) Stop() error {
return n.node.Stop() return n.node.Close()
} }
// GetEthereumClient retrieves a client to access the Ethereum subsystem. // GetEthereumClient retrieves a client to access the Ethereum subsystem.

@ -23,26 +23,46 @@ import (
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/debug"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
// PrivateAdminAPI is the collection of administrative API methods exposed only // apis returns the collection of built-in RPC APIs.
// over a secure RPC channel. func (n *Node) apis() []rpc.API {
type PrivateAdminAPI struct { return []rpc.API{
node *Node // Node interfaced by this API {
Namespace: "admin",
Version: "1.0",
Service: &privateAdminAPI{n},
}, {
Namespace: "admin",
Version: "1.0",
Service: &publicAdminAPI{n},
Public: true,
}, {
Namespace: "debug",
Version: "1.0",
Service: debug.Handler,
}, {
Namespace: "web3",
Version: "1.0",
Service: &publicWeb3API{n},
Public: true,
},
}
} }
// NewPrivateAdminAPI creates a new API definition for the private admin methods // privateAdminAPI is the collection of administrative API methods exposed only
// of the node itself. // over a secure RPC channel.
func NewPrivateAdminAPI(node *Node) *PrivateAdminAPI { type privateAdminAPI struct {
return &PrivateAdminAPI{node: node} node *Node // Node interfaced by this API
} }
// AddPeer requests connecting to a remote node, and also maintaining the new // AddPeer requests connecting to a remote node, and also maintaining the new
// connection at all times, even reconnecting if it is lost. // connection at all times, even reconnecting if it is lost.
func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) { func (api *privateAdminAPI) AddPeer(url string) (bool, error) {
// Make sure the server is running, fail otherwise // Make sure the server is running, fail otherwise
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
@ -58,7 +78,7 @@ func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) {
} }
// RemovePeer disconnects from a remote node if the connection exists // RemovePeer disconnects from a remote node if the connection exists
func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) { func (api *privateAdminAPI) RemovePeer(url string) (bool, error) {
// Make sure the server is running, fail otherwise // Make sure the server is running, fail otherwise
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
@ -74,7 +94,7 @@ func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) {
} }
// AddTrustedPeer allows a remote node to always connect, even if slots are full // AddTrustedPeer allows a remote node to always connect, even if slots are full
func (api *PrivateAdminAPI) AddTrustedPeer(url string) (bool, error) { func (api *privateAdminAPI) AddTrustedPeer(url string) (bool, error) {
// Make sure the server is running, fail otherwise // Make sure the server is running, fail otherwise
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
@ -90,7 +110,7 @@ func (api *PrivateAdminAPI) AddTrustedPeer(url string) (bool, error) {
// RemoveTrustedPeer removes a remote node from the trusted peer set, but it // RemoveTrustedPeer removes a remote node from the trusted peer set, but it
// does not disconnect it automatically. // does not disconnect it automatically.
func (api *PrivateAdminAPI) RemoveTrustedPeer(url string) (bool, error) { func (api *privateAdminAPI) RemoveTrustedPeer(url string) (bool, error) {
// Make sure the server is running, fail otherwise // Make sure the server is running, fail otherwise
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
@ -106,7 +126,7 @@ func (api *PrivateAdminAPI) RemoveTrustedPeer(url string) (bool, error) {
// PeerEvents creates an RPC subscription which receives peer events from the // PeerEvents creates an RPC subscription which receives peer events from the
// node's p2p.Server // node's p2p.Server
func (api *PrivateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, error) { func (api *privateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, error) {
// Make sure the server is running, fail otherwise // Make sure the server is running, fail otherwise
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
@ -143,14 +163,11 @@ func (api *PrivateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription,
} }
// StartRPC starts the HTTP RPC API server. // StartRPC starts the HTTP RPC API server.
func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis *string, vhosts *string) (bool, error) { func (api *privateAdminAPI) StartRPC(host *string, port *int, cors *string, apis *string, vhosts *string) (bool, error) {
api.node.lock.Lock() api.node.lock.Lock()
defer api.node.lock.Unlock() defer api.node.lock.Unlock()
if api.node.httpHandler != nil { // Determine host and port.
return false, fmt.Errorf("HTTP RPC already running on %s", api.node.httpEndpoint)
}
if host == nil { if host == nil {
h := DefaultHTTPHost h := DefaultHTTPHost
if api.node.config.HTTPHost != "" { if api.node.config.HTTPHost != "" {
@ -162,57 +179,55 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
port = &api.node.config.HTTPPort port = &api.node.config.HTTPPort
} }
allowedOrigins := api.node.config.HTTPCors // Determine config.
config := httpConfig{
CorsAllowedOrigins: api.node.config.HTTPCors,
Vhosts: api.node.config.HTTPVirtualHosts,
Modules: api.node.config.HTTPModules,
}
if cors != nil { if cors != nil {
allowedOrigins = nil config.CorsAllowedOrigins = nil
for _, origin := range strings.Split(*cors, ",") { for _, origin := range strings.Split(*cors, ",") {
allowedOrigins = append(allowedOrigins, strings.TrimSpace(origin)) config.CorsAllowedOrigins = append(config.CorsAllowedOrigins, strings.TrimSpace(origin))
} }
} }
allowedVHosts := api.node.config.HTTPVirtualHosts
if vhosts != nil { if vhosts != nil {
allowedVHosts = nil config.Vhosts = nil
for _, vhost := range strings.Split(*host, ",") { for _, vhost := range strings.Split(*host, ",") {
allowedVHosts = append(allowedVHosts, strings.TrimSpace(vhost)) config.Vhosts = append(config.Vhosts, strings.TrimSpace(vhost))
} }
} }
modules := api.node.httpWhitelist
if apis != nil { if apis != nil {
modules = nil config.Modules = nil
for _, m := range strings.Split(*apis, ",") { for _, m := range strings.Split(*apis, ",") {
modules = append(modules, strings.TrimSpace(m)) config.Modules = append(config.Modules, strings.TrimSpace(m))
} }
} }
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil { if err := api.node.http.setListenAddr(*host, *port); err != nil {
return false, err
}
if err := api.node.http.enableRPC(api.node.rpcAPIs, config); err != nil {
return false, err
}
if err := api.node.http.start(); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
// StopRPC terminates an already running HTTP RPC API endpoint. // StopRPC shuts down the HTTP server.
func (api *PrivateAdminAPI) StopRPC() (bool, error) { func (api *privateAdminAPI) StopRPC() (bool, error) {
api.node.lock.Lock() api.node.http.stop()
defer api.node.lock.Unlock()
if api.node.httpHandler == nil {
return false, fmt.Errorf("HTTP RPC not running")
}
api.node.stopHTTP()
return true, nil return true, nil
} }
// StartWS starts the websocket RPC API server. // StartWS starts the websocket RPC API server.
func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *string, apis *string) (bool, error) { func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *string, apis *string) (bool, error) {
api.node.lock.Lock() api.node.lock.Lock()
defer api.node.lock.Unlock() defer api.node.lock.Unlock()
if api.node.wsHandler != nil { // Determine host and port.
return false, fmt.Errorf("WebSocket RPC already running on %s", api.node.wsEndpoint)
}
if host == nil { if host == nil {
h := DefaultWSHost h := DefaultWSHost
if api.node.config.WSHost != "" { if api.node.config.WSHost != "" {
@ -224,55 +239,56 @@ func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str
port = &api.node.config.WSPort port = &api.node.config.WSPort
} }
origins := api.node.config.WSOrigins // Determine config.
if allowedOrigins != nil { config := wsConfig{
origins = nil Modules: api.node.config.WSModules,
for _, origin := range strings.Split(*allowedOrigins, ",") { Origins: api.node.config.WSOrigins,
origins = append(origins, strings.TrimSpace(origin)) // ExposeAll: api.node.config.WSExposeAll,
}
} }
modules := api.node.config.WSModules
if apis != nil { if apis != nil {
modules = nil config.Modules = nil
for _, m := range strings.Split(*apis, ",") { for _, m := range strings.Split(*apis, ",") {
modules = append(modules, strings.TrimSpace(m)) config.Modules = append(config.Modules, strings.TrimSpace(m))
}
}
if allowedOrigins != nil {
config.Origins = nil
for _, origin := range strings.Split(*allowedOrigins, ",") {
config.Origins = append(config.Origins, strings.TrimSpace(origin))
} }
} }
if err := api.node.startWS(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, origins, api.node.config.WSExposeAll); err != nil { // Enable WebSocket on the server.
server := api.node.wsServerForPort(*port)
if err := server.setListenAddr(*host, *port); err != nil {
return false, err return false, err
} }
return true, nil if err := server.enableWS(api.node.rpcAPIs, config); err != nil {
} return false, err
// StopWS terminates an already running websocket RPC API endpoint.
func (api *PrivateAdminAPI) StopWS() (bool, error) {
api.node.lock.Lock()
defer api.node.lock.Unlock()
if api.node.wsHandler == nil {
return false, fmt.Errorf("WebSocket RPC not running")
} }
api.node.stopWS() if err := server.start(); err != nil {
return false, err
}
api.node.http.log.Info("WebSocket endpoint opened", "url", api.node.WSEndpoint())
return true, nil return true, nil
} }
// PublicAdminAPI is the collection of administrative API methods exposed over // StopWS terminates all WebSocket servers.
// both secure and unsecure RPC channels. func (api *privateAdminAPI) StopWS() (bool, error) {
type PublicAdminAPI struct { api.node.http.stopWS()
node *Node // Node interfaced by this API api.node.ws.stop()
return true, nil
} }
// NewPublicAdminAPI creates a new API definition for the public admin methods // publicAdminAPI is the collection of administrative API methods exposed over
// of the node itself. // both secure and unsecure RPC channels.
func NewPublicAdminAPI(node *Node) *PublicAdminAPI { type publicAdminAPI struct {
return &PublicAdminAPI{node: node} node *Node // Node interfaced by this API
} }
// Peers retrieves all the information we know about each individual peer at the // Peers retrieves all the information we know about each individual peer at the
// protocol granularity. // protocol granularity.
func (api *PublicAdminAPI) Peers() ([]*p2p.PeerInfo, error) { func (api *publicAdminAPI) Peers() ([]*p2p.PeerInfo, error) {
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
return nil, ErrNodeStopped return nil, ErrNodeStopped
@ -282,7 +298,7 @@ func (api *PublicAdminAPI) Peers() ([]*p2p.PeerInfo, error) {
// NodeInfo retrieves all the information we know about the host node at the // NodeInfo retrieves all the information we know about the host node at the
// protocol granularity. // protocol granularity.
func (api *PublicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) { func (api *publicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) {
server := api.node.Server() server := api.node.Server()
if server == nil { if server == nil {
return nil, ErrNodeStopped return nil, ErrNodeStopped
@ -291,27 +307,22 @@ func (api *PublicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) {
} }
// Datadir retrieves the current data directory the node is using. // Datadir retrieves the current data directory the node is using.
func (api *PublicAdminAPI) Datadir() string { func (api *publicAdminAPI) Datadir() string {
return api.node.DataDir() return api.node.DataDir()
} }
// PublicWeb3API offers helper utils // publicWeb3API offers helper utils
type PublicWeb3API struct { type publicWeb3API struct {
stack *Node stack *Node
} }
// NewPublicWeb3API creates a new Web3Service instance
func NewPublicWeb3API(stack *Node) *PublicWeb3API {
return &PublicWeb3API{stack}
}
// ClientVersion returns the node name // ClientVersion returns the node name
func (s *PublicWeb3API) ClientVersion() string { func (s *publicWeb3API) ClientVersion() string {
return s.stack.Server().Name return s.stack.Server().Name
} }
// Sha3 applies the ethereum sha3 implementation on the input. // Sha3 applies the ethereum sha3 implementation on the input.
// It assumes the input is hex encoded. // It assumes the input is hex encoded.
func (s *PublicWeb3API) Sha3(input hexutil.Bytes) hexutil.Bytes { func (s *publicWeb3API) Sha3(input hexutil.Bytes) hexutil.Bytes {
return crypto.Keccak256(input) return crypto.Keccak256(input)
} }

350
node/api_test.go Normal file

@ -0,0 +1,350 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"bytes"
"io"
"net"
"net/http"
"net/url"
"strings"
"testing"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/assert"
)
// This test uses the admin_startRPC and admin_startWS APIs,
// checking whether the HTTP server is started correctly.
func TestStartRPC(t *testing.T) {
type test struct {
name string
cfg Config
fn func(*testing.T, *Node, *privateAdminAPI)
// Checks. These run after the node is configured and all API calls have been made.
wantReachable bool // whether the HTTP server should be reachable at all
wantHandlers bool // whether RegisterHandler handlers should be accessible
wantRPC bool // whether JSON-RPC/HTTP should be accessible
wantWS bool // whether JSON-RPC/WS should be accessible
}
tests := []test{
{
name: "all off",
cfg: Config{},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "rpc enabled through config",
cfg: Config{HTTPHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: false,
},
{
name: "rpc enabled through API",
cfg: Config{},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StartRPC(sp("127.0.0.1"), ip(0), nil, nil, nil)
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: false,
},
{
name: "rpc start again after failure",
cfg: Config{},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
// Listen on a random port.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("can't listen:", err)
}
defer listener.Close()
port := listener.Addr().(*net.TCPAddr).Port
// Now try to start RPC on that port. This should fail.
_, err = api.StartRPC(sp("127.0.0.1"), ip(port), nil, nil, nil)
if err == nil {
t.Fatal("StartRPC should have failed on port", port)
}
// Try again after unblocking the port. It should work this time.
listener.Close()
_, err = api.StartRPC(sp("127.0.0.1"), ip(port), nil, nil, nil)
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: false,
},
{
name: "rpc stopped through API",
cfg: Config{HTTPHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StopRPC()
assert.NoError(t, err)
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "rpc stopped twice",
cfg: Config{HTTPHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StopRPC()
assert.NoError(t, err)
_, err = api.StopRPC()
assert.NoError(t, err)
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "ws enabled through config",
cfg: Config{WSHost: "127.0.0.1"},
wantReachable: true,
wantHandlers: false,
wantRPC: false,
wantWS: true,
},
{
name: "ws enabled through API",
cfg: Config{},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StartWS(sp("127.0.0.1"), ip(0), nil, nil)
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: false,
wantRPC: false,
wantWS: true,
},
{
name: "ws stopped through API",
cfg: Config{WSHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StopWS()
assert.NoError(t, err)
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "ws stopped twice",
cfg: Config{WSHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StopWS()
assert.NoError(t, err)
_, err = api.StopWS()
assert.NoError(t, err)
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "ws enabled after RPC",
cfg: Config{HTTPHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
wsport := n.http.port
_, err := api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil)
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: true,
},
{
name: "ws enabled after RPC then stopped",
cfg: Config{HTTPHost: "127.0.0.1"},
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
wsport := n.http.port
_, err := api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil)
assert.NoError(t, err)
_, err = api.StopWS()
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: false,
},
{
name: "rpc stopped with ws enabled",
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StartRPC(sp("127.0.0.1"), ip(0), nil, nil, nil)
assert.NoError(t, err)
wsport := n.http.port
_, err = api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil)
assert.NoError(t, err)
_, err = api.StopRPC()
assert.NoError(t, err)
},
wantReachable: false,
wantHandlers: false,
wantRPC: false,
wantWS: false,
},
{
name: "rpc enabled after ws",
fn: func(t *testing.T, n *Node, api *privateAdminAPI) {
_, err := api.StartWS(sp("127.0.0.1"), ip(0), nil, nil)
assert.NoError(t, err)
wsport := n.http.port
_, err = api.StartRPC(sp("127.0.0.1"), ip(wsport), nil, nil, nil)
assert.NoError(t, err)
},
wantReachable: true,
wantHandlers: true,
wantRPC: true,
wantWS: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Apply some sane defaults.
config := test.cfg
// config.Logger = testlog.Logger(t, log.LvlDebug)
config.NoUSB = true
config.P2P.NoDiscovery = true
// Create Node.
stack, err := New(&config)
if err != nil {
t.Fatal("can't create node:", err)
}
defer stack.Close()
// Register the test handler.
stack.RegisterHandler("test", "/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
if err := stack.Start(); err != nil {
t.Fatal("can't start node:", err)
}
// Run the API call hook.
if test.fn != nil {
test.fn(t, stack, &privateAdminAPI{stack})
}
// Check if the HTTP endpoints are available.
baseURL := stack.HTTPEndpoint()
reachable := checkReachable(baseURL)
handlersAvailable := checkBodyOK(baseURL + "/test")
rpcAvailable := checkRPC(baseURL)
wsAvailable := checkRPC(strings.Replace(baseURL, "http://", "ws://", 1))
if reachable != test.wantReachable {
t.Errorf("HTTP server is %sreachable, want it %sreachable", not(reachable), not(test.wantReachable))
}
if handlersAvailable != test.wantHandlers {
t.Errorf("RegisterHandler handlers %savailable, want them %savailable", not(handlersAvailable), not(test.wantHandlers))
}
if rpcAvailable != test.wantRPC {
t.Errorf("HTTP RPC %savailable, want it %savailable", not(rpcAvailable), not(test.wantRPC))
}
if wsAvailable != test.wantWS {
t.Errorf("WS RPC %savailable, want it %savailable", not(wsAvailable), not(test.wantWS))
}
})
}
}
// checkReachable checks if the TCP endpoint in rawurl is open.
func checkReachable(rawurl string) bool {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
conn, err := net.Dial("tcp", u.Host)
if err != nil {
return false
}
conn.Close()
return true
}
// checkBodyOK checks whether the given HTTP URL responds with 200 OK and body "OK".
func checkBodyOK(url string) bool {
resp, err := http.Get(url)
if err != nil {
return false
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return false
}
buf := make([]byte, 2)
if _, err = io.ReadFull(resp.Body, buf); err != nil {
return false
}
return bytes.Equal(buf, []byte("OK"))
}
// checkRPC checks whether JSON-RPC works against the given URL.
func checkRPC(url string) bool {
c, err := rpc.Dial(url)
if err != nil {
return false
}
defer c.Close()
_, err = c.SupportedModules()
return err == nil
}
// string/int pointer helpers.
func sp(s string) *string { return &s }
func ip(i int) *int { return &i }
func not(ok bool) string {
if ok {
return ""
}
return "not "
}

@ -162,15 +162,6 @@ type Config struct {
// private APIs to untrusted users is a major security risk. // private APIs to untrusted users is a major security risk.
WSExposeAll bool `toml:",omitempty"` WSExposeAll bool `toml:",omitempty"`
// GraphQLHost is the host interface on which to start the GraphQL server. If this
// field is empty, no GraphQL API endpoint will be started.
GraphQLHost string
// GraphQLPort is the TCP port number on which to start the GraphQL server. The
// default zero value is/ valid and will pick a port number randomly (useful
// for ephemeral nodes).
GraphQLPort int `toml:",omitempty"`
// GraphQLCors is the Cross-Origin Resource Sharing header to send to requesting // GraphQLCors is the Cross-Origin Resource Sharing header to send to requesting
// clients. Please be aware that CORS is a browser enforced security, it's fully // clients. Please be aware that CORS is a browser enforced security, it's fully
// useless for custom HTTP clients. // useless for custom HTTP clients.
@ -247,15 +238,6 @@ func (c *Config) HTTPEndpoint() string {
return fmt.Sprintf("%s:%d", c.HTTPHost, c.HTTPPort) return fmt.Sprintf("%s:%d", c.HTTPHost, c.HTTPPort)
} }
// GraphQLEndpoint resolves a GraphQL endpoint based on the configured host interface
// and port parameters.
func (c *Config) GraphQLEndpoint() string {
if c.GraphQLHost == "" {
return ""
}
return fmt.Sprintf("%s:%d", c.GraphQLHost, c.GraphQLPort)
}
// DefaultHTTPEndpoint returns the HTTP endpoint used by default. // DefaultHTTPEndpoint returns the HTTP endpoint used by default.
func DefaultHTTPEndpoint() string { func DefaultHTTPEndpoint() string {
config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort} config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort}
@ -280,7 +262,7 @@ func DefaultWSEndpoint() string {
// ExtRPCEnabled returns the indicator whether node enables the external // ExtRPCEnabled returns the indicator whether node enables the external
// RPC(http, ws or graphql). // RPC(http, ws or graphql).
func (c *Config) ExtRPCEnabled() bool { func (c *Config) ExtRPCEnabled() bool {
return c.HTTPHost != "" || c.WSHost != "" || c.GraphQLHost != "" return c.HTTPHost != "" || c.WSHost != ""
} }
// NodeName returns the devp2p node identifier. // NodeName returns the devp2p node identifier.

@ -45,7 +45,6 @@ var DefaultConfig = Config{
HTTPTimeouts: rpc.DefaultHTTPTimeouts, HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort, WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"}, WSModules: []string{"net", "web3"},
GraphQLPort: DefaultGraphQLPort,
GraphQLVirtualHosts: []string{"localhost"}, GraphQLVirtualHosts: []string{"localhost"},
P2P: p2p.Config{ P2P: p2p.Config{
ListenAddr: ":30303", ListenAddr: ":30303",

@ -22,6 +22,43 @@ resources to provide RPC APIs. Services can also offer devp2p protocols, which a
up to the devp2p network when the node instance is started. up to the devp2p network when the node instance is started.
Node Lifecycle
The Node object has a lifecycle consisting of three basic states, INITIALIZING, RUNNING
and CLOSED.
New()
INITIALIZING Start()
Close() RUNNING
CLOSED Close()
Creating a Node allocates basic resources such as the data directory and returns the node
in its INITIALIZING state. Lifecycle objects, RPC APIs and peer-to-peer networking
protocols can be registered in this state. Basic operations such as opening a key-value
database are permitted while initializing.
Once everything is registered, the node can be started, which moves it into the RUNNING
state. Starting the node starts all registered Lifecycle objects and enables RPC and
peer-to-peer networking. Note that no additional Lifecycles, APIs or p2p protocols can be
registered while the node is running.
Closing the node releases all held resources. The actions performed by Close depend on the
state it was in. When closing a node in INITIALIZING state, resources related to the data
directory are released. If the node was RUNNING, closing it also stops all Lifecycle
objects and shuts down RPC and peer-to-peer networking.
You must always call Close on Node, even if the node was not started.
Resources Managed By Node Resources Managed By Node
All file-system resources used by a node instance are located in a directory called the All file-system resources used by a node instance are located in a directory called the

@ -48,21 +48,6 @@ func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.
return httpSrv, listener.Addr(), err return httpSrv, listener.Addr(), err
} }
// startWSEndpoint starts a websocket endpoint.
func startWSEndpoint(endpoint string, handler http.Handler) (*http.Server, net.Addr, error) {
// start the HTTP listener
var (
listener net.Listener
err error
)
if listener, err = net.Listen("tcp", endpoint); err != nil {
return nil, nil, err
}
wsSrv := &http.Server{Handler: handler}
go wsSrv.Serve(listener)
return wsSrv, listener.Addr(), err
}
// checkModuleAvailability checks that all names given in modules are actually // checkModuleAvailability checks that all names given in modules are actually
// available API services. It assumes that the MetadataApi module ("rpc") is always available; // available API services. It assumes that the MetadataApi module ("rpc") is always available;
// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints. // the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.

@ -39,17 +39,6 @@ func convertFileLockError(err error) error {
return err return err
} }
// DuplicateServiceError is returned during Node startup if a registered service
// constructor returns a service of the same type that was already started.
type DuplicateServiceError struct {
Kind reflect.Type
}
// Error generates a textual representation of the duplicate service error.
func (e *DuplicateServiceError) Error() string {
return fmt.Sprintf("duplicate service: %v", e.Kind)
}
// StopError is returned if a Node fails to stop either any of its registered // StopError is returned if a Node fails to stop either any of its registered
// services or itself. // services or itself.
type StopError struct { type StopError struct {

31
node/lifecycle.go Normal file

@ -0,0 +1,31 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
// Lifecycle encompasses the behavior of services that can be started and stopped
// on the node. Lifecycle management is delegated to the node, but it is the
// responsibility of the service-specific package to configure and register the
// service on the node using the `RegisterLifecycle` method.
type Lifecycle interface {
// Start is called after all services have been constructed and the networking
// layer was also initialized to spawn any goroutines required by the service.
Start() error
// Stop terminates all goroutines belonging to the service, blocking until they
// are all terminated.
Stop() error
}

File diff suppressed because it is too large Load Diff

@ -21,26 +21,20 @@ import (
"log" "log"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
) )
// SampleService is a trivial network service that can be attached to a node for // SampleLifecycle is a trivial network service that can be attached to a node for
// life cycle management. // life cycle management.
// //
// The following methods are needed to implement a node.Service: // The following methods are needed to implement a node.Lifecycle:
// - Protocols() []p2p.Protocol - devp2p protocols the service can communicate on
// - APIs() []rpc.API - api methods the service wants to expose on rpc channels
// - Start() error - method invoked when the node is ready to start the service // - Start() error - method invoked when the node is ready to start the service
// - Stop() error - method invoked when the node terminates the service // - Stop() error - method invoked when the node terminates the service
type SampleService struct{} type SampleLifecycle struct{}
func (s *SampleService) Protocols() []p2p.Protocol { return nil } func (s *SampleLifecycle) Start() error { fmt.Println("Service starting..."); return nil }
func (s *SampleService) APIs() []rpc.API { return nil } func (s *SampleLifecycle) Stop() error { fmt.Println("Service stopping..."); return nil }
func (s *SampleService) Start(*p2p.Server) error { fmt.Println("Service starting..."); return nil }
func (s *SampleService) Stop() error { fmt.Println("Service stopping..."); return nil }
func ExampleService() { func ExampleLifecycle() {
// Create a network node to run protocols with the default values. // Create a network node to run protocols with the default values.
stack, err := node.New(&node.Config{}) stack, err := node.New(&node.Config{})
if err != nil { if err != nil {
@ -48,29 +42,18 @@ func ExampleService() {
} }
defer stack.Close() defer stack.Close()
// Create and register a simple network service. This is done through the definition // Create and register a simple network Lifecycle.
// of a node.ServiceConstructor that will instantiate a node.Service. The reason for service := new(SampleLifecycle)
// the factory method approach is to support service restarts without relying on the stack.RegisterLifecycle(service)
// individual implementations' support for such operations.
constructor := func(context *node.ServiceContext) (node.Service, error) {
return new(SampleService), nil
}
if err := stack.Register(constructor); err != nil {
log.Fatalf("Failed to register service: %v", err)
}
// Boot up the entire protocol stack, do a restart and terminate // Boot up the entire protocol stack, do a restart and terminate
if err := stack.Start(); err != nil { if err := stack.Start(); err != nil {
log.Fatalf("Failed to start the protocol stack: %v", err) log.Fatalf("Failed to start the protocol stack: %v", err)
} }
if err := stack.Restart(); err != nil { if err := stack.Close(); err != nil {
log.Fatalf("Failed to restart the protocol stack: %v", err)
}
if err := stack.Stop(); err != nil {
log.Fatalf("Failed to stop the protocol stack: %v", err) log.Fatalf("Failed to stop the protocol stack: %v", err)
} }
// Output: // Output:
// Service starting... // Service starting...
// Service stopping... // Service stopping...
// Service starting...
// Service stopping...
} }

@ -18,14 +18,18 @@ package node
import ( import (
"errors" "errors"
"fmt"
"io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"strings"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
@ -43,20 +47,28 @@ func testNodeConfig() *Config {
} }
} }
// Tests that an empty protocol stack can be started, restarted and stopped. // Tests that an empty protocol stack can be closed more than once.
func TestNodeLifeCycle(t *testing.T) { func TestNodeCloseMultipleTimes(t *testing.T) {
stack, err := New(testNodeConfig()) stack, err := New(testNodeConfig())
if err != nil { if err != nil {
t.Fatalf("failed to create protocol stack: %v", err) t.Fatalf("failed to create protocol stack: %v", err)
} }
defer stack.Close() stack.Close()
// Ensure that a stopped node can be stopped again // Ensure that a stopped node can be stopped again
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if err := stack.Stop(); err != ErrNodeStopped { if err := stack.Close(); err != ErrNodeStopped {
t.Fatalf("iter %d: stop failure mismatch: have %v, want %v", i, err, ErrNodeStopped) t.Fatalf("iter %d: stop failure mismatch: have %v, want %v", i, err, ErrNodeStopped)
} }
} }
}
func TestNodeStartMultipleTimes(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
// Ensure that a node can be successfully started, but only once // Ensure that a node can be successfully started, but only once
if err := stack.Start(); err != nil { if err := stack.Start(); err != nil {
t.Fatalf("failed to start node: %v", err) t.Fatalf("failed to start node: %v", err)
@ -64,17 +76,11 @@ func TestNodeLifeCycle(t *testing.T) {
if err := stack.Start(); err != ErrNodeRunning { if err := stack.Start(); err != ErrNodeRunning {
t.Fatalf("start failure mismatch: have %v, want %v ", err, ErrNodeRunning) t.Fatalf("start failure mismatch: have %v, want %v ", err, ErrNodeRunning)
} }
// Ensure that a node can be restarted arbitrarily many times
for i := 0; i < 3; i++ {
if err := stack.Restart(); err != nil {
t.Fatalf("iter %d: failed to restart node: %v", i, err)
}
}
// Ensure that a node can be stopped, but only once // Ensure that a node can be stopped, but only once
if err := stack.Stop(); err != nil { if err := stack.Close(); err != nil {
t.Fatalf("failed to stop node: %v", err) t.Fatalf("failed to stop node: %v", err)
} }
if err := stack.Stop(); err != ErrNodeStopped { if err := stack.Close(); err != ErrNodeStopped {
t.Fatalf("stop failure mismatch: have %v, want %v ", err, ErrNodeStopped) t.Fatalf("stop failure mismatch: have %v, want %v ", err, ErrNodeStopped)
} }
} }
@ -94,92 +100,152 @@ func TestNodeUsedDataDir(t *testing.T) {
t.Fatalf("failed to create original protocol stack: %v", err) t.Fatalf("failed to create original protocol stack: %v", err)
} }
defer original.Close() defer original.Close()
if err := original.Start(); err != nil { if err := original.Start(); err != nil {
t.Fatalf("failed to start original protocol stack: %v", err) t.Fatalf("failed to start original protocol stack: %v", err)
} }
defer original.Stop()
// Create a second node based on the same data directory and ensure failure // Create a second node based on the same data directory and ensure failure
duplicate, err := New(&Config{DataDir: dir}) _, err = New(&Config{DataDir: dir})
if err != nil { if err != ErrDatadirUsed {
t.Fatalf("failed to create duplicate protocol stack: %v", err)
}
defer duplicate.Close()
if err := duplicate.Start(); err != ErrDatadirUsed {
t.Fatalf("duplicate datadir failure mismatch: have %v, want %v", err, ErrDatadirUsed) t.Fatalf("duplicate datadir failure mismatch: have %v, want %v", err, ErrDatadirUsed)
} }
} }
// Tests whether services can be registered and duplicates caught. // Tests whether a Lifecycle can be registered.
func TestServiceRegistry(t *testing.T) { func TestLifecycleRegistry_Successful(t *testing.T) {
stack, err := New(testNodeConfig()) stack, err := New(testNodeConfig())
if err != nil { if err != nil {
t.Fatalf("failed to create protocol stack: %v", err) t.Fatalf("failed to create protocol stack: %v", err)
} }
defer stack.Close() defer stack.Close()
// Register a batch of unique services and ensure they start successfully noop := NewNoop()
services := []ServiceConstructor{NewNoopServiceA, NewNoopServiceB, NewNoopServiceC} stack.RegisterLifecycle(noop)
for i, constructor := range services {
if err := stack.Register(constructor); err != nil { if !containsLifecycle(stack.lifecycles, noop) {
t.Fatalf("service #%d: registration failed: %v", i, err) t.Fatalf("lifecycle was not properly registered on the node, %v", err)
}
}
// Tests whether a service's protocols can be registered properly on the node's p2p server.
func TestRegisterProtocols(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
fs, err := NewFullService(stack)
if err != nil {
t.Fatalf("could not create full service: %v", err)
}
for _, protocol := range fs.Protocols() {
if !containsProtocol(stack.server.Protocols, protocol) {
t.Fatalf("protocol %v was not successfully registered", protocol)
} }
} }
if err := stack.Start(); err != nil {
t.Fatalf("failed to start original service stack: %v", err) for _, api := range fs.APIs() {
} if !containsAPI(stack.rpcAPIs, api) {
if err := stack.Stop(); err != nil { t.Fatalf("api %v was not successfully registered", api)
t.Fatalf("failed to stop original service stack: %v", err)
}
// Duplicate one of the services and retry starting the node
if err := stack.Register(NewNoopServiceB); err != nil {
t.Fatalf("duplicate registration failed: %v", err)
}
if err := stack.Start(); err == nil {
t.Fatalf("duplicate service started")
} else {
if _, ok := err.(*DuplicateServiceError); !ok {
t.Fatalf("duplicate error mismatch: have %v, want %v", err, DuplicateServiceError{})
} }
} }
} }
// Tests that registered services get started and stopped correctly. // This test checks that open databases are closed with node.
func TestServiceLifeCycle(t *testing.T) { func TestNodeCloseClosesDB(t *testing.T) {
stack, err := New(testNodeConfig()) stack, _ := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close() defer stack.Close()
// Register a batch of life-cycle instrumented services db, err := stack.OpenDatabase("mydb", 0, 0, "")
services := map[string]InstrumentingWrapper{ if err != nil {
"A": InstrumentedServiceMakerA, t.Fatal("can't open DB:", err)
"B": InstrumentedServiceMakerB,
"C": InstrumentedServiceMakerC,
} }
if err = db.Put([]byte{}, []byte{}); err != nil {
t.Fatal("can't Put on open DB:", err)
}
stack.Close()
if err = db.Put([]byte{}, []byte{}); err == nil {
t.Fatal("Put succeeded after node is closed")
}
}
// This test checks that OpenDatabase can be used from within a Lifecycle Start method.
func TestNodeOpenDatabaseFromLifecycleStart(t *testing.T) {
stack, _ := New(testNodeConfig())
defer stack.Close()
var db ethdb.Database
var err error
stack.RegisterLifecycle(&InstrumentedService{
startHook: func() {
db, err = stack.OpenDatabase("mydb", 0, 0, "")
if err != nil {
t.Fatal("can't open DB:", err)
}
},
stopHook: func() {
db.Close()
},
})
stack.Start()
stack.Close()
}
// This test checks that OpenDatabase can be used from within a Lifecycle Stop method.
func TestNodeOpenDatabaseFromLifecycleStop(t *testing.T) {
stack, _ := New(testNodeConfig())
defer stack.Close()
stack.RegisterLifecycle(&InstrumentedService{
stopHook: func() {
db, err := stack.OpenDatabase("mydb", 0, 0, "")
if err != nil {
t.Fatal("can't open DB:", err)
}
db.Close()
},
})
stack.Start()
stack.Close()
}
// Tests that registered Lifecycles get started and stopped correctly.
func TestLifecycleLifeCycle(t *testing.T) {
stack, _ := New(testNodeConfig())
defer stack.Close()
started := make(map[string]bool) started := make(map[string]bool)
stopped := make(map[string]bool) stopped := make(map[string]bool)
for id, maker := range services { // Create a batch of instrumented services
id := id // Closure for the constructor lifecycles := map[string]Lifecycle{
constructor := func(*ServiceContext) (Service, error) { "A": &InstrumentedService{
return &InstrumentedService{ startHook: func() { started["A"] = true },
startHook: func(*p2p.Server) { started[id] = true }, stopHook: func() { stopped["A"] = true },
stopHook: func() { stopped[id] = true }, },
}, nil "B": &InstrumentedService{
} startHook: func() { started["B"] = true },
if err := stack.Register(maker(constructor)); err != nil { stopHook: func() { stopped["B"] = true },
t.Fatalf("service %s: registration failed: %v", id, err) },
} "C": &InstrumentedService{
startHook: func() { started["C"] = true },
stopHook: func() { stopped["C"] = true },
},
}
// register lifecycles on node
for _, lifecycle := range lifecycles {
stack.RegisterLifecycle(lifecycle)
} }
// Start the node and check that all services are running // Start the node and check that all services are running
if err := stack.Start(); err != nil { if err := stack.Start(); err != nil {
t.Fatalf("failed to start protocol stack: %v", err) t.Fatalf("failed to start protocol stack: %v", err)
} }
for id := range services { for id := range lifecycles {
if !started[id] { if !started[id] {
t.Fatalf("service %s: freshly started service not running", id) t.Fatalf("service %s: freshly started service not running", id)
} }
@ -188,470 +254,286 @@ func TestServiceLifeCycle(t *testing.T) {
} }
} }
// Stop the node and check that all services have been stopped // Stop the node and check that all services have been stopped
if err := stack.Stop(); err != nil { if err := stack.Close(); err != nil {
t.Fatalf("failed to stop protocol stack: %v", err) t.Fatalf("failed to stop protocol stack: %v", err)
} }
for id := range services { for id := range lifecycles {
if !stopped[id] { if !stopped[id] {
t.Fatalf("service %s: freshly terminated service still running", id) t.Fatalf("service %s: freshly terminated service still running", id)
} }
} }
} }
// Tests that services are restarted cleanly as new instances. // Tests that if a Lifecycle fails to start, all others started before it will be
func TestServiceRestarts(t *testing.T) { // shut down.
func TestLifecycleStartupError(t *testing.T) {
stack, err := New(testNodeConfig()) stack, err := New(testNodeConfig())
if err != nil { if err != nil {
t.Fatalf("failed to create protocol stack: %v", err) t.Fatalf("failed to create protocol stack: %v", err)
} }
defer stack.Close() defer stack.Close()
// Define a service that does not support restarts
var (
running bool
started int
)
constructor := func(*ServiceContext) (Service, error) {
running = false
return &InstrumentedService{
startHook: func(*p2p.Server) {
if running {
panic("already running")
}
running = true
started++
},
}, nil
}
// Register the service and start the protocol stack
if err := stack.Register(constructor); err != nil {
t.Fatalf("failed to register the service: %v", err)
}
if err := stack.Start(); err != nil {
t.Fatalf("failed to start protocol stack: %v", err)
}
defer stack.Stop()
if !running || started != 1 {
t.Fatalf("running/started mismatch: have %v/%d, want true/1", running, started)
}
// Restart the stack a few times and check successful service restarts
for i := 0; i < 3; i++ {
if err := stack.Restart(); err != nil {
t.Fatalf("iter %d: failed to restart stack: %v", i, err)
}
}
if !running || started != 4 {
t.Fatalf("running/started mismatch: have %v/%d, want true/4", running, started)
}
}
// Tests that if a service fails to initialize itself, none of the other services
// will be allowed to even start.
func TestServiceConstructionAbortion(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
// Define a batch of good services
services := map[string]InstrumentingWrapper{
"A": InstrumentedServiceMakerA,
"B": InstrumentedServiceMakerB,
"C": InstrumentedServiceMakerC,
}
started := make(map[string]bool) started := make(map[string]bool)
for id, maker := range services { stopped := make(map[string]bool)
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) { // Create a batch of instrumented services
return &InstrumentedService{ lifecycles := map[string]Lifecycle{
startHook: func(*p2p.Server) { started[id] = true }, "A": &InstrumentedService{
}, nil startHook: func() { started["A"] = true },
} stopHook: func() { stopped["A"] = true },
if err := stack.Register(maker(constructor)); err != nil { },
t.Fatalf("service %s: registration failed: %v", id, err) "B": &InstrumentedService{
} startHook: func() { started["B"] = true },
stopHook: func() { stopped["B"] = true },
},
"C": &InstrumentedService{
startHook: func() { started["C"] = true },
stopHook: func() { stopped["C"] = true },
},
} }
// register lifecycles on node
for _, lifecycle := range lifecycles {
stack.RegisterLifecycle(lifecycle)
}
// Register a service that fails to construct itself // Register a service that fails to construct itself
failure := errors.New("fail") failure := errors.New("fail")
failer := func(*ServiceContext) (Service, error) { failer := &InstrumentedService{start: failure}
return nil, failure stack.RegisterLifecycle(failer)
}
if err := stack.Register(failer); err != nil {
t.Fatalf("failer registration failed: %v", err)
}
// Start the protocol stack and ensure none of the services get started
for i := 0; i < 100; i++ {
if err := stack.Start(); err != failure {
t.Fatalf("iter %d: stack startup failure mismatch: have %v, want %v", i, err, failure)
}
for id := range services {
if started[id] {
t.Fatalf("service %s: started should not have", id)
}
delete(started, id)
}
}
}
// Tests that if a service fails to start, all others started before it will be
// shut down.
func TestServiceStartupAbortion(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
// Register a batch of good services
services := map[string]InstrumentingWrapper{
"A": InstrumentedServiceMakerA,
"B": InstrumentedServiceMakerB,
"C": InstrumentedServiceMakerC,
}
started := make(map[string]bool)
stopped := make(map[string]bool)
for id, maker := range services {
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
startHook: func(*p2p.Server) { started[id] = true },
stopHook: func() { stopped[id] = true },
}, nil
}
if err := stack.Register(maker(constructor)); err != nil {
t.Fatalf("service %s: registration failed: %v", id, err)
}
}
// Register a service that fails to start
failure := errors.New("fail")
failer := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
start: failure,
}, nil
}
if err := stack.Register(failer); err != nil {
t.Fatalf("failer registration failed: %v", err)
}
// Start the protocol stack and ensure all started services stop // Start the protocol stack and ensure all started services stop
for i := 0; i < 100; i++ { if err := stack.Start(); err != failure {
if err := stack.Start(); err != failure { t.Fatalf("stack startup failure mismatch: have %v, want %v", err, failure)
t.Fatalf("iter %d: stack startup failure mismatch: have %v, want %v", i, err, failure) }
} for id := range lifecycles {
for id := range services { if started[id] && !stopped[id] {
if started[id] && !stopped[id] { t.Fatalf("service %s: started but not stopped", id)
t.Fatalf("service %s: started but not stopped", id)
}
delete(started, id)
delete(stopped, id)
} }
delete(started, id)
delete(stopped, id)
} }
} }
// Tests that even if a registered service fails to shut down cleanly, it does // Tests that even if a registered Lifecycle fails to shut down cleanly, it does
// not influence the rest of the shutdown invocations. // not influence the rest of the shutdown invocations.
func TestServiceTerminationGuarantee(t *testing.T) { func TestLifecycleTerminationGuarantee(t *testing.T) {
stack, err := New(testNodeConfig()) stack, err := New(testNodeConfig())
if err != nil { if err != nil {
t.Fatalf("failed to create protocol stack: %v", err) t.Fatalf("failed to create protocol stack: %v", err)
} }
defer stack.Close() defer stack.Close()
// Register a batch of good services
services := map[string]InstrumentingWrapper{
"A": InstrumentedServiceMakerA,
"B": InstrumentedServiceMakerB,
"C": InstrumentedServiceMakerC,
}
started := make(map[string]bool) started := make(map[string]bool)
stopped := make(map[string]bool) stopped := make(map[string]bool)
for id, maker := range services { // Create a batch of instrumented services
id := id // Closure for the constructor lifecycles := map[string]Lifecycle{
constructor := func(*ServiceContext) (Service, error) { "A": &InstrumentedService{
return &InstrumentedService{ startHook: func() { started["A"] = true },
startHook: func(*p2p.Server) { started[id] = true }, stopHook: func() { stopped["A"] = true },
stopHook: func() { stopped[id] = true }, },
}, nil "B": &InstrumentedService{
} startHook: func() { started["B"] = true },
if err := stack.Register(maker(constructor)); err != nil { stopHook: func() { stopped["B"] = true },
t.Fatalf("service %s: registration failed: %v", id, err) },
} "C": &InstrumentedService{
startHook: func() { started["C"] = true },
stopHook: func() { stopped["C"] = true },
},
} }
// register lifecycles on node
for _, lifecycle := range lifecycles {
stack.RegisterLifecycle(lifecycle)
}
// Register a service that fails to shot down cleanly // Register a service that fails to shot down cleanly
failure := errors.New("fail") failure := errors.New("fail")
failer := func(*ServiceContext) (Service, error) { failer := &InstrumentedService{stop: failure}
return &InstrumentedService{ stack.RegisterLifecycle(failer)
stop: failure,
}, nil
}
if err := stack.Register(failer); err != nil {
t.Fatalf("failer registration failed: %v", err)
}
// Start the protocol stack, and ensure that a failing shut down terminates all // Start the protocol stack, and ensure that a failing shut down terminates all
for i := 0; i < 100; i++ { // Start the stack and make sure all is online
// Start the stack and make sure all is online
if err := stack.Start(); err != nil {
t.Fatalf("iter %d: failed to start protocol stack: %v", i, err)
}
for id := range services {
if !started[id] {
t.Fatalf("iter %d, service %s: service not running", i, id)
}
if stopped[id] {
t.Fatalf("iter %d, service %s: service already stopped", i, id)
}
}
// Stop the stack, verify failure and check all terminations
err := stack.Stop()
if err, ok := err.(*StopError); !ok {
t.Fatalf("iter %d: termination failure mismatch: have %v, want StopError", i, err)
} else {
failer := reflect.TypeOf(&InstrumentedService{})
if err.Services[failer] != failure {
t.Fatalf("iter %d: failer termination failure mismatch: have %v, want %v", i, err.Services[failer], failure)
}
if len(err.Services) != 1 {
t.Fatalf("iter %d: failure count mismatch: have %d, want %d", i, len(err.Services), 1)
}
}
for id := range services {
if !stopped[id] {
t.Fatalf("iter %d, service %s: service not terminated", i, id)
}
delete(started, id)
delete(stopped, id)
}
}
}
// TestServiceRetrieval tests that individual services can be retrieved.
func TestServiceRetrieval(t *testing.T) {
// Create a simple stack and register two service types
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
if err := stack.Register(NewNoopService); err != nil {
t.Fatalf("noop service registration failed: %v", err)
}
if err := stack.Register(NewInstrumentedService); err != nil {
t.Fatalf("instrumented service registration failed: %v", err)
}
// Make sure none of the services can be retrieved until started
var noopServ *NoopService
if err := stack.Service(&noopServ); err != ErrNodeStopped {
t.Fatalf("noop service retrieval mismatch: have %v, want %v", err, ErrNodeStopped)
}
var instServ *InstrumentedService
if err := stack.Service(&instServ); err != ErrNodeStopped {
t.Fatalf("instrumented service retrieval mismatch: have %v, want %v", err, ErrNodeStopped)
}
// Start the stack and ensure everything is retrievable now
if err := stack.Start(); err != nil {
t.Fatalf("failed to start stack: %v", err)
}
defer stack.Stop()
if err := stack.Service(&noopServ); err != nil {
t.Fatalf("noop service retrieval mismatch: have %v, want %v", err, nil)
}
if err := stack.Service(&instServ); err != nil {
t.Fatalf("instrumented service retrieval mismatch: have %v, want %v", err, nil)
}
}
// Tests that all protocols defined by individual services get launched.
func TestProtocolGather(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
// Register a batch of services with some configured number of protocols
services := map[string]struct {
Count int
Maker InstrumentingWrapper
}{
"zero": {0, InstrumentedServiceMakerA},
"one": {1, InstrumentedServiceMakerB},
"many": {10, InstrumentedServiceMakerC},
}
for id, config := range services {
protocols := make([]p2p.Protocol, config.Count)
for i := 0; i < len(protocols); i++ {
protocols[i].Name = id
protocols[i].Version = uint(i)
}
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
protocols: protocols,
}, nil
}
if err := stack.Register(config.Maker(constructor)); err != nil {
t.Fatalf("service %s: registration failed: %v", id, err)
}
}
// Start the services and ensure all protocols start successfully
if err := stack.Start(); err != nil { if err := stack.Start(); err != nil {
t.Fatalf("failed to start protocol stack: %v", err) t.Fatalf("failed to start protocol stack: %v", err)
} }
defer stack.Stop() for id := range lifecycles {
if !started[id] {
protocols := stack.Server().Protocols t.Fatalf("service %s: service not running", id)
if len(protocols) != 11 { }
t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26) if stopped[id] {
} t.Fatalf("service %s: service already stopped", id)
for id, config := range services {
for ver := 0; ver < config.Count; ver++ {
launched := false
for i := 0; i < len(protocols); i++ {
if protocols[i].Name == id && protocols[i].Version == uint(ver) {
launched = true
break
}
}
if !launched {
t.Errorf("configured protocol not launched: %s v%d", id, ver)
}
} }
} }
// Stop the stack, verify failure and check all terminations
err = stack.Close()
if err, ok := err.(*StopError); !ok {
t.Fatalf("termination failure mismatch: have %v, want StopError", err)
} else {
failer := reflect.TypeOf(&InstrumentedService{})
if err.Services[failer] != failure {
t.Fatalf("failer termination failure mismatch: have %v, want %v", err.Services[failer], failure)
}
if len(err.Services) != 1 {
t.Fatalf("failure count mismatch: have %d, want %d", len(err.Services), 1)
}
}
for id := range lifecycles {
if !stopped[id] {
t.Fatalf("service %s: service not terminated", id)
}
delete(started, id)
delete(stopped, id)
}
stack.server = &p2p.Server{}
stack.server.PrivateKey = testNodeKey
} }
// Tests that all APIs defined by individual services get exposed. // Tests whether a handler can be successfully mounted on the canonical HTTP server
func TestAPIGather(t *testing.T) { // on the given path
stack, err := New(testNodeConfig()) func TestRegisterHandler_Successful(t *testing.T) {
if err != nil { node := createNode(t, 7878, 7979)
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
// Register a batch of services with some configured APIs // create and mount handler
calls := make(chan string, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
makeAPI := func(result string) *OneMethodAPI { w.Write([]byte("success"))
return &OneMethodAPI{fun: func() { calls <- result }} })
} node.RegisterHandler("test", "/test", handler)
services := map[string]struct {
APIs []rpc.API // start node
Maker InstrumentingWrapper if err := node.Start(); err != nil {
}{ t.Fatalf("could not start node: %v", err)
"Zero APIs": {
[]rpc.API{}, InstrumentedServiceMakerA},
"Single API": {
[]rpc.API{
{Namespace: "single", Version: "1", Service: makeAPI("single.v1"), Public: true},
}, InstrumentedServiceMakerB},
"Many APIs": {
[]rpc.API{
{Namespace: "multi", Version: "1", Service: makeAPI("multi.v1"), Public: true},
{Namespace: "multi.v2", Version: "2", Service: makeAPI("multi.v2"), Public: true},
{Namespace: "multi.v2.nested", Version: "2", Service: makeAPI("multi.v2.nested"), Public: true},
}, InstrumentedServiceMakerC},
} }
for id, config := range services { // create HTTP request
config := config httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7878/test", nil)
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{apis: config.APIs}, nil
}
if err := stack.Register(config.Maker(constructor)); err != nil {
t.Fatalf("service %s: registration failed: %v", id, err)
}
}
// Start the services and ensure all API start successfully
if err := stack.Start(); err != nil {
t.Fatalf("failed to start protocol stack: %v", err)
}
defer stack.Stop()
// Connect to the RPC server and verify the various registered endpoints
client, err := stack.Attach()
if err != nil {
t.Fatalf("failed to connect to the inproc API server: %v", err)
}
defer client.Close()
tests := []struct {
Method string
Result string
}{
{"single_theOneMethod", "single.v1"},
{"multi_theOneMethod", "multi.v1"},
{"multi.v2_theOneMethod", "multi.v2"},
{"multi.v2.nested_theOneMethod", "multi.v2.nested"},
}
for i, test := range tests {
if err := client.Call(nil, test.Method); err != nil {
t.Errorf("test %d: API request failed: %v", i, err)
}
select {
case result := <-calls:
if result != test.Result {
t.Errorf("test %d: result mismatch: have %s, want %s", i, result, test.Result)
}
case <-time.After(time.Second):
t.Fatalf("test %d: rpc execution timeout", i)
}
}
}
func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
node := startHTTP(t)
defer node.stopHTTP()
wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
if err != nil { if err != nil {
t.Error("could not issue new http request ", err) t.Error("could not issue new http request ", err)
} }
wsReq.Header.Set("Connection", "upgrade")
wsReq.Header.Set("Upgrade", "websocket")
wsReq.Header.Set("Sec-WebSocket-Version", "13")
wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
resp := doHTTPRequest(t, wsReq)
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
}
func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) {
node := startHTTP(t)
defer node.stopHTTP()
httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
if err != nil {
t.Error("could not issue new http request ", err)
}
httpReq.Header.Set("Accept-Encoding", "gzip")
// check response
resp := doHTTPRequest(t, httpReq) resp := doHTTPRequest(t, httpReq)
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) buf := make([]byte, 7)
_, err = io.ReadFull(resp.Body, buf)
if err != nil {
t.Fatalf("could not read response: %v", err)
}
assert.Equal(t, "success", string(buf))
} }
func startHTTP(t *testing.T) *Node { // Tests that the given handler will not be successfully mounted since no HTTP server
conf := &Config{HTTPPort: 7453, WSPort: 7453} // is enabled for RPC
func TestRegisterHandler_Unsuccessful(t *testing.T) {
node, err := New(&DefaultConfig)
if err != nil {
t.Fatalf("could not create new node: %v", err)
}
// create and mount handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("success"))
})
node.RegisterHandler("test", "/test", handler)
}
// Tests whether websocket requests can be handled on the same port as a regular http server.
func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
node := startHTTP(t, 0, 0)
defer node.Close()
ws := strings.Replace(node.HTTPEndpoint(), "http://", "ws://", 1)
if node.WSEndpoint() != ws {
t.Fatalf("endpoints should be the same")
}
if !checkRPC(ws) {
t.Fatalf("ws request failed")
}
if !checkRPC(node.HTTPEndpoint()) {
t.Fatalf("http request failed")
}
}
func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) {
// try and get a free port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("can't listen:", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
node := startHTTP(t, 0, port)
defer node.Close()
wsOnHTTP := strings.Replace(node.HTTPEndpoint(), "http://", "ws://", 1)
ws := fmt.Sprintf("ws://127.0.0.1:%d", port)
if node.WSEndpoint() == wsOnHTTP {
t.Fatalf("endpoints should not be the same")
}
// ensure ws endpoint matches the expected endpoint
if node.WSEndpoint() != ws {
t.Fatalf("ws endpoint is incorrect: expected %s, got %s", ws, node.WSEndpoint())
}
if !checkRPC(ws) {
t.Fatalf("ws request failed")
}
if !checkRPC(node.HTTPEndpoint()) {
t.Fatalf("http request failed")
}
}
func createNode(t *testing.T, httpPort, wsPort int) *Node {
conf := &Config{
HTTPHost: "127.0.0.1",
HTTPPort: httpPort,
WSHost: "127.0.0.1",
WSPort: wsPort,
}
node, err := New(conf) node, err := New(conf)
if err != nil { if err != nil {
t.Error("could not create a new node ", err) t.Fatalf("could not create a new node: %v", err)
} }
return node
}
err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{}) func startHTTP(t *testing.T, httpPort, wsPort int) *Node {
node := createNode(t, httpPort, wsPort)
err := node.Start()
if err != nil { if err != nil {
t.Error("could not start http service on node ", err) t.Fatalf("could not start http service on node: %v", err)
} }
return node return node
} }
func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
client := &http.Client{} client := http.DefaultClient
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
t.Error("could not issue a GET request to the given endpoint", err) t.Fatalf("could not issue a GET request to the given endpoint: %v", err)
} }
return resp return resp
} }
func containsProtocol(stackProtocols []p2p.Protocol, protocol p2p.Protocol) bool {
for _, a := range stackProtocols {
if reflect.DeepEqual(a, protocol) {
return true
}
}
return false
}
func containsAPI(stackAPIs []rpc.API, api rpc.API) bool {
for _, a := range stackAPIs {
if reflect.DeepEqual(a, api) {
return true
}
}
return false
}

@ -18,17 +18,304 @@ package node
import ( import (
"compress/gzip" "compress/gzip"
"context"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc"
"github.com/rs/cors" "github.com/rs/cors"
) )
// httpConfig is the JSON-RPC/HTTP configuration.
type httpConfig struct {
Modules []string
CorsAllowedOrigins []string
Vhosts []string
}
// wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct {
Origins []string
Modules []string
}
type rpcHandler struct {
http.Handler
server *rpc.Server
}
type httpServer struct {
log log.Logger
timeouts rpc.HTTPTimeouts
mux http.ServeMux // registered handlers go here
mu sync.Mutex
server *http.Server
listener net.Listener // non-nil when server is running
// HTTP RPC handler things.
httpConfig httpConfig
httpHandler atomic.Value // *rpcHandler
// WebSocket handler things.
wsConfig wsConfig
wsHandler atomic.Value // *rpcHandler
// These are set by setListenAddr.
endpoint string
host string
port int
handlerNames map[string]string
}
func newHTTPServer(log log.Logger, timeouts rpc.HTTPTimeouts) *httpServer {
h := &httpServer{log: log, timeouts: timeouts, handlerNames: make(map[string]string)}
h.httpHandler.Store((*rpcHandler)(nil))
h.wsHandler.Store((*rpcHandler)(nil))
return h
}
// setListenAddr configures the listening address of the server.
// The address can only be set while the server isn't running.
func (h *httpServer) setListenAddr(host string, port int) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.listener != nil && (host != h.host || port != h.port) {
return fmt.Errorf("HTTP server already running on %s", h.endpoint)
}
h.host, h.port = host, port
h.endpoint = fmt.Sprintf("%s:%d", host, port)
return nil
}
// listenAddr returns the listening address of the server.
func (h *httpServer) listenAddr() string {
h.mu.Lock()
defer h.mu.Unlock()
if h.listener != nil {
return h.listener.Addr().String()
}
return h.endpoint
}
// start starts the HTTP server if it is enabled and not already running.
func (h *httpServer) start() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.endpoint == "" || h.listener != nil {
return nil // already running or not configured
}
// Initialize the server.
h.server = &http.Server{Handler: h}
if h.timeouts != (rpc.HTTPTimeouts{}) {
CheckTimeouts(&h.timeouts)
h.server.ReadTimeout = h.timeouts.ReadTimeout
h.server.WriteTimeout = h.timeouts.WriteTimeout
h.server.IdleTimeout = h.timeouts.IdleTimeout
}
// Start the server.
listener, err := net.Listen("tcp", h.endpoint)
if err != nil {
// If the server fails to start, we need to clear out the RPC and WS
// configuration so they can be configured another time.
h.disableRPC()
h.disableWS()
return err
}
h.listener = listener
go h.server.Serve(listener)
// if server is websocket only, return after logging
if h.wsAllowed() && !h.rpcAllowed() {
h.log.Info("WebSocket enabled", "url", fmt.Sprintf("ws://%v", listener.Addr()))
return nil
}
// Log http endpoint.
h.log.Info("HTTP server started",
"endpoint", listener.Addr(),
"cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","),
"vhosts", strings.Join(h.httpConfig.Vhosts, ","),
)
// Log all handlers mounted on server.
var paths []string
for path := range h.handlerNames {
paths = append(paths, path)
}
sort.Strings(paths)
logged := make(map[string]bool, len(paths))
for _, path := range paths {
name := h.handlerNames[path]
if !logged[name] {
log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
logged[name] = true
}
}
return nil
}
func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rpc := h.httpHandler.Load().(*rpcHandler)
if r.RequestURI == "/" {
// Serve JSON-RPC on the root path.
ws := h.wsHandler.Load().(*rpcHandler)
if ws != nil && isWebsocket(r) {
ws.ServeHTTP(w, r)
return
}
if rpc != nil {
rpc.ServeHTTP(w, r)
return
}
} else if rpc != nil {
// Requests to a path below root are handled by the mux,
// which has all the handlers registered via Node.RegisterHandler.
// These are made available when RPC is enabled.
h.mux.ServeHTTP(w, r)
return
}
w.WriteHeader(404)
}
// stop shuts down the HTTP server.
func (h *httpServer) stop() {
h.mu.Lock()
defer h.mu.Unlock()
h.doStop()
}
func (h *httpServer) doStop() {
if h.listener == nil {
return // not running
}
// Shut down the server.
httpHandler := h.httpHandler.Load().(*rpcHandler)
wsHandler := h.httpHandler.Load().(*rpcHandler)
if httpHandler != nil {
h.httpHandler.Store((*rpcHandler)(nil))
httpHandler.server.Stop()
}
if wsHandler != nil {
h.wsHandler.Store((*rpcHandler)(nil))
wsHandler.server.Stop()
}
h.server.Shutdown(context.Background())
h.listener.Close()
h.log.Info("HTTP server stopped", "endpoint", h.listener.Addr())
// Clear out everything to allow re-configuring it later.
h.host, h.port, h.endpoint = "", 0, ""
h.server, h.listener = nil, nil
}
// enableRPC turns on JSON-RPC over HTTP on the server.
func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.rpcAllowed() {
return fmt.Errorf("JSON-RPC over HTTP is already enabled")
}
// Create RPC server and handler.
srv := rpc.NewServer()
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
return err
}
h.httpConfig = config
h.httpHandler.Store(&rpcHandler{
Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts),
server: srv,
})
return nil
}
// disableRPC stops the HTTP RPC handler. This is internal, the caller must hold h.mu.
func (h *httpServer) disableRPC() bool {
handler := h.httpHandler.Load().(*rpcHandler)
if handler != nil {
h.httpHandler.Store((*rpcHandler)(nil))
handler.server.Stop()
}
return handler != nil
}
// enableWS turns on JSON-RPC over WebSocket on the server.
func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.wsAllowed() {
return fmt.Errorf("JSON-RPC over WebSocket is already enabled")
}
// Create RPC server and handler.
srv := rpc.NewServer()
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
return err
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins),
server: srv,
})
return nil
}
// stopWS disables JSON-RPC over WebSocket and also stops the server if it only serves WebSocket.
func (h *httpServer) stopWS() {
h.mu.Lock()
defer h.mu.Unlock()
if h.disableWS() {
if !h.rpcAllowed() {
h.doStop()
}
}
}
// disableWS disables the WebSocket handler. This is internal, the caller must hold h.mu.
func (h *httpServer) disableWS() bool {
ws := h.wsHandler.Load().(*rpcHandler)
if ws != nil {
h.wsHandler.Store((*rpcHandler)(nil))
ws.server.Stop()
}
return ws != nil
}
// rpcAllowed returns true when JSON-RPC over HTTP is enabled.
func (h *httpServer) rpcAllowed() bool {
return h.httpHandler.Load().(*rpcHandler) != nil
}
// wsAllowed returns true when JSON-RPC over WebSocket is enabled.
func (h *httpServer) wsAllowed() bool {
return h.wsHandler.Load().(*rpcHandler) != nil
}
// isWebsocket checks the header of an http request for a websocket upgrade request.
func isWebsocket(r *http.Request) bool {
return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
strings.ToLower(r.Header.Get("Connection")) == "upgrade"
}
// NewHTTPHandlerStack returns wrapped http-related handlers // NewHTTPHandlerStack returns wrapped http-related handlers
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler { func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
// Wrap the CORS-handler within a host-handler // Wrap the CORS-handler within a host-handler
@ -45,8 +332,8 @@ func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: allowedOrigins, AllowedOrigins: allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet}, AllowedMethods: []string{http.MethodPost, http.MethodGet},
MaxAge: 600,
AllowedHeaders: []string{"*"}, AllowedHeaders: []string{"*"},
MaxAge: 600,
}) })
return c.Handler(srv) return c.Handler(srv)
} }
@ -138,22 +425,68 @@ func newGzipHandler(next http.Handler) http.Handler {
}) })
} }
// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade type ipcServer struct {
// request to the websocket protocol. If not, serves the the request with the http handler. log log.Logger
func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler { endpoint string
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isWebsocket(r) { mu sync.Mutex
ws.ServeHTTP(w, r) listener net.Listener
log.Debug("serving websocket request") srv *rpc.Server
return }
func newIPCServer(log log.Logger, endpoint string) *ipcServer {
return &ipcServer{log: log, endpoint: endpoint}
}
// Start starts the httpServer's http.Server
func (is *ipcServer) start(apis []rpc.API) error {
is.mu.Lock()
defer is.mu.Unlock()
if is.listener != nil {
return nil // already running
}
listener, srv, err := rpc.StartIPCEndpoint(is.endpoint, apis)
if err != nil {
return err
}
is.log.Info("IPC endpoint opened", "url", is.endpoint)
is.listener, is.srv = listener, srv
return nil
}
func (is *ipcServer) stop() error {
is.mu.Lock()
defer is.mu.Unlock()
if is.listener == nil {
return nil // not running
}
err := is.listener.Close()
is.srv.Stop()
is.listener, is.srv = nil, nil
is.log.Info("IPC endpoint closed", "url", is.endpoint)
return err
}
// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
// and then registers all of the APIs exposed by the services.
func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error {
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
}
// Generate the whitelist based on the allowed modules
whitelist := make(map[string]bool)
for _, module := range modules {
whitelist[module] = true
}
// Register all the APIs exposed by the services
for _, api := range apis {
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
return err
}
} }
}
h.ServeHTTP(w, r) return nil
})
}
// isWebsocket checks the header of an http request for a websocket upgrade request.
func isWebsocket(r *http.Request) bool {
return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
strings.ToLower(r.Header.Get("Connection")) == "upgrade"
} }

@ -1,38 +1,110 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node package node
import ( import (
"bytes"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) { // TestCorsHandler makes sure CORS are properly handled on the http server.
srv := rpc.NewServer() func TestCorsHandler(t *testing.T) {
srv := createAndStartServer(t, httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, wsConfig{})
defer srv.stop()
handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{})) resp := testRequest(t, "origin", "test.com", "", srv)
ts := httptest.NewServer(handler) assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin"))
defer ts.Close()
responses := make(chan *http.Response) resp2 := testRequest(t, "origin", "bad", "", srv)
go func(responses chan *http.Response) { assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin"))
client := &http.Client{} }
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) // TestVhosts makes sure vhosts are properly handled on the http server.
req.Header.Set("Connection", "upgrade") func TestVhosts(t *testing.T) {
req.Header.Set("Upgrade", "websocket") srv := createAndStartServer(t, httpConfig{Vhosts: []string{"test"}}, false, wsConfig{})
req.Header.Set("Sec-WebSocket-Version", "13") defer srv.stop()
req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
resp := testRequest(t, "", "", "test", srv)
resp, err := client.Do(req) assert.Equal(t, resp.StatusCode, http.StatusOK)
if err != nil {
t.Error("could not issue a GET request to the test http server", err) resp2 := testRequest(t, "", "", "bad", srv)
} assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
responses <- resp }
}(responses)
// TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server.
response := <-responses func TestWebsocketOrigins(t *testing.T) {
assert.Equal(t, "websocket", response.Header.Get("Upgrade")) srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: []string{"test"}})
defer srv.stop()
dialer := websocket.DefaultDialer
_, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{
"Content-type": []string{"application/json"},
"Sec-WebSocket-Version": []string{"13"},
"Origin": []string{"test"},
})
assert.NoError(t, err)
_, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{
"Content-type": []string{"application/json"},
"Sec-WebSocket-Version": []string{"13"},
"Origin": []string{"bad"},
})
assert.Error(t, err)
}
func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
t.Helper()
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
assert.NoError(t, srv.enableRPC(nil, conf))
if ws {
assert.NoError(t, srv.enableWS(nil, wsConf))
}
assert.NoError(t, srv.setListenAddr("localhost", 0))
assert.NoError(t, srv.start())
return srv
}
func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response {
t.Helper()
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,method":"rpc_modules"}`))
req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body)
req.Header.Set("content-type", "application/json")
if key != "" && value != "" {
req.Header.Set(key, value)
}
if host != "" {
req.Host = host
}
client := http.DefaultClient
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
return resp
} }

@ -1,122 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"path/filepath"
"reflect"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
)
// ServiceContext is a collection of service independent options inherited from
// the protocol stack, that is passed to all constructors to be optionally used;
// as well as utility methods to operate on the service environment.
type ServiceContext struct {
services map[reflect.Type]Service // Index of the already constructed services
Config Config
EventMux *event.TypeMux // Event multiplexer used for decoupled notifications
AccountManager *accounts.Manager // Account manager created by the node.
}
// OpenDatabase opens an existing database with the given name (or creates one
// if no previous can be found) from within the node's data directory. If the
// node is an ephemeral one, a memory database is returned.
func (ctx *ServiceContext) OpenDatabase(name string, cache int, handles int, namespace string) (ethdb.Database, error) {
if ctx.Config.DataDir == "" {
return rawdb.NewMemoryDatabase(), nil
}
return rawdb.NewLevelDBDatabase(ctx.Config.ResolvePath(name), cache, handles, namespace)
}
// OpenDatabaseWithFreezer opens an existing database with the given name (or
// creates one if no previous can be found) from within the node's data directory,
// also attaching a chain freezer to it that moves ancient chain data from the
// database to immutable append-only files. If the node is an ephemeral one, a
// memory database is returned.
func (ctx *ServiceContext) OpenDatabaseWithFreezer(name string, cache int, handles int, freezer string, namespace string) (ethdb.Database, error) {
if ctx.Config.DataDir == "" {
return rawdb.NewMemoryDatabase(), nil
}
root := ctx.Config.ResolvePath(name)
switch {
case freezer == "":
freezer = filepath.Join(root, "ancient")
case !filepath.IsAbs(freezer):
freezer = ctx.Config.ResolvePath(freezer)
}
return rawdb.NewLevelDBDatabaseWithFreezer(root, cache, handles, freezer, namespace)
}
// ResolvePath resolves a user path into the data directory if that was relative
// and if the user actually uses persistent storage. It will return an empty string
// for emphemeral storage and the user's own input for absolute paths.
func (ctx *ServiceContext) ResolvePath(path string) string {
return ctx.Config.ResolvePath(path)
}
// Service retrieves a currently running service registered of a specific type.
func (ctx *ServiceContext) Service(service interface{}) error {
element := reflect.ValueOf(service).Elem()
if running, ok := ctx.services[element.Type()]; ok {
element.Set(reflect.ValueOf(running))
return nil
}
return ErrServiceUnknown
}
// ExtRPCEnabled returns the indicator whether node enables the external
// RPC(http, ws or graphql).
func (ctx *ServiceContext) ExtRPCEnabled() bool {
return ctx.Config.ExtRPCEnabled()
}
// ServiceConstructor is the function signature of the constructors needed to be
// registered for service instantiation.
type ServiceConstructor func(ctx *ServiceContext) (Service, error)
// Service is an individual protocol that can be registered into a node.
//
// Notes:
//
// • Service life-cycle management is delegated to the node. The service is allowed to
// initialize itself upon creation, but no goroutines should be spun up outside of the
// Start method.
//
// • Restart logic is not required as the node will create a fresh instance
// every time a service is started.
type Service interface {
// Protocols retrieves the P2P protocols the service wishes to start.
Protocols() []p2p.Protocol
// APIs retrieves the list of RPC descriptors the service provides
APIs() []rpc.API
// Start is called after all services have been constructed and the networking
// layer was also initialized to spawn any goroutines required by the service.
Start(server *p2p.Server) error
// Stop terminates all goroutines belonging to the service, blocking until they
// are all terminated.
Stop() error
}

@ -1,98 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
)
// Tests that databases are correctly created persistent or ephemeral based on
// the configured service context.
func TestContextDatabases(t *testing.T) {
// Create a temporary folder and ensure no database is contained within
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatalf("failed to create temporary data directory: %v", err)
}
defer os.RemoveAll(dir)
if _, err := os.Stat(filepath.Join(dir, "database")); err == nil {
t.Fatalf("non-created database already exists")
}
// Request the opening/creation of a database and ensure it persists to disk
ctx := &ServiceContext{Config: Config{Name: "unit-test", DataDir: dir}}
db, err := ctx.OpenDatabase("persistent", 0, 0, "")
if err != nil {
t.Fatalf("failed to open persistent database: %v", err)
}
db.Close()
if _, err := os.Stat(filepath.Join(dir, "unit-test", "persistent")); err != nil {
t.Fatalf("persistent database doesn't exists: %v", err)
}
// Request th opening/creation of an ephemeral database and ensure it's not persisted
ctx = &ServiceContext{Config: Config{DataDir: ""}}
db, err = ctx.OpenDatabase("ephemeral", 0, 0, "")
if err != nil {
t.Fatalf("failed to open ephemeral database: %v", err)
}
db.Close()
if _, err := os.Stat(filepath.Join(dir, "ephemeral")); err == nil {
t.Fatalf("ephemeral database exists")
}
}
// Tests that already constructed services can be retrieves by later ones.
func TestContextServices(t *testing.T) {
stack, err := New(testNodeConfig())
if err != nil {
t.Fatalf("failed to create protocol stack: %v", err)
}
defer stack.Close()
// Define a verifier that ensures a NoopA is before it and NoopB after
verifier := func(ctx *ServiceContext) (Service, error) {
var objA *NoopServiceA
if ctx.Service(&objA) != nil {
return nil, fmt.Errorf("former service not found")
}
var objB *NoopServiceB
if err := ctx.Service(&objB); err != ErrServiceUnknown {
return nil, fmt.Errorf("latters lookup error mismatch: have %v, want %v", err, ErrServiceUnknown)
}
return new(NoopService), nil
}
// Register the collection of services
if err := stack.Register(NewNoopServiceA); err != nil {
t.Fatalf("former failed to register service: %v", err)
}
if err := stack.Register(verifier); err != nil {
t.Fatalf("failed to register service verifier: %v", err)
}
if err := stack.Register(NewNoopServiceB); err != nil {
t.Fatalf("latter failed to register service: %v", err)
}
// Start the protocol stack and ensure services are constructed in order
if err := stack.Start(); err != nil {
t.Fatalf("failed to start stack: %v", err)
}
defer stack.Stop()
}

@ -20,61 +20,40 @@
package node package node
import ( import (
"reflect"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
// NoopService is a trivial implementation of the Service interface. // NoopLifecycle is a trivial implementation of the Service interface.
type NoopService struct{} type NoopLifecycle struct{}
func (s *NoopService) Protocols() []p2p.Protocol { return nil } func (s *NoopLifecycle) Start() error { return nil }
func (s *NoopService) APIs() []rpc.API { return nil } func (s *NoopLifecycle) Stop() error { return nil }
func (s *NoopService) Start(*p2p.Server) error { return nil }
func (s *NoopService) Stop() error { return nil }
func NewNoopService(*ServiceContext) (Service, error) { return new(NoopService), nil } func NewNoop() *Noop {
noop := new(Noop)
return noop
}
// Set of services all wrapping the base NoopService resulting in the same method // Set of services all wrapping the base NoopLifecycle resulting in the same method
// signatures but different outer types. // signatures but different outer types.
type NoopServiceA struct{ NoopService } type Noop struct{ NoopLifecycle }
type NoopServiceB struct{ NoopService }
type NoopServiceC struct{ NoopService }
func NewNoopServiceA(*ServiceContext) (Service, error) { return new(NoopServiceA), nil } // InstrumentedService is an implementation of Lifecycle for which all interface
func NewNoopServiceB(*ServiceContext) (Service, error) { return new(NoopServiceB), nil }
func NewNoopServiceC(*ServiceContext) (Service, error) { return new(NoopServiceC), nil }
// InstrumentedService is an implementation of Service for which all interface
// methods can be instrumented both return value as well as event hook wise. // methods can be instrumented both return value as well as event hook wise.
type InstrumentedService struct { type InstrumentedService struct {
start error
stop error
startHook func()
stopHook func()
protocols []p2p.Protocol protocols []p2p.Protocol
apis []rpc.API
start error
stop error
protocolsHook func()
startHook func(*p2p.Server)
stopHook func()
} }
func NewInstrumentedService(*ServiceContext) (Service, error) { return new(InstrumentedService), nil } func (s *InstrumentedService) Start() error {
func (s *InstrumentedService) Protocols() []p2p.Protocol {
if s.protocolsHook != nil {
s.protocolsHook()
}
return s.protocols
}
func (s *InstrumentedService) APIs() []rpc.API {
return s.apis
}
func (s *InstrumentedService) Start(server *p2p.Server) error {
if s.startHook != nil { if s.startHook != nil {
s.startHook(server) s.startHook()
} }
return s.start return s.start
} }
@ -86,48 +65,49 @@ func (s *InstrumentedService) Stop() error {
return s.stop return s.stop
} }
// InstrumentingWrapper is a method to specialize a service constructor returning type FullService struct{}
// a generic InstrumentedService into one returning a wrapping specific one.
type InstrumentingWrapper func(base ServiceConstructor) ServiceConstructor
func InstrumentingWrapperMaker(base ServiceConstructor, kind reflect.Type) ServiceConstructor { func NewFullService(stack *Node) (*FullService, error) {
return func(ctx *ServiceContext) (Service, error) { fs := new(FullService)
obj, err := base(ctx)
if err != nil {
return nil, err
}
wrapper := reflect.New(kind)
wrapper.Elem().Field(0).Set(reflect.ValueOf(obj).Elem())
return wrapper.Interface().(Service), nil stack.RegisterProtocols(fs.Protocols())
stack.RegisterAPIs(fs.APIs())
stack.RegisterLifecycle(fs)
return fs, nil
}
func (f *FullService) Start() error { return nil }
func (f *FullService) Stop() error { return nil }
func (f *FullService) Protocols() []p2p.Protocol {
return []p2p.Protocol{
p2p.Protocol{
Name: "test1",
Version: uint(1),
},
p2p.Protocol{
Name: "test2",
Version: uint(2),
},
} }
} }
// Set of services all wrapping the base InstrumentedService resulting in the func (f *FullService) APIs() []rpc.API {
// same method signatures but different outer types. return []rpc.API{
type InstrumentedServiceA struct{ InstrumentedService } {
type InstrumentedServiceB struct{ InstrumentedService } Namespace: "admin",
type InstrumentedServiceC struct{ InstrumentedService } Version: "1.0",
},
func InstrumentedServiceMakerA(base ServiceConstructor) ServiceConstructor { {
return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceA{})) Namespace: "debug",
} Version: "1.0",
Public: true,
func InstrumentedServiceMakerB(base ServiceConstructor) ServiceConstructor { },
return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceB{})) {
} Namespace: "net",
Version: "1.0",
func InstrumentedServiceMakerC(base ServiceConstructor) ServiceConstructor { Public: true,
return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceC{})) },
}
// OneMethodAPI is a single-method API handler to be returned by test services.
type OneMethodAPI struct {
fun func()
}
func (api *OneMethodAPI) TheOneMethod() {
if api.fun != nil {
api.fun()
} }
} }

@ -75,11 +75,11 @@ func (e *ExecAdapter) Name() string {
// NewNode returns a new ExecNode using the given config // NewNode returns a new ExecNode using the given config
func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) { func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) {
if len(config.Services) == 0 { if len(config.Lifecycles) == 0 {
return nil, errors.New("node must have at least one service") return nil, errors.New("node must have at least one service lifecycle")
} }
for _, service := range config.Services { for _, service := range config.Lifecycles {
if _, exists := serviceFuncs[service]; !exists { if _, exists := lifecycleConstructorFuncs[service]; !exists {
return nil, fmt.Errorf("unknown node service %q", service) return nil, fmt.Errorf("unknown node service %q", service)
} }
} }
@ -263,7 +263,7 @@ func (n *ExecNode) waitForStartupJSON(ctx context.Context) (string, chan nodeSta
func (n *ExecNode) execCommand() *exec.Cmd { func (n *ExecNode) execCommand() *exec.Cmd {
return &exec.Cmd{ return &exec.Cmd{
Path: reexec.Self(), Path: reexec.Self(),
Args: []string{"p2p-node", strings.Join(n.Config.Node.Services, ","), n.ID.String()}, Args: []string{"p2p-node", strings.Join(n.Config.Node.Lifecycles, ","), n.ID.String()},
} }
} }
@ -400,7 +400,7 @@ func execP2PNode() {
defer signal.Stop(sigc) defer signal.Stop(sigc)
<-sigc <-sigc
log.Info("Received SIGTERM, shutting down...") log.Info("Received SIGTERM, shutting down...")
stack.Stop() stack.Close()
}() }()
stack.Wait() // Wait for the stack to exit. stack.Wait() // Wait for the stack to exit.
} }
@ -434,44 +434,36 @@ func startExecNodeStack() (*node.Node, error) {
return nil, fmt.Errorf("error creating node stack: %v", err) return nil, fmt.Errorf("error creating node stack: %v", err)
} }
// register the services, collecting them into a map so we can wrap // Register the services, collecting them into a map so they can
// them in a snapshot service // be accessed by the snapshot API.
services := make(map[string]node.Service, len(serviceNames)) services := make(map[string]node.Lifecycle, len(serviceNames))
for _, name := range serviceNames { for _, name := range serviceNames {
serviceFunc, exists := serviceFuncs[name] lifecycleFunc, exists := lifecycleConstructorFuncs[name]
if !exists { if !exists {
return nil, fmt.Errorf("unknown node service %q", err) return nil, fmt.Errorf("unknown node service %q", err)
} }
constructor := func(nodeCtx *node.ServiceContext) (node.Service, error) { ctx := &ServiceContext{
ctx := &ServiceContext{ RPCDialer: &wsRPCDialer{addrs: conf.PeerAddrs},
RPCDialer: &wsRPCDialer{addrs: conf.PeerAddrs}, Config: conf.Node,
NodeContext: nodeCtx,
Config: conf.Node,
}
if conf.Snapshots != nil {
ctx.Snapshot = conf.Snapshots[name]
}
service, err := serviceFunc(ctx)
if err != nil {
return nil, err
}
services[name] = service
return service, nil
} }
if err := stack.Register(constructor); err != nil { if conf.Snapshots != nil {
return stack, fmt.Errorf("error registering service %q: %v", name, err) ctx.Snapshot = conf.Snapshots[name]
} }
service, err := lifecycleFunc(ctx, stack)
if err != nil {
return nil, err
}
services[name] = service
stack.RegisterLifecycle(service)
} }
// register the snapshot service // Add the snapshot API.
err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { stack.RegisterAPIs([]rpc.API{{
return &snapshotService{services}, nil Namespace: "simulation",
}) Version: "1.0",
if err != nil { Service: SnapshotAPI{services},
return stack, fmt.Errorf("error starting snapshot service: %v", err) }})
}
// start the stack
if err = stack.Start(); err != nil { if err = stack.Start(); err != nil {
err = fmt.Errorf("error starting stack: %v", err) err = fmt.Errorf("error starting stack: %v", err)
} }
@ -490,35 +482,9 @@ type nodeStartupJSON struct {
NodeInfo *p2p.NodeInfo NodeInfo *p2p.NodeInfo
} }
// snapshotService is a node.Service which wraps a list of services and
// exposes an API to generate a snapshot of those services
type snapshotService struct {
services map[string]node.Service
}
func (s *snapshotService) APIs() []rpc.API {
return []rpc.API{{
Namespace: "simulation",
Version: "1.0",
Service: SnapshotAPI{s.services},
}}
}
func (s *snapshotService) Protocols() []p2p.Protocol {
return nil
}
func (s *snapshotService) Start(*p2p.Server) error {
return nil
}
func (s *snapshotService) Stop() error {
return nil
}
// SnapshotAPI provides an RPC method to create snapshots of services // SnapshotAPI provides an RPC method to create snapshots of services
type SnapshotAPI struct { type SnapshotAPI struct {
services map[string]node.Service services map[string]node.Lifecycle
} }
func (api SnapshotAPI) Snapshot() (map[string][]byte, error) { func (api SnapshotAPI) Snapshot() (map[string][]byte, error) {

@ -37,29 +37,21 @@ import (
// SimAdapter is a NodeAdapter which creates in-memory simulation nodes and // SimAdapter is a NodeAdapter which creates in-memory simulation nodes and
// connects them using net.Pipe // connects them using net.Pipe
type SimAdapter struct { type SimAdapter struct {
pipe func() (net.Conn, net.Conn, error) pipe func() (net.Conn, net.Conn, error)
mtx sync.RWMutex mtx sync.RWMutex
nodes map[enode.ID]*SimNode nodes map[enode.ID]*SimNode
services map[string]ServiceFunc lifecycles LifecycleConstructors
} }
// NewSimAdapter creates a SimAdapter which is capable of running in-memory // NewSimAdapter creates a SimAdapter which is capable of running in-memory
// simulation nodes running any of the given services (the services to run on a // simulation nodes running any of the given services (the services to run on a
// particular node are passed to the NewNode function in the NodeConfig) // particular node are passed to the NewNode function in the NodeConfig)
// the adapter uses a net.Pipe for in-memory simulated network connections // the adapter uses a net.Pipe for in-memory simulated network connections
func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter { func NewSimAdapter(services LifecycleConstructors) *SimAdapter {
return &SimAdapter{ return &SimAdapter{
pipe: pipes.NetPipe, pipe: pipes.NetPipe,
nodes: make(map[enode.ID]*SimNode), nodes: make(map[enode.ID]*SimNode),
services: services, lifecycles: services,
}
}
func NewTCPAdapter(services map[string]ServiceFunc) *SimAdapter {
return &SimAdapter{
pipe: pipes.TCPPipe,
nodes: make(map[enode.ID]*SimNode),
services: services,
} }
} }
@ -85,11 +77,11 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
} }
// check the services are valid // check the services are valid
if len(config.Services) == 0 { if len(config.Lifecycles) == 0 {
return nil, errors.New("node must have at least one service") return nil, errors.New("node must have at least one service")
} }
for _, service := range config.Services { for _, service := range config.Lifecycles {
if _, exists := s.services[service]; !exists { if _, exists := s.lifecycles[service]; !exists {
return nil, fmt.Errorf("unknown node service %q", service) return nil, fmt.Errorf("unknown node service %q", service)
} }
} }
@ -119,7 +111,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
config: config, config: config,
node: n, node: n,
adapter: s, adapter: s,
running: make(map[string]node.Service), running: make(map[string]node.Lifecycle),
} }
s.nodes[id] = simNode s.nodes[id] = simNode
return simNode, nil return simNode, nil
@ -155,11 +147,7 @@ func (s *SimAdapter) DialRPC(id enode.ID) (*rpc.Client, error) {
if !ok { if !ok {
return nil, fmt.Errorf("unknown node: %s", id) return nil, fmt.Errorf("unknown node: %s", id)
} }
handler, err := node.node.RPCHandler() return node.node.Attach()
if err != nil {
return nil, err
}
return rpc.DialInProc(handler), nil
} }
// GetNode returns the node with the given ID if it exists // GetNode returns the node with the given ID if it exists
@ -179,7 +167,7 @@ type SimNode struct {
config *NodeConfig config *NodeConfig
adapter *SimAdapter adapter *SimAdapter
node *node.Node node *node.Node
running map[string]node.Service running map[string]node.Lifecycle
client *rpc.Client client *rpc.Client
registerOnce sync.Once registerOnce sync.Once
} }
@ -227,7 +215,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
// simulation_snapshot RPC method // simulation_snapshot RPC method
func (sn *SimNode) Snapshots() (map[string][]byte, error) { func (sn *SimNode) Snapshots() (map[string][]byte, error) {
sn.lock.RLock() sn.lock.RLock()
services := make(map[string]node.Service, len(sn.running)) services := make(map[string]node.Lifecycle, len(sn.running))
for name, service := range sn.running { for name, service := range sn.running {
services[name] = service services[name] = service
} }
@ -252,35 +240,30 @@ func (sn *SimNode) Snapshots() (map[string][]byte, error) {
// Start registers the services and starts the underlying devp2p node // Start registers the services and starts the underlying devp2p node
func (sn *SimNode) Start(snapshots map[string][]byte) error { func (sn *SimNode) Start(snapshots map[string][]byte) error {
newService := func(name string) func(ctx *node.ServiceContext) (node.Service, error) {
return func(nodeCtx *node.ServiceContext) (node.Service, error) {
ctx := &ServiceContext{
RPCDialer: sn.adapter,
NodeContext: nodeCtx,
Config: sn.config,
}
if snapshots != nil {
ctx.Snapshot = snapshots[name]
}
serviceFunc := sn.adapter.services[name]
service, err := serviceFunc(ctx)
if err != nil {
return nil, err
}
sn.running[name] = service
return service, nil
}
}
// ensure we only register the services once in the case of the node // ensure we only register the services once in the case of the node
// being stopped and then started again // being stopped and then started again
var regErr error var regErr error
sn.registerOnce.Do(func() { sn.registerOnce.Do(func() {
for _, name := range sn.config.Services { for _, name := range sn.config.Lifecycles {
if err := sn.node.Register(newService(name)); err != nil { ctx := &ServiceContext{
RPCDialer: sn.adapter,
Config: sn.config,
}
if snapshots != nil {
ctx.Snapshot = snapshots[name]
}
serviceFunc := sn.adapter.lifecycles[name]
service, err := serviceFunc(ctx, sn.node)
if err != nil {
regErr = err regErr = err
break break
} }
// if the service has already been registered, don't register it again.
if _, ok := sn.running[name]; ok {
continue
}
sn.running[name] = service
sn.node.RegisterLifecycle(service)
} }
}) })
if regErr != nil { if regErr != nil {
@ -292,13 +275,12 @@ func (sn *SimNode) Start(snapshots map[string][]byte) error {
} }
// create an in-process RPC client // create an in-process RPC client
handler, err := sn.node.RPCHandler() client, err := sn.node.Attach()
if err != nil { if err != nil {
return err return err
} }
sn.lock.Lock() sn.lock.Lock()
sn.client = rpc.DialInProc(handler) sn.client = client
sn.lock.Unlock() sn.lock.Unlock()
return nil return nil
@ -312,21 +294,21 @@ func (sn *SimNode) Stop() error {
sn.client = nil sn.client = nil
} }
sn.lock.Unlock() sn.lock.Unlock()
return sn.node.Stop() return sn.node.Close()
} }
// Service returns a running service by name // Service returns a running service by name
func (sn *SimNode) Service(name string) node.Service { func (sn *SimNode) Service(name string) node.Lifecycle {
sn.lock.RLock() sn.lock.RLock()
defer sn.lock.RUnlock() defer sn.lock.RUnlock()
return sn.running[name] return sn.running[name]
} }
// Services returns a copy of the underlying services // Services returns a copy of the underlying services
func (sn *SimNode) Services() []node.Service { func (sn *SimNode) Services() []node.Lifecycle {
sn.lock.RLock() sn.lock.RLock()
defer sn.lock.RUnlock() defer sn.lock.RUnlock()
services := make([]node.Service, 0, len(sn.running)) services := make([]node.Lifecycle, 0, len(sn.running))
for _, service := range sn.running { for _, service := range sn.running {
services = append(services, service) services = append(services, service)
} }
@ -334,10 +316,10 @@ func (sn *SimNode) Services() []node.Service {
} }
// ServiceMap returns a map by names of the underlying services // ServiceMap returns a map by names of the underlying services
func (sn *SimNode) ServiceMap() map[string]node.Service { func (sn *SimNode) ServiceMap() map[string]node.Lifecycle {
sn.lock.RLock() sn.lock.RLock()
defer sn.lock.RUnlock() defer sn.lock.RUnlock()
services := make(map[string]node.Service, len(sn.running)) services := make(map[string]node.Lifecycle, len(sn.running))
for name, service := range sn.running { for name, service := range sn.running {
services[name] = service services[name] = service
} }

@ -96,11 +96,11 @@ type NodeConfig struct {
// Use an existing database instead of a temporary one if non-empty // Use an existing database instead of a temporary one if non-empty
DataDir string DataDir string
// Services are the names of the services which should be run when // Lifecycles are the names of the service lifecycles which should be run when
// starting the node (for SimNodes it should be the names of services // starting the node (for SimNodes it should be the names of service lifecycles
// contained in SimAdapter.services, for other nodes it should be // contained in SimAdapter.lifecycles, for other nodes it should be
// services registered by calling the RegisterService function) // service lifecycles registered by calling the RegisterLifecycle function)
Services []string Lifecycles []string
// Properties are the names of the properties this node should hold // Properties are the names of the properties this node should hold
// within running services (e.g. "bootnode", "lightnode" or any custom values) // within running services (e.g. "bootnode", "lightnode" or any custom values)
@ -137,7 +137,7 @@ func (n *NodeConfig) MarshalJSON() ([]byte, error) {
confJSON := nodeConfigJSON{ confJSON := nodeConfigJSON{
ID: n.ID.String(), ID: n.ID.String(),
Name: n.Name, Name: n.Name,
Services: n.Services, Services: n.Lifecycles,
Properties: n.Properties, Properties: n.Properties,
Port: n.Port, Port: n.Port,
EnableMsgEvents: n.EnableMsgEvents, EnableMsgEvents: n.EnableMsgEvents,
@ -175,7 +175,7 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error {
} }
n.Name = confJSON.Name n.Name = confJSON.Name
n.Services = confJSON.Services n.Lifecycles = confJSON.Services
n.Properties = confJSON.Properties n.Properties = confJSON.Properties
n.Port = confJSON.Port n.Port = confJSON.Port
n.EnableMsgEvents = confJSON.EnableMsgEvents n.EnableMsgEvents = confJSON.EnableMsgEvents
@ -233,9 +233,8 @@ func assignTCPPort() (uint16, error) {
type ServiceContext struct { type ServiceContext struct {
RPCDialer RPCDialer
NodeContext *node.ServiceContext Config *NodeConfig
Config *NodeConfig Snapshot []byte
Snapshot []byte
} }
// RPCDialer is used when initialising services which need to connect to // RPCDialer is used when initialising services which need to connect to
@ -245,27 +244,29 @@ type RPCDialer interface {
DialRPC(id enode.ID) (*rpc.Client, error) DialRPC(id enode.ID) (*rpc.Client, error)
} }
// Services is a collection of services which can be run in a simulation // LifecycleConstructor allows a Lifecycle to be constructed during node start-up.
type Services map[string]ServiceFunc // While the service-specific package usually takes care of Lifecycle creation and registration,
// for testing purposes, it is useful to be able to construct a Lifecycle on spot.
type LifecycleConstructor func(ctx *ServiceContext, stack *node.Node) (node.Lifecycle, error)
// ServiceFunc returns a node.Service which can be used to boot a devp2p node // LifecycleConstructors stores LifecycleConstructor functions to call during node start-up.
type ServiceFunc func(ctx *ServiceContext) (node.Service, error) type LifecycleConstructors map[string]LifecycleConstructor
// serviceFuncs is a map of registered services which are used to boot devp2p // lifecycleConstructorFuncs is a map of registered services which are used to boot devp2p
// nodes // nodes
var serviceFuncs = make(Services) var lifecycleConstructorFuncs = make(LifecycleConstructors)
// RegisterServices registers the given Services which can then be used to // RegisterLifecycles registers the given Services which can then be used to
// start devp2p nodes using either the Exec or Docker adapters. // start devp2p nodes using either the Exec or Docker adapters.
// //
// It should be called in an init function so that it has the opportunity to // It should be called in an init function so that it has the opportunity to
// execute the services before main() is called. // execute the services before main() is called.
func RegisterServices(services Services) { func RegisterLifecycles(lifecycles LifecycleConstructors) {
for name, f := range services { for name, f := range lifecycles {
if _, exists := serviceFuncs[name]; exists { if _, exists := lifecycleConstructorFuncs[name]; exists {
panic(fmt.Sprintf("node service already exists: %q", name)) panic(fmt.Sprintf("node service already exists: %q", name))
} }
serviceFuncs[name] = f lifecycleConstructorFuncs[name] = f
} }
// now we have registered the services, run reexec.Init() which will // now we have registered the services, run reexec.Init() which will

@ -26,8 +26,8 @@ import (
func newTestNetwork(t *testing.T, nodeCount int) (*Network, []enode.ID) { func newTestNetwork(t *testing.T, nodeCount int) (*Network, []enode.ID) {
t.Helper() t.Helper()
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
return NewNoopService(nil), nil return NewNoopService(nil), nil
}, },
}) })

@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations" "github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters" "github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rpc"
) )
var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim", "exec" or "docker")`) var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim", "exec" or "docker")`)
@ -45,12 +44,14 @@ func main() {
log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
// register a single ping-pong service // register a single ping-pong service
services := map[string]adapters.ServiceFunc{ services := map[string]adapters.LifecycleConstructor{
"ping-pong": func(ctx *adapters.ServiceContext) (node.Service, error) { "ping-pong": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
return newPingPongService(ctx.Config.ID), nil pps := newPingPongService(ctx.Config.ID)
stack.RegisterProtocols(pps.Protocols())
return pps, nil
}, },
} }
adapters.RegisterServices(services) adapters.RegisterLifecycles(services)
// create the NodeAdapter // create the NodeAdapter
var adapter adapters.NodeAdapter var adapter adapters.NodeAdapter
@ -110,11 +111,7 @@ func (p *pingPongService) Protocols() []p2p.Protocol {
}} }}
} }
func (p *pingPongService) APIs() []rpc.API { func (p *pingPongService) Start() error {
return nil
}
func (p *pingPongService) Start(server *p2p.Server) error {
p.log.Info("ping-pong service starting") p.log.Info("ping-pong service starting")
return nil return nil
} }

@ -64,12 +64,15 @@ type testService struct {
state atomic.Value state atomic.Value
} }
func newTestService(ctx *adapters.ServiceContext) (node.Service, error) { func newTestService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
svc := &testService{ svc := &testService{
id: ctx.Config.ID, id: ctx.Config.ID,
peers: make(map[enode.ID]*testPeer), peers: make(map[enode.ID]*testPeer),
} }
svc.state.Store(ctx.Snapshot) svc.state.Store(ctx.Snapshot)
stack.RegisterProtocols(svc.Protocols())
stack.RegisterAPIs(svc.APIs())
return svc, nil return svc, nil
} }
@ -126,7 +129,7 @@ func (t *testService) APIs() []rpc.API {
}} }}
} }
func (t *testService) Start(server *p2p.Server) error { func (t *testService) Start() error {
return nil return nil
} }
@ -288,7 +291,7 @@ func (t *TestAPI) Events(ctx context.Context) (*rpc.Subscription, error) {
return rpcSub, nil return rpcSub, nil
} }
var testServices = adapters.Services{ var testServices = adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
} }

@ -110,8 +110,8 @@ func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error)
} }
// if no services are configured, use the default service // if no services are configured, use the default service
if len(conf.Services) == 0 { if len(conf.Lifecycles) == 0 {
conf.Services = []string{net.DefaultService} conf.Lifecycles = []string{net.DefaultService}
} }
// use the NodeAdapter to create the node // use the NodeAdapter to create the node
@ -913,19 +913,19 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn
snap.Nodes[i].Snapshots = snapshots snap.Nodes[i].Snapshots = snapshots
for _, addSvc := range addServices { for _, addSvc := range addServices {
haveSvc := false haveSvc := false
for _, svc := range snap.Nodes[i].Node.Config.Services { for _, svc := range snap.Nodes[i].Node.Config.Lifecycles {
if svc == addSvc { if svc == addSvc {
haveSvc = true haveSvc = true
break break
} }
} }
if !haveSvc { if !haveSvc {
snap.Nodes[i].Node.Config.Services = append(snap.Nodes[i].Node.Config.Services, addSvc) snap.Nodes[i].Node.Config.Lifecycles = append(snap.Nodes[i].Node.Config.Lifecycles, addSvc)
} }
} }
if len(removeServices) > 0 { if len(removeServices) > 0 {
var cleanedServices []string var cleanedServices []string
for _, svc := range snap.Nodes[i].Node.Config.Services { for _, svc := range snap.Nodes[i].Node.Config.Lifecycles {
haveSvc := false haveSvc := false
for _, rmSvc := range removeServices { for _, rmSvc := range removeServices {
if rmSvc == svc { if rmSvc == svc {
@ -938,7 +938,7 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn
} }
} }
snap.Nodes[i].Node.Config.Services = cleanedServices snap.Nodes[i].Node.Config.Lifecycles = cleanedServices
} }
} }
for _, conn := range net.Conns { for _, conn := range net.Conns {

@ -41,8 +41,8 @@ func TestSnapshot(t *testing.T) {
// create snapshot from ring network // create snapshot from ring network
// this is a minimal service, whose protocol will take exactly one message OR close of connection before quitting // this is a minimal service, whose protocol will take exactly one message OR close of connection before quitting
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
return NewNoopService(nil), nil return NewNoopService(nil), nil
}, },
}) })
@ -165,8 +165,8 @@ OUTER:
// PART II // PART II
// load snapshot and verify that exactly same connections are formed // load snapshot and verify that exactly same connections are formed
adapter = adapters.NewSimAdapter(adapters.Services{ adapter = adapters.NewSimAdapter(adapters.LifecycleConstructors{
"noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
return NewNoopService(nil), nil return NewNoopService(nil), nil
}, },
}) })
@ -256,8 +256,8 @@ OuterTwo:
t.Run("conns after load", func(t *testing.T) { t.Run("conns after load", func(t *testing.T) {
// Create new network. // Create new network.
n := NewNetwork( n := NewNetwork(
adapters.NewSimAdapter(adapters.Services{ adapters.NewSimAdapter(adapters.LifecycleConstructors{
"noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
return NewNoopService(nil), nil return NewNoopService(nil), nil
}, },
}), }),
@ -288,7 +288,7 @@ OuterTwo:
// with each other and that a snapshot fully represents the desired topology // with each other and that a snapshot fully represents the desired topology
func TestNetworkSimulation(t *testing.T) { func TestNetworkSimulation(t *testing.T) {
// create simulation network with 20 testService nodes // create simulation network with 20 testService nodes
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -437,7 +437,7 @@ func createTestNodesWithProperty(property string, count int, network *Network) (
// It then tests again whilst excluding a node ID from being returned. // It then tests again whilst excluding a node ID from being returned.
// If a node ID is not returned, or more node IDs than expected are returned, the test fails. // If a node ID is not returned, or more node IDs than expected are returned, the test fails.
func TestGetNodeIDs(t *testing.T) { func TestGetNodeIDs(t *testing.T) {
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -486,7 +486,7 @@ func TestGetNodeIDs(t *testing.T) {
// It then tests again whilst excluding a node from being returned. // It then tests again whilst excluding a node from being returned.
// If a node is not returned, or more nodes than expected are returned, the test fails. // If a node is not returned, or more nodes than expected are returned, the test fails.
func TestGetNodes(t *testing.T) { func TestGetNodes(t *testing.T) {
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -534,7 +534,7 @@ func TestGetNodes(t *testing.T) {
// TestGetNodesByID creates a set of nodes and attempts to retrieve a subset of them by ID // TestGetNodesByID creates a set of nodes and attempts to retrieve a subset of them by ID
// If a node is not returned, or more nodes than expected are returned, the test fails. // If a node is not returned, or more nodes than expected are returned, the test fails.
func TestGetNodesByID(t *testing.T) { func TestGetNodesByID(t *testing.T) {
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -579,7 +579,7 @@ func TestGetNodesByID(t *testing.T) {
// GetNodesByProperty is then checked for correctness by comparing the nodes returned to those initially created. // GetNodesByProperty is then checked for correctness by comparing the nodes returned to those initially created.
// If a node with a property is not found, or more nodes than expected are returned, the test fails. // If a node with a property is not found, or more nodes than expected are returned, the test fails.
func TestGetNodesByProperty(t *testing.T) { func TestGetNodesByProperty(t *testing.T) {
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -624,7 +624,7 @@ func TestGetNodesByProperty(t *testing.T) {
// GetNodeIDsByProperty is then checked for correctness by comparing the node IDs returned to those initially created. // GetNodeIDsByProperty is then checked for correctness by comparing the node IDs returned to those initially created.
// If a node ID with a property is not found, or more nodes IDs than expected are returned, the test fails. // If a node ID with a property is not found, or more nodes IDs than expected are returned, the test fails.
func TestGetNodeIDsByProperty(t *testing.T) { func TestGetNodeIDsByProperty(t *testing.T) {
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"test": newTestService, "test": newTestService,
}) })
network := NewNetwork(adapter, &NetworkConfig{ network := NewNetwork(adapter, &NetworkConfig{
@ -705,8 +705,8 @@ func benchmarkMinimalServiceTmp(b *testing.B) {
// this is a minimal service, whose protocol will close a channel upon run of protocol // this is a minimal service, whose protocol will close a channel upon run of protocol
// making it possible to bench the time it takes for the service to start and protocol actually to be run // making it possible to bench the time it takes for the service to start and protocol actually to be run
protoCMap := make(map[enode.ID]map[enode.ID]chan struct{}) protoCMap := make(map[enode.ID]map[enode.ID]chan struct{})
adapter := adapters.NewSimAdapter(adapters.Services{ adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{
"noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
protoCMap[ctx.Config.ID] = make(map[enode.ID]chan struct{}) protoCMap[ctx.Config.ID] = make(map[enode.ID]chan struct{})
svc := NewNoopService(protoCMap[ctx.Config.ID]) svc := NewNoopService(protoCMap[ctx.Config.ID])
return svc, nil return svc, nil

@ -66,7 +66,7 @@ func (t *NoopService) APIs() []rpc.API {
return []rpc.API{} return []rpc.API{}
} }
func (t *NoopService) Start(server *p2p.Server) error { func (t *NoopService) Start() error {
return nil return nil
} }

@ -1,67 +0,0 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package testing
import (
"fmt"
"sync"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
)
type TestPeer interface {
ID() enode.ID
Drop()
}
// TestPeerPool is an example peerPool to demonstrate registration of peer connections
type TestPeerPool struct {
lock sync.Mutex
peers map[enode.ID]TestPeer
}
func NewTestPeerPool() *TestPeerPool {
return &TestPeerPool{peers: make(map[enode.ID]TestPeer)}
}
func (p *TestPeerPool) Add(peer TestPeer) {
p.lock.Lock()
defer p.lock.Unlock()
log.Trace(fmt.Sprintf("pp add peer %v", peer.ID()))
p.peers[peer.ID()] = peer
}
func (p *TestPeerPool) Remove(peer TestPeer) {
p.lock.Lock()
defer p.lock.Unlock()
delete(p.peers, peer.ID())
}
func (p *TestPeerPool) Has(id enode.ID) bool {
p.lock.Lock()
defer p.lock.Unlock()
_, ok := p.peers[id]
return ok
}
func (p *TestPeerPool) Get(id enode.ID) TestPeer {
p.lock.Lock()
defer p.lock.Unlock()
return p.peers[id]
}

@ -1,283 +0,0 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package testing
import (
"errors"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)
var errTimedOut = errors.New("timed out")
// ProtocolSession is a quasi simulation of a pivot node running
// a service and a number of dummy peers that can send (trigger) or
// receive (expect) messages
type ProtocolSession struct {
Server *p2p.Server
Nodes []*enode.Node
adapter *adapters.SimAdapter
events chan *p2p.PeerEvent
}
// Exchange is the basic units of protocol tests
// the triggers and expects in the arrays are run immediately and asynchronously
// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types
// because it's unpredictable which expect will receive which message
// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code)
// an exchange is defined on a session
type Exchange struct {
Label string
Triggers []Trigger
Expects []Expect
Timeout time.Duration
}
// Trigger is part of the exchange, incoming message for the pivot node
// sent by a peer
type Trigger struct {
Msg interface{} // type of message to be sent
Code uint64 // code of message is given
Peer enode.ID // the peer to send the message to
Timeout time.Duration // timeout duration for the sending
}
// Expect is part of an exchange, outgoing message from the pivot node
// received by a peer
type Expect struct {
Msg interface{} // type of message to expect
Code uint64 // code of message is now given
Peer enode.ID // the peer that expects the message
Timeout time.Duration // timeout duration for receiving
}
// Disconnect represents a disconnect event, used and checked by TestDisconnected
type Disconnect struct {
Peer enode.ID // discconnected peer
Error error // disconnect reason
}
// trigger sends messages from peers
func (s *ProtocolSession) trigger(trig Trigger) error {
simNode, ok := s.adapter.GetNode(trig.Peer)
if !ok {
return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(s.Nodes))
}
mockNode, ok := simNode.Services()[0].(*mockNode)
if !ok {
return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer)
}
errc := make(chan error)
go func() {
log.Trace(fmt.Sprintf("trigger %v (%v)....", trig.Msg, trig.Code))
errc <- mockNode.Trigger(&trig)
log.Trace(fmt.Sprintf("triggered %v (%v)", trig.Msg, trig.Code))
}()
t := trig.Timeout
if t == time.Duration(0) {
t = 1000 * time.Millisecond
}
select {
case err := <-errc:
return err
case <-time.After(t):
return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer)
}
}
// expect checks an expectation of a message sent out by the pivot node
func (s *ProtocolSession) expect(exps []Expect) error {
// construct a map of expectations for each node
peerExpects := make(map[enode.ID][]Expect)
for _, exp := range exps {
if exp.Msg == nil {
return errors.New("no message to expect")
}
peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
}
// construct a map of mockNodes for each node
mockNodes := make(map[enode.ID]*mockNode)
for nodeID := range peerExpects {
simNode, ok := s.adapter.GetNode(nodeID)
if !ok {
return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(s.Nodes))
}
mockNode, ok := simNode.Services()[0].(*mockNode)
if !ok {
return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
}
mockNodes[nodeID] = mockNode
}
// done chanell cancels all created goroutines when function returns
done := make(chan struct{})
defer close(done)
// errc catches the first error from
errc := make(chan error)
wg := &sync.WaitGroup{}
wg.Add(len(mockNodes))
for nodeID, mockNode := range mockNodes {
nodeID := nodeID
mockNode := mockNode
go func() {
defer wg.Done()
// Sum all Expect timeouts to give the maximum
// time for all expectations to finish.
// mockNode.Expect checks all received messages against
// a list of expected messages and timeout for each
// of them can not be checked separately.
var t time.Duration
for _, exp := range peerExpects[nodeID] {
if exp.Timeout == time.Duration(0) {
t += 2000 * time.Millisecond
} else {
t += exp.Timeout
}
}
alarm := time.NewTimer(t)
defer alarm.Stop()
// expectErrc is used to check if error returned
// from mockNode.Expect is not nil and to send it to
// errc only in that case.
// done channel will be closed when function
expectErrc := make(chan error)
go func() {
select {
case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
case <-done:
case <-alarm.C:
}
}()
select {
case err := <-expectErrc:
if err != nil {
select {
case errc <- err:
case <-done:
case <-alarm.C:
errc <- errTimedOut
}
}
case <-done:
case <-alarm.C:
errc <- errTimedOut
}
}()
}
go func() {
wg.Wait()
// close errc when all goroutines finish to return nill err from errc
close(errc)
}()
return <-errc
}
// TestExchanges tests a series of exchanges against the session
func (s *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
for i, e := range exchanges {
if err := s.testExchange(e); err != nil {
return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
}
log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
}
return nil
}
// testExchange tests a single Exchange.
// Default timeout value is 2 seconds.
func (s *ProtocolSession) testExchange(e Exchange) error {
errc := make(chan error)
done := make(chan struct{})
defer close(done)
go func() {
for _, trig := range e.Triggers {
err := s.trigger(trig)
if err != nil {
errc <- err
return
}
}
select {
case errc <- s.expect(e.Expects):
case <-done:
}
}()
// time out globally or finish when all expectations satisfied
t := e.Timeout
if t == 0 {
t = 2000 * time.Millisecond
}
alarm := time.NewTimer(t)
defer alarm.Stop()
select {
case err := <-errc:
return err
case <-alarm.C:
return errTimedOut
}
}
// TestDisconnected tests the disconnections given as arguments
// the disconnect structs describe what disconnect error is expected on which peer
func (s *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error {
expects := make(map[enode.ID]error)
for _, disconnect := range disconnects {
expects[disconnect.Peer] = disconnect.Error
}
timeout := time.After(time.Second)
for len(expects) > 0 {
select {
case event := <-s.events:
if event.Type != p2p.PeerEventTypeDrop {
continue
}
expectErr, ok := expects[event.Peer]
if !ok {
continue
}
if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) {
return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error)
}
delete(expects, event.Peer)
case <-timeout:
return fmt.Errorf("timed out waiting for peers to disconnect")
}
}
return nil
}

@ -1,284 +0,0 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
/*
the p2p/testing package provides a unit test scheme to check simple
protocol message exchanges with one pivot node and a number of dummy peers
The pivot test node runs a node.Service, the dummy peers run a mock node
that can be used to send and receive messages
*/
package testing
import (
"bytes"
"crypto/ecdsa"
"fmt"
"io"
"io/ioutil"
"strings"
"sync"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
// ProtocolTester is the tester environment used for unit testing protocol
// message exchanges. It uses p2p/simulations framework
type ProtocolTester struct {
*ProtocolSession
network *simulations.Network
}
// NewProtocolTester constructs a new ProtocolTester
// it takes as argument the pivot node id, the number of dummy peers and the
// protocol run function called on a peer connection by the p2p server
func NewProtocolTester(prvkey *ecdsa.PrivateKey, nodeCount int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester {
services := adapters.Services{
"test": func(ctx *adapters.ServiceContext) (node.Service, error) {
return &testNode{run}, nil
},
"mock": func(ctx *adapters.ServiceContext) (node.Service, error) {
return newMockNode(), nil
},
}
adapter := adapters.NewSimAdapter(services)
net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{})
nodeConfig := &adapters.NodeConfig{
PrivateKey: prvkey,
EnableMsgEvents: true,
Services: []string{"test"},
}
if _, err := net.NewNodeWithConfig(nodeConfig); err != nil {
panic(err.Error())
}
if err := net.Start(nodeConfig.ID); err != nil {
panic(err.Error())
}
node := net.GetNode(nodeConfig.ID).Node.(*adapters.SimNode)
peers := make([]*adapters.NodeConfig, nodeCount)
nodes := make([]*enode.Node, nodeCount)
for i := 0; i < nodeCount; i++ {
peers[i] = adapters.RandomNodeConfig()
peers[i].Services = []string{"mock"}
if _, err := net.NewNodeWithConfig(peers[i]); err != nil {
panic(fmt.Sprintf("error initializing peer %v: %v", peers[i].ID, err))
}
if err := net.Start(peers[i].ID); err != nil {
panic(fmt.Sprintf("error starting peer %v: %v", peers[i].ID, err))
}
nodes[i] = peers[i].Node()
}
events := make(chan *p2p.PeerEvent, 1000)
node.SubscribeEvents(events)
ps := &ProtocolSession{
Server: node.Server(),
Nodes: nodes,
adapter: adapter,
events: events,
}
self := &ProtocolTester{
ProtocolSession: ps,
network: net,
}
self.Connect(nodeConfig.ID, peers...)
return self
}
// Stop stops the p2p server
func (t *ProtocolTester) Stop() {
t.Server.Stop()
t.network.Shutdown()
}
// Connect brings up the remote peer node and connects it using the
// p2p/simulations network connection with the in memory network adapter
func (t *ProtocolTester) Connect(selfID enode.ID, peers ...*adapters.NodeConfig) {
for _, peer := range peers {
log.Trace(fmt.Sprintf("connect to %v", peer.ID))
if err := t.network.Connect(selfID, peer.ID); err != nil {
panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err))
}
}
}
// testNode wraps a protocol run function and implements the node.Service
// interface
type testNode struct {
run func(*p2p.Peer, p2p.MsgReadWriter) error
}
func (t *testNode) Protocols() []p2p.Protocol {
return []p2p.Protocol{{
Length: 100,
Run: t.run,
}}
}
func (t *testNode) APIs() []rpc.API {
return nil
}
func (t *testNode) Start(server *p2p.Server) error {
return nil
}
func (t *testNode) Stop() error {
return nil
}
// mockNode is a testNode which doesn't actually run a protocol, instead
// exposing channels so that tests can manually trigger and expect certain
// messages
type mockNode struct {
testNode
trigger chan *Trigger
expect chan []Expect
err chan error
stop chan struct{}
stopOnce sync.Once
}
func newMockNode() *mockNode {
mock := &mockNode{
trigger: make(chan *Trigger),
expect: make(chan []Expect),
err: make(chan error),
stop: make(chan struct{}),
}
mock.testNode.run = mock.Run
return mock
}
// Run is a protocol run function which just loops waiting for tests to
// instruct it to either trigger or expect a message from the peer
func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
for {
select {
case trig := <-m.trigger:
wmsg := Wrap(trig.Msg)
m.err <- p2p.Send(rw, trig.Code, wmsg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
return nil
}
}
}
func (m *mockNode) Trigger(trig *Trigger) error {
m.trigger <- trig
return <-m.err
}
func (m *mockNode) Expect(exp ...Expect) error {
m.expect <- exp
return <-m.err
}
func (m *mockNode) Stop() error {
m.stopOnce.Do(func() { close(m.stop) })
return nil
}
func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
matched := make([]bool, len(exps))
for {
msg, err := rw.ReadMsg()
if err != nil {
if err == io.EOF {
break
}
return err
}
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
}
var found bool
for i, exp := range exps {
if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
matched[i] = true
found = true
break
}
}
if !found {
expected := make([]string, 0)
for i, exp := range exps {
if matched[i] {
continue
}
expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg))))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
done := true
for _, m := range matched {
if !m {
done = false
break
}
}
if done {
return nil
}
}
for i, m := range matched {
if !m {
return fmt.Errorf("expected message #%d not received", i)
}
}
return nil
}
// mustEncodeMsg uses rlp to encode a message.
// In case of error it panics.
func mustEncodeMsg(msg interface{}) []byte {
contentEnc, err := rlp.EncodeToBytes(msg)
if err != nil {
panic("content encode error: " + err.Error())
}
return contentEnc
}
type WrappedMsg struct {
Context []byte
Size uint32
Payload []byte
}
func Wrap(msg interface{}) interface{} {
data, _ := rlp.EncodeToBytes(msg)
return &WrappedMsg{
Size: uint32(len(data)),
Payload: data,
}
}

@ -27,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/node"
whisper "github.com/ethereum/go-ethereum/whisper/whisperv6" whisper "github.com/ethereum/go-ethereum/whisper/whisperv6"
) )
@ -89,7 +90,11 @@ func TestMailServer(t *testing.T) {
} }
var server WMailServer var server WMailServer
shh = whisper.New(&whisper.DefaultConfig)
stack, w := newNode(t)
defer stack.Close()
shh = w
shh.RegisterServer(&server) shh.RegisterServer(&server)
err = server.Init(shh, dir, password, powRequirement) err = server.Init(shh, dir, password, powRequirement)
@ -210,3 +215,21 @@ func createRequest(t *testing.T, p *ServerTestParams) *whisper.Envelope {
} }
return env return env
} }
// newNode creates a new node using a default config and
// creates and registers a new Whisper service on it.
func newNode(t *testing.T) (*node.Node, *whisper.Whisper) {
stack, err := node.New(&node.DefaultConfig)
if err != nil {
t.Fatalf("could not create new node: %v", err)
}
w, err := whisper.New(stack, &whisper.DefaultConfig)
if err != nil {
t.Fatalf("could not create new whisper service: %v", err)
}
err = stack.Start()
if err != nil {
t.Fatalf("could not start node: %v", err)
}
return stack, w
}

@ -23,7 +23,8 @@ import (
) )
func TestMultipleTopicCopyInNewMessageFilter(t *testing.T) { func TestMultipleTopicCopyInNewMessageFilter(t *testing.T) {
w := New(nil) stack, w := newNodeWithWhisper(t)
defer stack.Close()
keyID, err := w.GenerateSymKey() keyID, err := w.GenerateSymKey()
if err != nil { if err != nil {

@ -92,7 +92,10 @@ func TestInstallFilters(t *testing.T) {
InitSingleTest() InitSingleTest()
const SizeTestFilters = 256 const SizeTestFilters = 256
w := New(&Config{})
stack, w := newNodeWithWhisper(t)
defer stack.Close()
filters := NewFilters(w) filters := NewFilters(w)
tst := generateTestCases(t, SizeTestFilters) tst := generateTestCases(t, SizeTestFilters)
@ -130,7 +133,9 @@ func TestInstallFilters(t *testing.T) {
func TestInstallSymKeyGeneratesHash(t *testing.T) { func TestInstallSymKeyGeneratesHash(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&Config{}) stack, w := newNodeWithWhisper(t)
defer stack.Close()
filters := NewFilters(w) filters := NewFilters(w)
filter, _ := generateFilter(t, true) filter, _ := generateFilter(t, true)
@ -157,7 +162,9 @@ func TestInstallSymKeyGeneratesHash(t *testing.T) {
func TestInstallIdenticalFilters(t *testing.T) { func TestInstallIdenticalFilters(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&Config{}) stack, w := newNodeWithWhisper(t)
defer stack.Close()
filters := NewFilters(w) filters := NewFilters(w)
filter1, _ := generateFilter(t, true) filter1, _ := generateFilter(t, true)
@ -227,7 +234,9 @@ func TestInstallIdenticalFilters(t *testing.T) {
func TestInstallFilterWithSymAndAsymKeys(t *testing.T) { func TestInstallFilterWithSymAndAsymKeys(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&Config{}) stack, w := newNodeWithWhisper(t)
defer stack.Close()
filters := NewFilters(w) filters := NewFilters(w)
filter1, _ := generateFilter(t, true) filter1, _ := generateFilter(t, true)
@ -641,7 +650,9 @@ func TestWatchers(t *testing.T) {
var x, firstID string var x, firstID string
var err error var err error
w := New(&Config{}) stack, w := newNodeWithWhisper(t)
defer stack.Close()
filters := NewFilters(w) filters := NewFilters(w)
tst := generateTestCases(t, NumFilters) tst := generateTestCases(t, NumFilters)
for i = 0; i < NumFilters; i++ { for i = 0; i < NumFilters; i++ {

@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
@ -93,7 +94,7 @@ type Whisper struct {
} }
// New creates a Whisper client ready to communicate through the Ethereum P2P network. // New creates a Whisper client ready to communicate through the Ethereum P2P network.
func New(cfg *Config) *Whisper { func New(stack *node.Node, cfg *Config) (*Whisper, error) {
if cfg == nil { if cfg == nil {
cfg = &DefaultConfig cfg = &DefaultConfig
} }
@ -132,7 +133,10 @@ func New(cfg *Config) *Whisper {
}, },
} }
return whisper stack.RegisterAPIs(whisper.APIs())
stack.RegisterProtocols(whisper.Protocols())
stack.RegisterLifecycle(whisper)
return whisper, nil
} }
// MinPow returns the PoW value required by this node. // MinPow returns the PoW value required by this node.
@ -634,9 +638,9 @@ func (whisper *Whisper) Send(envelope *Envelope) error {
return err return err
} }
// Start implements node.Service, starting the background data propagation thread // Start implements node.Lifecycle, starting the background data propagation thread
// of the Whisper protocol. // of the Whisper protocol.
func (whisper *Whisper) Start(*p2p.Server) error { func (whisper *Whisper) Start() error {
log.Info("started whisper v." + ProtocolVersionStr) log.Info("started whisper v." + ProtocolVersionStr)
whisper.wg.Add(1) whisper.wg.Add(1)
go whisper.update() go whisper.update()
@ -650,7 +654,7 @@ func (whisper *Whisper) Start(*p2p.Server) error {
return nil return nil
} }
// Stop implements node.Service, stopping the background data propagation thread // Stop implements node.Lifecycle, stopping the background data propagation thread
// of the Whisper protocol. // of the Whisper protocol.
func (whisper *Whisper) Stop() error { func (whisper *Whisper) Stop() error {
close(whisper.quit) close(whisper.quit)
@ -1092,3 +1096,45 @@ func addBloom(a, b []byte) []byte {
} }
return c return c
} }
func StandaloneWhisperService(cfg *Config) *Whisper {
if cfg == nil {
cfg = &DefaultConfig
}
whisper := &Whisper{
privateKeys: make(map[string]*ecdsa.PrivateKey),
symKeys: make(map[string][]byte),
envelopes: make(map[common.Hash]*Envelope),
expirations: make(map[uint32]mapset.Set),
peers: make(map[*Peer]struct{}),
messageQueue: make(chan *Envelope, messageQueueLimit),
p2pMsgQueue: make(chan *Envelope, messageQueueLimit),
quit: make(chan struct{}),
syncAllowance: DefaultSyncAllowance,
}
whisper.filters = NewFilters(whisper)
whisper.settings.Store(minPowIdx, cfg.MinimumAcceptedPOW)
whisper.settings.Store(maxMsgSizeIdx, cfg.MaxMessageSize)
whisper.settings.Store(overflowIdx, false)
whisper.settings.Store(restrictConnectionBetweenLightClientsIdx, cfg.RestrictConnectionBetweenLightClients)
// p2p whisper sub protocol handler
whisper.protocol = p2p.Protocol{
Name: ProtocolName,
Version: uint(ProtocolVersion),
Length: NumberOfMessageCodes,
Run: whisper.HandlePeer,
NodeInfo: func() interface{} {
return map[string]interface{}{
"version": ProtocolVersionStr,
"maxMessageSize": whisper.MaxMessageSize(),
"minimumPoW": whisper.MinPow(),
}
},
}
return whisper
}

@ -25,13 +25,15 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/node"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
) )
func TestWhisperBasic(t *testing.T) { func TestWhisperBasic(t *testing.T) {
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
p := w.Protocols() defer stack.Close()
shh := p[0]
shh := w.Protocols()[0]
if shh.Name != ProtocolName { if shh.Name != ProtocolName {
t.Fatalf("failed Protocol Name: %v.", shh.Name) t.Fatalf("failed Protocol Name: %v.", shh.Name)
} }
@ -111,11 +113,10 @@ func TestWhisperBasic(t *testing.T) {
} }
func TestWhisperAsymmetricKeyImport(t *testing.T) { func TestWhisperAsymmetricKeyImport(t *testing.T) {
var ( stack, w := newNodeWithWhisper(t)
w = New(&DefaultConfig) defer stack.Close()
privateKeys []*ecdsa.PrivateKey
)
var privateKeys []*ecdsa.PrivateKey
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
id, err := w.NewKeyPair() id, err := w.NewKeyPair()
if err != nil { if err != nil {
@ -142,7 +143,9 @@ func TestWhisperAsymmetricKeyImport(t *testing.T) {
} }
func TestWhisperIdentityManagement(t *testing.T) { func TestWhisperIdentityManagement(t *testing.T) {
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
id1, err := w.NewKeyPair() id1, err := w.NewKeyPair()
if err != nil { if err != nil {
t.Fatalf("failed to generate new key pair: %s.", err) t.Fatalf("failed to generate new key pair: %s.", err)
@ -261,12 +264,14 @@ func TestWhisperIdentityManagement(t *testing.T) {
func TestWhisperSymKeyManagement(t *testing.T) { func TestWhisperSymKeyManagement(t *testing.T) {
InitSingleTest() InitSingleTest()
var ( var (
k1, k2 []byte k1, k2 []byte
w = New(&DefaultConfig)
id2 = string("arbitrary-string-2") id2 = string("arbitrary-string-2")
) )
stack, w := newNodeWithWhisper(t)
defer stack.Close()
id1, err := w.GenerateSymKey() id1, err := w.GenerateSymKey()
if err != nil { if err != nil {
t.Fatalf("failed GenerateSymKey with seed %d: %s.", seed, err) t.Fatalf("failed GenerateSymKey with seed %d: %s.", seed, err)
@ -365,7 +370,7 @@ func TestWhisperSymKeyManagement(t *testing.T) {
w.DeleteSymKey(id1) w.DeleteSymKey(id1)
k1, err = w.GetSymKey(id1) k1, err = w.GetSymKey(id1)
if err == nil { if err == nil {
t.Fatalf("failed w.GetSymKey(id1): false positive.") t.Fatal("failed w.GetSymKey(id1): false positive.")
} }
if k1 != nil { if k1 != nil {
t.Fatalf("failed GetSymKey(id1): false positive. key=%v", k1) t.Fatalf("failed GetSymKey(id1): false positive. key=%v", k1)
@ -451,11 +456,12 @@ func TestWhisperSymKeyManagement(t *testing.T) {
func TestExpiry(t *testing.T) { func TestExpiry(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
w.SetMinimumPowTest(0.0000001) w.SetMinimumPowTest(0.0000001)
defer w.SetMinimumPowTest(DefaultMinimumPoW) defer w.SetMinimumPowTest(DefaultMinimumPoW)
w.Start(nil) w.Start()
defer w.Stop()
params, err := generateMessageParams() params, err := generateMessageParams()
if err != nil { if err != nil {
@ -517,11 +523,12 @@ func TestExpiry(t *testing.T) {
func TestCustomization(t *testing.T) { func TestCustomization(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
defer w.SetMinimumPowTest(DefaultMinimumPoW) defer w.SetMinimumPowTest(DefaultMinimumPoW)
defer w.SetMaxMessageSize(DefaultMaxMessageSize) defer w.SetMaxMessageSize(DefaultMaxMessageSize)
w.Start(nil) w.Start()
defer w.Stop()
const smallPoW = 0.00001 const smallPoW = 0.00001
@ -610,11 +617,12 @@ func TestCustomization(t *testing.T) {
func TestSymmetricSendCycle(t *testing.T) { func TestSymmetricSendCycle(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
defer w.SetMinimumPowTest(DefaultMinimumPoW) defer w.SetMinimumPowTest(DefaultMinimumPoW)
defer w.SetMaxMessageSize(DefaultMaxMessageSize) defer w.SetMaxMessageSize(DefaultMaxMessageSize)
w.Start(nil) w.Start()
defer w.Stop()
filter1, err := generateFilter(t, true) filter1, err := generateFilter(t, true)
if err != nil { if err != nil {
@ -701,11 +709,12 @@ func TestSymmetricSendCycle(t *testing.T) {
func TestSymmetricSendWithoutAKey(t *testing.T) { func TestSymmetricSendWithoutAKey(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
defer w.SetMinimumPowTest(DefaultMinimumPoW) defer w.SetMinimumPowTest(DefaultMinimumPoW)
defer w.SetMaxMessageSize(DefaultMaxMessageSize) defer w.SetMaxMessageSize(DefaultMaxMessageSize)
w.Start(nil) w.Start()
defer w.Stop()
filter, err := generateFilter(t, true) filter, err := generateFilter(t, true)
if err != nil { if err != nil {
@ -771,11 +780,12 @@ func TestSymmetricSendWithoutAKey(t *testing.T) {
func TestSymmetricSendKeyMismatch(t *testing.T) { func TestSymmetricSendKeyMismatch(t *testing.T) {
InitSingleTest() InitSingleTest()
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
defer w.SetMinimumPowTest(DefaultMinimumPoW) defer w.SetMinimumPowTest(DefaultMinimumPoW)
defer w.SetMaxMessageSize(DefaultMaxMessageSize) defer w.SetMaxMessageSize(DefaultMaxMessageSize)
w.Start(nil) w.Start()
defer w.Stop()
filter, err := generateFilter(t, true) filter, err := generateFilter(t, true)
if err != nil { if err != nil {
@ -882,17 +892,37 @@ func TestBloom(t *testing.T) {
t.Fatal("bloomFilterMatch false negative") t.Fatal("bloomFilterMatch false negative")
} }
w := New(&DefaultConfig) stack, w := newNodeWithWhisper(t)
defer stack.Close()
f := w.BloomFilter() f := w.BloomFilter()
if f != nil { if f != nil {
t.Fatal("wrong bloom on creation") t.Fatal("wrong bloom on creation")
} }
err = w.SetBloomFilter(x) err = w.SetBloomFilter(x)
if err != nil { if err != nil {
t.Fatalf("failed to set bloom filter: %s", err) t.Fatalf("failed to set bloom filter: %v", err)
} }
f = w.BloomFilter() f = w.BloomFilter()
if !BloomFilterMatch(f, x) || !BloomFilterMatch(x, f) { if !BloomFilterMatch(f, x) || !BloomFilterMatch(x, f) {
t.Fatal("retireved wrong bloom filter") t.Fatal("retireved wrong bloom filter")
} }
} }
// newNodeWithWhisper creates a new node using a default config and
// creates and registers a new Whisper service on it.
func newNodeWithWhisper(t *testing.T) (*node.Node, *Whisper) {
stack, err := node.New(&node.DefaultConfig)
if err != nil {
t.Fatalf("could not create new node: %v", err)
}
w, err := New(stack, &DefaultConfig)
if err != nil {
t.Fatalf("could not create new whisper service: %v", err)
}
err = stack.Start()
if err != nil {
t.Fatalf("could not start node: %v", err)
}
return stack, w
}