diff --git a/cmd/clef/main.go b/cmd/clef/main.go index 801c7e9efd..7dd1900cba 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -592,15 +592,16 @@ func signer(c *cli.Context) error { // start http server httpEndpoint := fmt.Sprintf("%s:%d", c.GlobalString(utils.RPCListenAddrFlag.Name), c.Int(rpcPortFlag.Name)) - listener, err := node.StartHTTPEndpoint(httpEndpoint, rpc.DefaultHTTPTimeouts, handler) + httpServer, addr, err := node.StartHTTPEndpoint(httpEndpoint, rpc.DefaultHTTPTimeouts, handler) if err != nil { utils.Fatalf("Could not start RPC api: %v", err) } - extapiURL = fmt.Sprintf("http://%v/", listener.Addr()) + extapiURL = fmt.Sprintf("http://%v/", addr) log.Info("HTTP endpoint opened", "url", extapiURL) defer func() { - listener.Close() + // Don't bother imposing a timeout here. + httpServer.Shutdown(context.Background()) log.Info("HTTP endpoint closed", "url", extapiURL) }() } diff --git a/cmd/geth/retesteth.go b/cmd/geth/retesteth.go index 102f222ada..7e11ff9513 100644 --- a/cmd/geth/retesteth.go +++ b/cmd/geth/retesteth.go @@ -905,7 +905,7 @@ func retesteth(ctx *cli.Context) error { IdleTimeout: 120 * time.Second, } httpEndpoint := fmt.Sprintf("%s:%d", ctx.GlobalString(utils.RPCListenAddrFlag.Name), ctx.Int(rpcPortFlag.Name)) - listener, err := node.StartHTTPEndpoint(httpEndpoint, RetestethHTTPTimeouts, handler) + httpServer, _, err := node.StartHTTPEndpoint(httpEndpoint, RetestethHTTPTimeouts, handler) if err != nil { utils.Fatalf("Could not start RPC api: %v", err) } @@ -913,7 +913,8 @@ func retesteth(ctx *cli.Context) error { log.Info("HTTP endpoint opened", "url", extapiURL) defer func() { - listener.Close() + // Don't bother imposing a timeout here. + httpServer.Shutdown(context.Background()) log.Info("HTTP endpoint closed", "url", httpEndpoint) }() diff --git a/node/endpoints.go b/node/endpoints.go index 8cd6b4d1c8..1baa1b5c41 100644 --- a/node/endpoints.go +++ b/node/endpoints.go @@ -26,14 +26,14 @@ import ( ) // StartHTTPEndpoint starts the HTTP RPC endpoint. -func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) { +func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (*http.Server, net.Addr, error) { // start the HTTP listener var ( listener net.Listener err error ) if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, err + return nil, nil, err } // make sure timeout values are meaningful CheckTimeouts(&timeouts) @@ -45,22 +45,22 @@ func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http. IdleTimeout: timeouts.IdleTimeout, } go httpSrv.Serve(listener) - return listener, err + return httpSrv, listener.Addr(), err } // startWSEndpoint starts a websocket endpoint. -func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) { +func startWSEndpoint(endpoint string, handler http.Handler) (*http.Server, net.Addr, error) { // start the HTTP listener var ( listener net.Listener err error ) if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, err + return nil, nil, err } wsSrv := &http.Server{Handler: handler} go wsSrv.Serve(listener) - return listener, err + return wsSrv, listener.Addr(), err } // checkModuleAvailability checks that all names given in modules are actually diff --git a/node/node.go b/node/node.go index 1d14317fc1..329ff425b9 100644 --- a/node/node.go +++ b/node/node.go @@ -17,9 +17,11 @@ package node import ( + "context" "errors" "fmt" "net" + "net/http" "os" "path/filepath" "reflect" @@ -59,14 +61,16 @@ type Node struct { ipcListener net.Listener // IPC RPC listener socket to serve API requests ipcHandler *rpc.Server // IPC RPC request handler to process the API requests - httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) - httpWhitelist []string // HTTP RPC modules to allow through this endpoint - httpListener net.Listener // HTTP RPC listener socket to server API requests - httpHandler *rpc.Server // HTTP RPC request handler to process the API requests + httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) + httpWhitelist []string // HTTP RPC modules to allow through this endpoint + httpListenerAddr net.Addr // Address of HTTP RPC listener socket serving API requests + httpServer *http.Server // HTTP RPC HTTP server + httpHandler *rpc.Server // HTTP RPC request handler to process the API requests - wsEndpoint string // Websocket endpoint (interface + port) to listen at (empty = websocket disabled) - wsListener net.Listener // Websocket RPC listener socket to server API requests - wsHandler *rpc.Server // Websocket RPC request handler to process the API requests + wsEndpoint string // WebSocket endpoint (interface + port) to listen at (empty = WebSocket disabled) + wsListenerAddr net.Addr // Address of WebSocket RPC listener socket serving API requests + wsHTTPServer *http.Server // WebSocket RPC HTTP server + wsHandler *rpc.Server // WebSocket RPC request handler to process the API requests stop chan struct{} // Channel to wait for termination notifications lock sync.RWMutex @@ -375,23 +379,24 @@ func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors return err } handler := NewHTTPHandlerStack(srv, cors, vhosts) - // wrap handler in websocket handler only if websocket port is the same as http rpc + // wrap handler in WebSocket handler only if WebSocket port is the same as http rpc if n.httpEndpoint == n.wsEndpoint { handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins)) } - listener, err := StartHTTPEndpoint(endpoint, timeouts, handler) + httpServer, addr, err := StartHTTPEndpoint(endpoint, timeouts, handler) if err != nil { return err } - n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()), + n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", addr), "cors", strings.Join(cors, ","), "vhosts", strings.Join(vhosts, ",")) if n.httpEndpoint == n.wsEndpoint { - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr())) + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) } // All listeners booted successfully n.httpEndpoint = endpoint - n.httpListener = listener + n.httpListenerAddr = addr + n.httpServer = httpServer n.httpHandler = srv return nil @@ -399,11 +404,10 @@ func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors // stopHTTP terminates the HTTP RPC endpoint. func (n *Node) stopHTTP() { - if n.httpListener != nil { - url := fmt.Sprintf("http://%v/", n.httpListener.Addr()) - n.httpListener.Close() - n.httpListener = nil - n.log.Info("HTTP endpoint closed", "url", url) + if n.httpServer != nil { + // Don't bother imposing a timeout here. + n.httpServer.Shutdown(context.Background()) + n.log.Info("HTTP endpoint closed", "url", fmt.Sprintf("http://%v/", n.httpListenerAddr)) } if n.httpHandler != nil { n.httpHandler.Stop() @@ -411,7 +415,7 @@ func (n *Node) stopHTTP() { } } -// startWS initializes and starts the websocket RPC endpoint. +// startWS initializes and starts the WebSocket RPC endpoint. func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string, exposeAll bool) error { // Short circuit if the WS endpoint isn't being exposed if endpoint == "" { @@ -424,26 +428,26 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig if err != nil { return err } - listener, err := startWSEndpoint(endpoint, handler) + httpServer, addr, err := startWSEndpoint(endpoint, handler) if err != nil { return err } - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%s", listener.Addr())) + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) // All listeners booted successfully n.wsEndpoint = endpoint - n.wsListener = listener + n.wsListenerAddr = addr + n.wsHTTPServer = httpServer n.wsHandler = srv return nil } -// stopWS terminates the websocket RPC endpoint. +// stopWS terminates the WebSocket RPC endpoint. func (n *Node) stopWS() { - if n.wsListener != nil { - n.wsListener.Close() - n.wsListener = nil - - n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%s", n.wsEndpoint)) + if n.wsHTTPServer != nil { + // Don't bother imposing a timeout here. + n.wsHTTPServer.Shutdown(context.Background()) + n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%v", n.wsListenerAddr)) } if n.wsHandler != nil { n.wsHandler.Stop() @@ -607,8 +611,8 @@ func (n *Node) HTTPEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.httpListener != nil { - return n.httpListener.Addr().String() + if n.httpListenerAddr != nil { + return n.httpListenerAddr.String() } return n.httpEndpoint } @@ -618,8 +622,8 @@ func (n *Node) WSEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.wsListener != nil { - return n.wsListener.Addr().String() + if n.wsListenerAddr != nil { + return n.wsListenerAddr.String() } return n.wsEndpoint }