diff --git a/rpc/http.go b/rpc/http.go index a46d8c2b39..9805d69b63 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -27,16 +27,16 @@ import ( "mime" "net" "net/http" + "strings" "sync" "time" "github.com/rs/cors" - "strings" ) const ( - contentType = "application/json" - maxHTTPRequestContentLength = 1024 * 128 + contentType = "application/json" + maxRequestContentLength = 1024 * 128 ) var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0") @@ -182,8 +182,8 @@ func validateRequest(r *http.Request) (int, error) { if r.Method == http.MethodPut || r.Method == http.MethodDelete { return http.StatusMethodNotAllowed, errors.New("method not allowed") } - if r.ContentLength > maxHTTPRequestContentLength { - err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength) + if r.ContentLength > maxRequestContentLength { + err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength) return http.StatusRequestEntityTooLarge, err } mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")) diff --git a/rpc/http_test.go b/rpc/http_test.go index aed84f6835..b3f694d8af 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -32,7 +32,7 @@ func TestHTTPErrorResponseWithPut(t *testing.T) { } func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { - body := make([]rune, maxHTTPRequestContentLength+1) + body := make([]rune, maxRequestContentLength+1) testHTTPErrorResponse(t, http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) } diff --git a/rpc/json.go b/rpc/json.go index 2e7fd599e2..837011f51b 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -76,13 +76,13 @@ type jsonNotification struct { // jsonCodec reads and writes JSON-RPC messages to the underlying connection. It // also has support for parsing arguments and serializing (result) objects. type jsonCodec struct { - closer sync.Once // close closed channel once - closed chan interface{} // closed on Close - decMu sync.Mutex // guards d - d *json.Decoder // decodes incoming requests - encMu sync.Mutex // guards e - e *json.Encoder // encodes responses - rw io.ReadWriteCloser // connection + closer sync.Once // close closed channel once + closed chan interface{} // closed on Close + decMu sync.Mutex // guards the decoder + decode func(v interface{}) error // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode func(v interface{}) error // encoder to allow multiple transports + rw io.ReadWriteCloser // connection } func (err *jsonError) Error() string { @@ -96,11 +96,29 @@ func (err *jsonError) ErrorCode() int { return err.Code } -// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0 +// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based +// on explicitly given encoding and decoding methods. +func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec { + return &jsonCodec{ + closed: make(chan interface{}), + encode: encode, + decode: decode, + rw: rwc, + } +} + +// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0. func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec { - d := json.NewDecoder(rwc) - d.UseNumber() - return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc} + enc := json.NewEncoder(rwc) + dec := json.NewDecoder(rwc) + dec.UseNumber() + + return &jsonCodec{ + closed: make(chan interface{}), + encode: enc.Encode, + decode: dec.Decode, + rw: rwc, + } } // isBatch returns true when the first non-whitespace characters is '[' @@ -123,14 +141,12 @@ func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) { defer c.decMu.Unlock() var incomingMsg json.RawMessage - if err := c.d.Decode(&incomingMsg); err != nil { + if err := c.decode(&incomingMsg); err != nil { return nil, false, &invalidRequestError{err.Error()} } - if isBatch(incomingMsg) { return parseBatchRequest(incomingMsg) } - return parseRequest(incomingMsg) } @@ -338,7 +354,7 @@ func (c *jsonCodec) Write(res interface{}) error { c.encMu.Lock() defer c.encMu.Unlock() - return c.e.Encode(res) + return c.encode(res) } // Close the underlying connection diff --git a/rpc/websocket.go b/rpc/websocket.go index 4214fc86a0..a6e1cec28a 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -17,8 +17,10 @@ package rpc import ( + "bytes" "context" "crypto/tls" + "encoding/json" "fmt" "net" "net/http" @@ -32,6 +34,23 @@ import ( "gopkg.in/fatih/set.v0" ) +// websocketJSONCodec is a custom JSON codec with payload size enforcement and +// special number parsing. +var websocketJSONCodec = websocket.Codec{ + // Marshal is the stock JSON marshaller used by the websocket library too. + Marshal: func(v interface{}) ([]byte, byte, error) { + msg, err := json.Marshal(v) + return msg, websocket.TextFrame, err + }, + // Unmarshal is a specialized unmarshaller to properly convert numbers. + Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { + dec := json.NewDecoder(bytes.NewReader(msg)) + dec.UseNumber() + + return dec.Decode(v) + }, +} + // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. @@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { return websocket.Server{ Handshake: wsHandshakeValidator(allowedOrigins), Handler: func(conn *websocket.Conn) { - srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) + // Create a custom encode/decode pair to enforce payload size and number encoding + conn.MaxPayloadBytes = maxRequestContentLength + + encoder := func(v interface{}) error { + return websocketJSONCodec.Send(conn, v) + } + decoder := func(v interface{}) error { + return websocketJSONCodec.Receive(conn, v) + } + srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) }, } }