rpc: enforce the 128KB request limits on websockets too
This commit is contained in:
parent
6a2d2869f6
commit
555f42cfd8
@ -27,16 +27,16 @@ import (
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/cors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
maxHTTPRequestContentLength = 1024 * 128
|
||||
maxRequestContentLength = 1024 * 128
|
||||
)
|
||||
|
||||
var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||
@ -182,8 +182,8 @@ func validateRequest(r *http.Request) (int, error) {
|
||||
if r.Method == http.MethodPut || r.Method == http.MethodDelete {
|
||||
return http.StatusMethodNotAllowed, errors.New("method not allowed")
|
||||
}
|
||||
if r.ContentLength > maxHTTPRequestContentLength {
|
||||
err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength)
|
||||
if r.ContentLength > maxRequestContentLength {
|
||||
err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength)
|
||||
return http.StatusRequestEntityTooLarge, err
|
||||
}
|
||||
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
|
||||
|
@ -32,7 +32,7 @@ func TestHTTPErrorResponseWithPut(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
|
||||
body := make([]rune, maxHTTPRequestContentLength+1)
|
||||
body := make([]rune, maxRequestContentLength+1)
|
||||
testHTTPErrorResponse(t,
|
||||
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
|
40
rpc/json.go
40
rpc/json.go
@ -78,10 +78,10 @@ type jsonNotification struct {
|
||||
type jsonCodec struct {
|
||||
closer sync.Once // close closed channel once
|
||||
closed chan interface{} // closed on Close
|
||||
decMu sync.Mutex // guards d
|
||||
d *json.Decoder // decodes incoming requests
|
||||
encMu sync.Mutex // guards e
|
||||
e *json.Encoder // encodes responses
|
||||
decMu sync.Mutex // guards the decoder
|
||||
decode func(v interface{}) error // decoder to allow multiple transports
|
||||
encMu sync.Mutex // guards the encoder
|
||||
encode func(v interface{}) error // encoder to allow multiple transports
|
||||
rw io.ReadWriteCloser // connection
|
||||
}
|
||||
|
||||
@ -96,11 +96,29 @@ func (err *jsonError) ErrorCode() int {
|
||||
return err.Code
|
||||
}
|
||||
|
||||
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0
|
||||
// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
|
||||
// on explicitly given encoding and decoding methods.
|
||||
func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec {
|
||||
return &jsonCodec{
|
||||
closed: make(chan interface{}),
|
||||
encode: encode,
|
||||
decode: decode,
|
||||
rw: rwc,
|
||||
}
|
||||
}
|
||||
|
||||
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
|
||||
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec {
|
||||
d := json.NewDecoder(rwc)
|
||||
d.UseNumber()
|
||||
return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc}
|
||||
enc := json.NewEncoder(rwc)
|
||||
dec := json.NewDecoder(rwc)
|
||||
dec.UseNumber()
|
||||
|
||||
return &jsonCodec{
|
||||
closed: make(chan interface{}),
|
||||
encode: enc.Encode,
|
||||
decode: dec.Decode,
|
||||
rw: rwc,
|
||||
}
|
||||
}
|
||||
|
||||
// isBatch returns true when the first non-whitespace characters is '['
|
||||
@ -123,14 +141,12 @@ func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) {
|
||||
defer c.decMu.Unlock()
|
||||
|
||||
var incomingMsg json.RawMessage
|
||||
if err := c.d.Decode(&incomingMsg); err != nil {
|
||||
if err := c.decode(&incomingMsg); err != nil {
|
||||
return nil, false, &invalidRequestError{err.Error()}
|
||||
}
|
||||
|
||||
if isBatch(incomingMsg) {
|
||||
return parseBatchRequest(incomingMsg)
|
||||
}
|
||||
|
||||
return parseRequest(incomingMsg)
|
||||
}
|
||||
|
||||
@ -338,7 +354,7 @@ func (c *jsonCodec) Write(res interface{}) error {
|
||||
c.encMu.Lock()
|
||||
defer c.encMu.Unlock()
|
||||
|
||||
return c.e.Encode(res)
|
||||
return c.encode(res)
|
||||
}
|
||||
|
||||
// Close the underlying connection
|
||||
|
@ -17,8 +17,10 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -32,6 +34,23 @@ import (
|
||||
"gopkg.in/fatih/set.v0"
|
||||
)
|
||||
|
||||
// websocketJSONCodec is a custom JSON codec with payload size enforcement and
|
||||
// special number parsing.
|
||||
var websocketJSONCodec = websocket.Codec{
|
||||
// Marshal is the stock JSON marshaller used by the websocket library too.
|
||||
Marshal: func(v interface{}) ([]byte, byte, error) {
|
||||
msg, err := json.Marshal(v)
|
||||
return msg, websocket.TextFrame, err
|
||||
},
|
||||
// Unmarshal is a specialized unmarshaller to properly convert numbers.
|
||||
Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
|
||||
dec := json.NewDecoder(bytes.NewReader(msg))
|
||||
dec.UseNumber()
|
||||
|
||||
return dec.Decode(v)
|
||||
},
|
||||
}
|
||||
|
||||
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
||||
//
|
||||
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
||||
@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
||||
return websocket.Server{
|
||||
Handshake: wsHandshakeValidator(allowedOrigins),
|
||||
Handler: func(conn *websocket.Conn) {
|
||||
srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
|
||||
// Create a custom encode/decode pair to enforce payload size and number encoding
|
||||
conn.MaxPayloadBytes = maxRequestContentLength
|
||||
|
||||
encoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Send(conn, v)
|
||||
}
|
||||
decoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Receive(conn, v)
|
||||
}
|
||||
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user