rpc: add limit for batch request items and response size (#26681)
This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches were limited only by processing time. The server would pick calls from the batch and answer them until the response timeout occurred, then stop processing the remaining batch items. Here, we are adding two additional limits which can be configured: - the 'item limit': batches can have at most N items - the 'response size limit': batches can contain at most X response bytes These limits are optional in package rpc. In Geth, we set a default limit of 1000 items and 25MB response size. When a batch goes over the limit, an error response is returned to the client. However, doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid `id` can be responded to. Since batches may also contain non-call messages or notifications, the best effort thing we can do to report an error with the batch itself is reporting the limit violation as an error for the first method call in the batch. If a batch is too large, but contains only notifications and responses, the error will be reported with a null `id`. The RPC client was also changed so it can deal with errors resulting from too large batches. An older client connected to the server code in this PR could get stuck until the request timeout occurred when the batch is too large. **Upgrading to a version of the RPC client containing this change is strongly recommended to avoid timeout issues.** For some weird reason, when writing the original client implementation, @fjl worked off of the assumption that responses could be distributed across batches arbitrarily. So for a batch request containing requests `[A B C]`, the server could respond with `[A B C]` but also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the client. So in the implementation of BatchCallContext, the client waited for all requests in the batch individually. If the server didn't respond to some of the requests in the batch, the client would eventually just time out (if a context was used). With the addition of batch limits into the server, we anticipate that people will hit this kind of error way more often. To handle this properly, the client now waits for a single response batch and expects it to contain all responses to the requests. --------- Co-authored-by: Felix Lange <fjl@twurst.com> Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
5ac4da3653
commit
f3314bb6df
@ -732,6 +732,7 @@ func signer(c *cli.Context) error {
|
||||
cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name))
|
||||
|
||||
srv := rpc.NewServer()
|
||||
srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize)
|
||||
err := node.RegisterApis(rpcAPI, []string{"account"}, srv)
|
||||
if err != nil {
|
||||
utils.Fatalf("Could not register API: %w", err)
|
||||
|
@ -168,6 +168,8 @@ var (
|
||||
utils.RPCGlobalEVMTimeoutFlag,
|
||||
utils.RPCGlobalTxFeeCapFlag,
|
||||
utils.AllowUnprotectedTxs,
|
||||
utils.BatchRequestLimit,
|
||||
utils.BatchResponseMaxSize,
|
||||
}
|
||||
|
||||
metricsFlags = []cli.Flag{
|
||||
|
@ -713,6 +713,18 @@ var (
|
||||
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
|
||||
Category: flags.APICategory,
|
||||
}
|
||||
BatchRequestLimit = &cli.IntFlag{
|
||||
Name: "rpc.batch-request-limit",
|
||||
Usage: "Maximum number of requests in a batch",
|
||||
Value: node.DefaultConfig.BatchRequestLimit,
|
||||
Category: flags.APICategory,
|
||||
}
|
||||
BatchResponseMaxSize = &cli.IntFlag{
|
||||
Name: "rpc.batch-response-max-size",
|
||||
Usage: "Maximum number of bytes returned from a batched call",
|
||||
Value: node.DefaultConfig.BatchResponseMaxSize,
|
||||
Category: flags.APICategory,
|
||||
}
|
||||
EnablePersonal = &cli.BoolFlag{
|
||||
Name: "rpc.enabledeprecatedpersonal",
|
||||
Usage: "Enables the (deprecated) personal namespace",
|
||||
@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
|
||||
if ctx.IsSet(AllowUnprotectedTxs.Name) {
|
||||
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
|
||||
}
|
||||
|
||||
if ctx.IsSet(BatchRequestLimit.Name) {
|
||||
cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name)
|
||||
}
|
||||
|
||||
if ctx.IsSet(BatchResponseMaxSize.Name) {
|
||||
cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// setGraphQL creates the GraphQL listener interface string from the set
|
||||
|
@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri
|
||||
CorsAllowedOrigins: api.node.config.HTTPCors,
|
||||
Vhosts: api.node.config.HTTPVirtualHosts,
|
||||
Modules: api.node.config.HTTPModules,
|
||||
rpcEndpointConfig: rpcEndpointConfig{
|
||||
batchItemLimit: api.node.config.BatchRequestLimit,
|
||||
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
|
||||
},
|
||||
}
|
||||
if cors != nil {
|
||||
config.CorsAllowedOrigins = nil
|
||||
@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap
|
||||
Modules: api.node.config.WSModules,
|
||||
Origins: api.node.config.WSOrigins,
|
||||
// ExposeAll: api.node.config.WSExposeAll,
|
||||
rpcEndpointConfig: rpcEndpointConfig{
|
||||
batchItemLimit: api.node.config.BatchRequestLimit,
|
||||
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
|
||||
},
|
||||
}
|
||||
if apis != nil {
|
||||
config.Modules = nil
|
||||
|
@ -197,6 +197,12 @@ type Config struct {
|
||||
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
|
||||
AllowUnprotectedTxs bool `toml:",omitempty"`
|
||||
|
||||
// BatchRequestLimit is the maximum number of requests in a batch.
|
||||
BatchRequestLimit int `toml:",omitempty"`
|
||||
|
||||
// BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call.
|
||||
BatchResponseMaxSize int `toml:",omitempty"`
|
||||
|
||||
// JWTSecret is the path to the hex-encoded jwt secret.
|
||||
JWTSecret string `toml:",omitempty"`
|
||||
|
||||
|
@ -56,6 +56,8 @@ var DefaultConfig = Config{
|
||||
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
|
||||
WSPort: DefaultWSPort,
|
||||
WSModules: []string{"net", "web3"},
|
||||
BatchRequestLimit: 1000,
|
||||
BatchResponseMaxSize: 25 * 1000 * 1000,
|
||||
GraphQLVirtualHosts: []string{"localhost"},
|
||||
P2P: p2p.Config{
|
||||
ListenAddr: ":30303",
|
||||
|
19
node/node.go
19
node/node.go
@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) {
|
||||
if strings.HasSuffix(conf.Name, ".ipc") {
|
||||
return nil, errors.New(`Config.Name cannot end in ".ipc"`)
|
||||
}
|
||||
|
||||
server := rpc.NewServer()
|
||||
server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize)
|
||||
node := &Node{
|
||||
config: conf,
|
||||
inprocHandler: rpc.NewServer(),
|
||||
inprocHandler: server,
|
||||
eventmux: new(event.TypeMux),
|
||||
log: conf.Logger,
|
||||
stop: make(chan struct{}),
|
||||
@ -403,6 +404,11 @@ func (n *Node) startRPC() error {
|
||||
openAPIs, allAPIs = n.getAPIs()
|
||||
)
|
||||
|
||||
rpcConfig := rpcEndpointConfig{
|
||||
batchItemLimit: n.config.BatchRequestLimit,
|
||||
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
|
||||
}
|
||||
|
||||
initHttp := func(server *httpServer, port int) error {
|
||||
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
|
||||
return err
|
||||
@ -412,6 +418,7 @@ func (n *Node) startRPC() error {
|
||||
Vhosts: n.config.HTTPVirtualHosts,
|
||||
Modules: n.config.HTTPModules,
|
||||
prefix: n.config.HTTPPathPrefix,
|
||||
rpcEndpointConfig: rpcConfig,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -428,6 +435,7 @@ func (n *Node) startRPC() error {
|
||||
Modules: n.config.WSModules,
|
||||
Origins: n.config.WSOrigins,
|
||||
prefix: n.config.WSPathPrefix,
|
||||
rpcEndpointConfig: rpcConfig,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -441,16 +449,19 @@ func (n *Node) startRPC() error {
|
||||
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
||||
return err
|
||||
}
|
||||
sharedConfig := rpcConfig
|
||||
sharedConfig.jwtSecret = secret
|
||||
if err := server.enableRPC(allAPIs, httpConfig{
|
||||
CorsAllowedOrigins: DefaultAuthCors,
|
||||
Vhosts: n.config.AuthVirtualHosts,
|
||||
Modules: DefaultAuthModules,
|
||||
prefix: DefaultAuthPrefix,
|
||||
jwtSecret: secret,
|
||||
rpcEndpointConfig: sharedConfig,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
servers = append(servers, server)
|
||||
|
||||
// Enable auth via WS
|
||||
server = n.wsServerForPort(port, true)
|
||||
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
|
||||
@ -460,7 +471,7 @@ func (n *Node) startRPC() error {
|
||||
Modules: DefaultAuthModules,
|
||||
Origins: DefaultAuthOrigins,
|
||||
prefix: DefaultAuthPrefix,
|
||||
jwtSecret: secret,
|
||||
rpcEndpointConfig: sharedConfig,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ type httpConfig struct {
|
||||
CorsAllowedOrigins []string
|
||||
Vhosts []string
|
||||
prefix string // path prefix on which to mount http handler
|
||||
jwtSecret []byte // optional JWT secret
|
||||
rpcEndpointConfig
|
||||
}
|
||||
|
||||
// wsConfig is the JSON-RPC/Websocket configuration
|
||||
@ -49,7 +49,13 @@ type wsConfig struct {
|
||||
Origins []string
|
||||
Modules []string
|
||||
prefix string // path prefix on which to mount ws handler
|
||||
rpcEndpointConfig
|
||||
}
|
||||
|
||||
type rpcEndpointConfig struct {
|
||||
jwtSecret []byte // optional JWT secret
|
||||
batchItemLimit int
|
||||
batchResponseSizeLimit int
|
||||
}
|
||||
|
||||
type rpcHandler struct {
|
||||
@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
|
||||
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer()
|
||||
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
|
||||
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
|
||||
}
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer()
|
||||
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
|
||||
if err := RegisterApis(apis, config.Modules, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -339,8 +339,10 @@ func TestJWT(t *testing.T) {
|
||||
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
|
||||
return ss
|
||||
}
|
||||
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
|
||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
|
||||
cfg := rpcEndpointConfig{jwtSecret: []byte("secret")}
|
||||
httpcfg := &httpConfig{rpcEndpointConfig: cfg}
|
||||
wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg}
|
||||
srv := createAndStartServer(t, httpcfg, true, wscfg, nil)
|
||||
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||
|
||||
|
103
rpc/client.go
103
rpc/client.go
@ -34,14 +34,15 @@ import (
|
||||
var (
|
||||
ErrBadResult = errors.New("bad result in JSON-RPC response")
|
||||
ErrClientQuit = errors.New("client is closed")
|
||||
ErrNoResult = errors.New("no result in JSON-RPC response")
|
||||
ErrNoResult = errors.New("JSON-RPC response has no result")
|
||||
ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call")
|
||||
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
|
||||
errClientReconnected = errors.New("client reconnected")
|
||||
errDead = errors.New("connection lost")
|
||||
)
|
||||
|
||||
const (
|
||||
// Timeouts
|
||||
const (
|
||||
defaultDialTimeout = 10 * time.Second // used if context has no deadline
|
||||
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
|
||||
)
|
||||
@ -84,6 +85,10 @@ type Client struct {
|
||||
// This function, if non-nil, is called when the connection is lost.
|
||||
reconnectFunc reconnectFunc
|
||||
|
||||
// config fields
|
||||
batchItemLimit int
|
||||
batchResponseMaxSize int
|
||||
|
||||
// writeConn is used for writing to the connection on the caller's goroutine. It should
|
||||
// only be accessed outside of dispatch, with the write lock held. The write lock is
|
||||
// taken by sending on reqInit and released by sending on reqSent.
|
||||
@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, clientContextKey{}, c)
|
||||
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services)
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize)
|
||||
return &clientConn{conn, handler}
|
||||
}
|
||||
|
||||
@ -128,14 +133,17 @@ type readOp struct {
|
||||
batch bool
|
||||
}
|
||||
|
||||
// requestOp represents a pending request. This is used for both batch and non-batch
|
||||
// requests.
|
||||
type requestOp struct {
|
||||
ids []json.RawMessage
|
||||
err error
|
||||
resp chan *jsonrpcMessage // receives up to len(ids) responses
|
||||
sub *ClientSubscription // only set for EthSubscribe requests
|
||||
resp chan []*jsonrpcMessage // the response goes here
|
||||
sub *ClientSubscription // set for Subscribe requests.
|
||||
hadResponse bool // true when the request was responded to
|
||||
}
|
||||
|
||||
func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
|
||||
func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Send the timeout to dispatch so it can remove the request IDs.
|
||||
@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*
|
||||
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
|
||||
}
|
||||
|
||||
return newClient(ctx, reconnect)
|
||||
return newClient(ctx, cfg, reconnect)
|
||||
}
|
||||
|
||||
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
|
||||
@ -221,22 +229,24 @@ func ClientFromContext(ctx context.Context) (*Client, bool) {
|
||||
return client, ok
|
||||
}
|
||||
|
||||
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
|
||||
func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) {
|
||||
conn, err := connect(initctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
|
||||
c := initClient(conn, new(serviceRegistry), cfg)
|
||||
c.reconnectFunc = connect
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
|
||||
func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client {
|
||||
_, isHTTP := conn.(*httpConn)
|
||||
c := &Client{
|
||||
isHTTP: isHTTP,
|
||||
idgen: idgen,
|
||||
services: services,
|
||||
idgen: cfg.idgen,
|
||||
batchItemLimit: cfg.batchItemLimit,
|
||||
batchResponseMaxSize: cfg.batchResponseLimit,
|
||||
writeConn: conn,
|
||||
close: make(chan struct{}),
|
||||
closing: make(chan struct{}),
|
||||
@ -248,6 +258,13 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C
|
||||
reqSent: make(chan error, 1),
|
||||
reqTimeout: make(chan *requestOp),
|
||||
}
|
||||
|
||||
// Set defaults.
|
||||
if c.idgen == nil {
|
||||
c.idgen = randomIDGenerator()
|
||||
}
|
||||
|
||||
// Launch the main loop.
|
||||
if !isHTTP {
|
||||
go c.dispatch(conn)
|
||||
}
|
||||
@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
|
||||
op := &requestOp{
|
||||
ids: []json.RawMessage{msg.ID},
|
||||
resp: make(chan []*jsonrpcMessage, 1),
|
||||
}
|
||||
|
||||
if c.isHTTP {
|
||||
err = c.sendHTTP(ctx, op, msg)
|
||||
@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
|
||||
}
|
||||
|
||||
// dispatch has accepted the request and will close the channel when it quits.
|
||||
switch resp, err := op.wait(ctx, c); {
|
||||
case err != nil:
|
||||
batchresp, err := op.wait(ctx, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp := batchresp[0]
|
||||
switch {
|
||||
case resp.Error != nil:
|
||||
return resp.Error
|
||||
case len(resp.Result) == 0:
|
||||
@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
|
||||
)
|
||||
op := &requestOp{
|
||||
ids: make([]json.RawMessage, len(b)),
|
||||
resp: make(chan *jsonrpcMessage, len(b)),
|
||||
resp: make(chan []*jsonrpcMessage, 1),
|
||||
}
|
||||
for i, elem := range b {
|
||||
msg, err := c.newMessage(elem.Method, elem.Args...)
|
||||
@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
|
||||
} else {
|
||||
err = c.send(ctx, op, msgs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batchresp, err := op.wait(ctx, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for all responses to come back.
|
||||
for n := 0; n < len(b) && err == nil; n++ {
|
||||
var resp *jsonrpcMessage
|
||||
resp, err = op.wait(ctx, c)
|
||||
if err != nil {
|
||||
break
|
||||
for n := 0; n < len(batchresp) && err == nil; n++ {
|
||||
resp := batchresp[n]
|
||||
if resp == nil {
|
||||
// Ignore null responses. These can happen for batches sent via HTTP.
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the element corresponding to this response.
|
||||
// The element is guaranteed to be present because dispatch
|
||||
// only sends valid IDs to our channel.
|
||||
elem := &b[byID[string(resp.ID)]]
|
||||
if resp.Error != nil {
|
||||
index, ok := byID[string(resp.ID)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(byID, string(resp.ID))
|
||||
|
||||
// Assign result and error.
|
||||
elem := &b[index]
|
||||
switch {
|
||||
case resp.Error != nil:
|
||||
elem.Error = resp.Error
|
||||
continue
|
||||
}
|
||||
if len(resp.Result) == 0 {
|
||||
case resp.Result == nil:
|
||||
elem.Error = ErrNoResult
|
||||
continue
|
||||
}
|
||||
default:
|
||||
elem.Error = json.Unmarshal(resp.Result, elem.Result)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all expected responses have been received.
|
||||
for _, index := range byID {
|
||||
elem := &b[index]
|
||||
elem.Error = ErrMissingBatchResponse
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
|
||||
}
|
||||
op := &requestOp{
|
||||
ids: []json.RawMessage{msg.ID},
|
||||
resp: make(chan *jsonrpcMessage),
|
||||
resp: make(chan []*jsonrpcMessage, 1),
|
||||
sub: newClientSubscription(c, namespace, chanVal),
|
||||
}
|
||||
|
||||
|
@ -28,11 +28,18 @@ type ClientOption interface {
|
||||
}
|
||||
|
||||
type clientConfig struct {
|
||||
// HTTP settings
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
httpAuth HTTPAuth
|
||||
|
||||
// WebSocket options
|
||||
wsDialer *websocket.Dialer
|
||||
|
||||
// RPC handler options
|
||||
idgen func() ID
|
||||
batchItemLimit int
|
||||
batchResponseLimit int
|
||||
}
|
||||
|
||||
func (cfg *clientConfig) initHeaders() {
|
||||
@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption {
|
||||
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
|
||||
// auth information to the request.
|
||||
type HTTPAuth func(h http.Header) error
|
||||
|
||||
// WithBatchItemLimit changes the maximum number of items allowed in batch requests.
|
||||
//
|
||||
// Note: this option applies when processing incoming batch requests. It does not affect
|
||||
// batch requests sent by the client.
|
||||
func WithBatchItemLimit(limit int) ClientOption {
|
||||
return optionFunc(func(cfg *clientConfig) {
|
||||
cfg.batchItemLimit = limit
|
||||
})
|
||||
}
|
||||
|
||||
// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be
|
||||
// generated for batch requests. When this limit is reached, further calls in the batch
|
||||
// will not be processed.
|
||||
//
|
||||
// Note: this option applies when processing incoming batch requests. It does not affect
|
||||
// batch requests sent by the client.
|
||||
func WithBatchResponseSizeLimit(sizeLimit int) ClientOption {
|
||||
return optionFunc(func(cfg *clientConfig) {
|
||||
cfg.batchResponseLimit = sizeLimit
|
||||
})
|
||||
}
|
||||
|
@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// This checks that, for HTTP connections, the length of batch responses is validated to
|
||||
// match the request exactly.
|
||||
func TestClientBatchRequest_len(t *testing.T) {
|
||||
b, err := json.Marshal([]jsonrpcMessage{
|
||||
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)},
|
||||
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)},
|
||||
{Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)},
|
||||
{Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode jsonrpc message:", err)
|
||||
@ -185,35 +187,100 @@ func TestClientBatchRequest_len(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
t.Run("too-few", func(t *testing.T) {
|
||||
client, err := Dial(s.URL)
|
||||
if err != nil {
|
||||
t.Fatal("failed to dial test server:", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
t.Run("too-few", func(t *testing.T) {
|
||||
batch := []BatchElem{
|
||||
{Method: "foo", Result: new(string)},
|
||||
{Method: "bar", Result: new(string)},
|
||||
{Method: "baz", Result: new(string)},
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
if err := client.BatchCallContext(ctx, batch); err != nil {
|
||||
t.Fatal("error:", err)
|
||||
}
|
||||
for i, elem := range batch[:2] {
|
||||
if elem.Error != nil {
|
||||
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
|
||||
}
|
||||
}
|
||||
for i, elem := range batch[2:] {
|
||||
if elem.Error != ErrMissingBatchResponse {
|
||||
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too-many", func(t *testing.T) {
|
||||
client, err := Dial(s.URL)
|
||||
if err != nil {
|
||||
t.Fatal("failed to dial test server:", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
batch := []BatchElem{
|
||||
{Method: "foo", Result: new(string)},
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
if err := client.BatchCallContext(ctx, batch); err != nil {
|
||||
t.Fatal("error:", err)
|
||||
}
|
||||
for i, elem := range batch[:1] {
|
||||
if elem.Error != nil {
|
||||
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
|
||||
}
|
||||
}
|
||||
for i, elem := range batch[1:] {
|
||||
if elem.Error != ErrMissingBatchResponse {
|
||||
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// This checks that the client can handle the case where the server doesn't
|
||||
// respond to all requests in a batch.
|
||||
func TestClientBatchRequestLimit(t *testing.T) {
|
||||
server := newTestServer()
|
||||
defer server.Stop()
|
||||
server.SetBatchLimits(2, 100000)
|
||||
client := DialInProc(server)
|
||||
|
||||
batch := []BatchElem{
|
||||
{Method: "foo"},
|
||||
{Method: "bar"},
|
||||
{Method: "baz"},
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancelFn()
|
||||
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
|
||||
t.Errorf("expected %q but got: %v", ErrBadResult, err)
|
||||
err := client.BatchCall(batch)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too-many", func(t *testing.T) {
|
||||
batch := []BatchElem{
|
||||
{Method: "foo"},
|
||||
// Check that the first response indicates an error with batch size.
|
||||
var err0 Error
|
||||
if !errors.As(batch[0].Error, &err0) {
|
||||
t.Log("error zero:", batch[0].Error)
|
||||
t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error)
|
||||
} else {
|
||||
if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge {
|
||||
t.Fatalf("wrong error on batch elem zero: %v", err0)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that remaining response batch elements are reported as absent.
|
||||
for i, elem := range batch[1:] {
|
||||
if elem.Error != ErrMissingBatchResponse {
|
||||
t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error)
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancelFn()
|
||||
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
|
||||
t.Errorf("expected %q but got: %v", ErrBadResult, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientNotify(t *testing.T) {
|
||||
@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) {
|
||||
defer srv.Stop()
|
||||
|
||||
// Create the client on the other end of the pipe.
|
||||
client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) {
|
||||
cfg := new(clientConfig)
|
||||
client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) {
|
||||
return NewCodec(p2), nil
|
||||
})
|
||||
defer client.Close()
|
||||
|
@ -61,12 +61,15 @@ const (
|
||||
errcodeDefault = -32000
|
||||
errcodeNotificationsUnsupported = -32001
|
||||
errcodeTimeout = -32002
|
||||
errcodeResponseTooLarge = -32003
|
||||
errcodePanic = -32603
|
||||
errcodeMarshalError = -32603
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgTimeout = "request timed out"
|
||||
errMsgResponseTooLarge = "response too large"
|
||||
errMsgBatchTooLarge = "batch too large"
|
||||
)
|
||||
|
||||
type methodNotFoundError struct{ method string }
|
||||
|
165
rpc/handler.go
165
rpc/handler.go
@ -60,6 +60,8 @@ type handler struct {
|
||||
conn jsonWriter // where responses will be sent
|
||||
log log.Logger
|
||||
allowSubscribe bool
|
||||
batchRequestLimit int
|
||||
batchResponseMaxSize int
|
||||
|
||||
subLock sync.Mutex
|
||||
serverSubs map[ID]*Subscription
|
||||
@ -70,7 +72,7 @@ type callProc struct {
|
||||
notifiers []*Notifier
|
||||
}
|
||||
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler {
|
||||
rootCtx, cancelRoot := context.WithCancel(connCtx)
|
||||
h := &handler{
|
||||
reg: reg,
|
||||
@ -83,6 +85,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
allowSubscribe: true,
|
||||
serverSubs: make(map[ID]*Subscription),
|
||||
log: log.Root(),
|
||||
batchRequestLimit: batchRequestLimit,
|
||||
batchResponseMaxSize: batchResponseMaxSize,
|
||||
}
|
||||
if conn.remoteAddr() != "" {
|
||||
h.log = h.log.New("conn", conn.remoteAddr())
|
||||
@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
|
||||
b.doWrite(ctx, conn, false)
|
||||
}
|
||||
|
||||
// timeout sends the responses added so far. For the remaining unanswered call
|
||||
// messages, it sends a timeout error response.
|
||||
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
|
||||
// respondWithError sends the responses added so far. For the remaining unanswered call
|
||||
// messages, it responds with the given error.
|
||||
func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
for _, msg := range b.calls {
|
||||
if !msg.isNotification() {
|
||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||
b.resp = append(b.resp, resp)
|
||||
b.resp = append(b.resp, msg.errorResponse(err))
|
||||
}
|
||||
}
|
||||
b.doWrite(ctx, conn, true)
|
||||
@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Apply limit on total number of requests.
|
||||
if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit {
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
h.respondWithBatchTooLarge(cp, msgs)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle non-call messages first:
|
||||
// Handle non-call messages first.
|
||||
// Here we need to find the requestOp that sent the request batch.
|
||||
calls := make([]*jsonrpcMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if handled := h.handleImmediate(msg); !handled {
|
||||
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
|
||||
calls = append(calls, msg)
|
||||
}
|
||||
}
|
||||
})
|
||||
if len(calls) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Process calls on a goroutine because they may block indefinitely:
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
var (
|
||||
@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
cancel()
|
||||
callBuffer.timeout(cp.ctx, h.conn)
|
||||
err := &internalServerError{errcodeTimeout, errMsgTimeout}
|
||||
callBuffer.respondWithError(cp.ctx, h.conn, err)
|
||||
})
|
||||
}
|
||||
|
||||
responseBytes := 0
|
||||
for {
|
||||
// No need to handle rest of calls if timed out.
|
||||
if cp.ctx.Err() != nil {
|
||||
@ -214,24 +226,52 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
}
|
||||
resp := h.handleCallMsg(cp, msg)
|
||||
callBuffer.pushResponse(resp)
|
||||
if resp != nil && h.batchResponseMaxSize != 0 {
|
||||
responseBytes += len(resp.Result)
|
||||
if responseBytes > h.batchResponseMaxSize {
|
||||
err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge}
|
||||
callBuffer.respondWithError(cp.ctx, h.conn, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
callBuffer.write(cp.ctx, h.conn)
|
||||
|
||||
h.addSubscriptions(cp.notifiers)
|
||||
callBuffer.write(cp.ctx, h.conn)
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// handleMsg handles a single message.
|
||||
func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||
if ok := h.handleImmediate(msg); ok {
|
||||
return
|
||||
func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) {
|
||||
resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge})
|
||||
// Find the first call and add its "id" field to the error.
|
||||
// This is the best we can do, given that the protocol doesn't have a way
|
||||
// of reporting an error for the entire batch.
|
||||
for _, msg := range batch {
|
||||
if msg.isCall() {
|
||||
resp.ID = msg.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true)
|
||||
}
|
||||
|
||||
// handleMsg handles a single non-batch message.
|
||||
func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||
msgs := []*jsonrpcMessage{msg}
|
||||
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
h.handleNonBatchCall(cp, msg)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) {
|
||||
var (
|
||||
responded sync.Once
|
||||
timer *time.Timer
|
||||
@ -266,7 +306,6 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// close cancels all requests except for inflightReq and waits for
|
||||
@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) {
|
||||
}()
|
||||
}
|
||||
|
||||
// handleImmediate executes non-call messages. It returns false if the message is a
|
||||
// call or requires a reply.
|
||||
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
|
||||
// handleResponse processes method call responses.
|
||||
func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) {
|
||||
var resolvedops []*requestOp
|
||||
handleResp := func(msg *jsonrpcMessage) {
|
||||
op := h.respWait[string(msg.ID)]
|
||||
if op == nil {
|
||||
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
||||
return
|
||||
}
|
||||
resolvedops = append(resolvedops, op)
|
||||
delete(h.respWait, string(msg.ID))
|
||||
|
||||
// For subscription responses, start the subscription if the server
|
||||
// indicates success. EthSubscribe gets unblocked in either case through
|
||||
// the op.resp channel.
|
||||
if op.sub != nil {
|
||||
if msg.Error != nil {
|
||||
op.err = msg.Error
|
||||
} else {
|
||||
op.err = json.Unmarshal(msg.Result, &op.sub.subid)
|
||||
if op.err == nil {
|
||||
go op.sub.run()
|
||||
h.clientSubs[op.sub.subid] = op.sub
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !op.hadResponse {
|
||||
op.hadResponse = true
|
||||
op.resp <- batch
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range batch {
|
||||
start := time.Now()
|
||||
switch {
|
||||
case msg.isResponse():
|
||||
handleResp(msg)
|
||||
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
|
||||
|
||||
case msg.isNotification():
|
||||
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
|
||||
h.handleSubscriptionResult(msg)
|
||||
return true
|
||||
continue
|
||||
}
|
||||
return false
|
||||
case msg.isResponse():
|
||||
h.handleResponse(msg)
|
||||
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
|
||||
return true
|
||||
handleCall(msg)
|
||||
|
||||
default:
|
||||
return false
|
||||
handleCall(msg)
|
||||
}
|
||||
}
|
||||
|
||||
for _, op := range resolvedops {
|
||||
h.removeRequestOp(op)
|
||||
}
|
||||
}
|
||||
|
||||
@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponse processes method call responses.
|
||||
func (h *handler) handleResponse(msg *jsonrpcMessage) {
|
||||
op := h.respWait[string(msg.ID)]
|
||||
if op == nil {
|
||||
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
|
||||
return
|
||||
}
|
||||
delete(h.respWait, string(msg.ID))
|
||||
// For normal responses, just forward the reply to Call/BatchCall.
|
||||
if op.sub == nil {
|
||||
op.resp <- msg
|
||||
return
|
||||
}
|
||||
// For subscription responses, start the subscription if the server
|
||||
// indicates success. EthSubscribe gets unblocked in either case through
|
||||
// the op.resp channel.
|
||||
defer close(op.resp)
|
||||
if msg.Error != nil {
|
||||
op.err = msg.Error
|
||||
return
|
||||
}
|
||||
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
|
||||
go op.sub.run()
|
||||
h.clientSubs[op.sub.subid] = op.sub
|
||||
}
|
||||
}
|
||||
|
||||
// handleCallMsg executes a call message and returns the answer.
|
||||
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
|
||||
start := time.Now()
|
||||
@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
|
||||
h.handleCall(ctx, msg)
|
||||
h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
|
||||
return nil
|
||||
|
||||
case msg.isCall():
|
||||
resp := h.handleCall(ctx, msg)
|
||||
var ctx []interface{}
|
||||
@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
|
||||
h.log.Debug("Served "+msg.Method, ctx...)
|
||||
}
|
||||
return resp
|
||||
|
||||
case msg.hasValidID():
|
||||
return msg.errorResponse(&invalidRequestError{"invalid request"})
|
||||
|
||||
default:
|
||||
return errorMessage(&invalidRequestError{"invalid request"})
|
||||
}
|
||||
@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
||||
if callb == nil {
|
||||
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
|
||||
}
|
||||
|
||||
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
|
||||
if err != nil {
|
||||
return msg.errorResponse(&invalidParamsError{err.Error()})
|
||||
}
|
||||
start := time.Now()
|
||||
answer := h.runMethod(cp.ctx, msg, callb, args)
|
||||
|
||||
// Collect the statistics for RPC calls if metrics is enabled.
|
||||
// We only care about pure rpc call. Filter out subscription.
|
||||
if callb != h.unsubscribeCb {
|
||||
@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
||||
rpcServingTimer.UpdateSince(start)
|
||||
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
|
||||
}
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
|
19
rpc/http.go
19
rpc/http.go
@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
||||
var cfg clientConfig
|
||||
cfg.httpClient = client
|
||||
fn := newClientTransportHTTP(endpoint, &cfg)
|
||||
return newClient(context.Background(), fn)
|
||||
return newClient(context.Background(), &cfg, fn)
|
||||
}
|
||||
|
||||
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
|
||||
@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
var respmsg jsonrpcMessage
|
||||
if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil {
|
||||
var resp jsonrpcMessage
|
||||
batch := [1]*jsonrpcMessage{&resp}
|
||||
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
op.resp <- &respmsg
|
||||
op.resp <- batch[:]
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
|
||||
return err
|
||||
}
|
||||
defer respBody.Close()
|
||||
var respmsgs []jsonrpcMessage
|
||||
|
||||
var respmsgs []*jsonrpcMessage
|
||||
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(respmsgs) != len(msgs) {
|
||||
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
|
||||
}
|
||||
for i := 0; i < len(respmsgs); i++ {
|
||||
op.resp <- &respmsgs[i]
|
||||
}
|
||||
op.resp <- respmsgs
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,8 @@ import (
|
||||
// DialInProc attaches an in-process connection to the given RPC server.
|
||||
func DialInProc(handler *Server) *Client {
|
||||
initctx := context.Background()
|
||||
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
|
||||
cfg := new(clientConfig)
|
||||
c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) {
|
||||
p1, p2 := net.Pipe()
|
||||
go handler.ServeCodec(NewCodec(p1), 0)
|
||||
return NewCodec(p2), nil
|
||||
|
@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error {
|
||||
// The context is used for the initial connection establishment. It does not
|
||||
// affect subsequent interactions with the client.
|
||||
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
|
||||
return newClient(ctx, newClientTransportIPC(endpoint))
|
||||
cfg := new(clientConfig)
|
||||
return newClient(ctx, cfg, newClientTransportIPC(endpoint))
|
||||
}
|
||||
|
||||
func newClientTransportIPC(endpoint string) reconnectFunc {
|
||||
|
@ -49,6 +49,8 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
codecs map[ServerCodec]struct{}
|
||||
run atomic.Bool
|
||||
batchItemLimit int
|
||||
batchResponseLimit int
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance with no registered handlers.
|
||||
@ -65,6 +67,17 @@ func NewServer() *Server {
|
||||
return server
|
||||
}
|
||||
|
||||
// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit'
|
||||
// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of
|
||||
// response bytes across all requests in a batch.
|
||||
//
|
||||
// This method should be called before processing any requests via ServeCodec, ServeHTTP,
|
||||
// ServeListener etc.
|
||||
func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) {
|
||||
s.batchItemLimit = itemLimit
|
||||
s.batchResponseLimit = maxResponseSize
|
||||
}
|
||||
|
||||
// RegisterName creates a service for the given receiver type under the given name. When no
|
||||
// methods on the given receiver match the criteria to be either a RPC method or a
|
||||
// subscription an error is returned. Otherwise a new service is created and added to the
|
||||
@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
|
||||
}
|
||||
defer s.untrackCodec(codec)
|
||||
|
||||
c := initClient(codec, s.idgen, &s.services)
|
||||
cfg := &clientConfig{
|
||||
idgen: s.idgen,
|
||||
batchItemLimit: s.batchItemLimit,
|
||||
batchResponseLimit: s.batchResponseLimit,
|
||||
}
|
||||
c := initClient(codec, &s.services, cfg)
|
||||
<-codec.closed()
|
||||
c.Close()
|
||||
}
|
||||
@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
|
||||
return
|
||||
}
|
||||
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services)
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit)
|
||||
h.allowSubscribe = false
|
||||
defer h.close(io.EOF, nil)
|
||||
|
||||
|
@ -70,6 +70,7 @@ func TestServer(t *testing.T) {
|
||||
|
||||
func runTestScript(t *testing.T, file string) {
|
||||
server := newTestServer()
|
||||
server.SetBatchLimits(4, 100000)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerBatchResponseSizeLimit(t *testing.T) {
|
||||
server := newTestServer()
|
||||
defer server.Stop()
|
||||
server.SetBatchLimits(100, 60)
|
||||
var (
|
||||
batch []BatchElem
|
||||
client = DialInProc(server)
|
||||
)
|
||||
for i := 0; i < 5; i++ {
|
||||
batch = append(batch, BatchElem{
|
||||
Method: "test_echo",
|
||||
Args: []any{"x", 1},
|
||||
Result: new(echoResult),
|
||||
})
|
||||
}
|
||||
if err := client.BatchCall(batch); err != nil {
|
||||
t.Fatal("error sending batch:", err)
|
||||
}
|
||||
for i := range batch {
|
||||
// We expect the first two queries to be ok, but after that the size limit takes effect.
|
||||
if i < 2 {
|
||||
if batch[i].Error != nil {
|
||||
t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// After two, we expect an error.
|
||||
re, ok := batch[i].Error.(Error)
|
||||
if !ok {
|
||||
t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error)
|
||||
}
|
||||
wantedCode := errcodeResponseTooLarge
|
||||
if re.ErrorCode() != wantedCode {
|
||||
t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) {
|
||||
|
||||
// DialIO creates a client which uses the given IO channels
|
||||
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
|
||||
return newClient(ctx, newClientTransportIO(in, out))
|
||||
cfg := new(clientConfig)
|
||||
return newClient(ctx, cfg, newClientTransportIO(in, out))
|
||||
}
|
||||
|
||||
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
|
||||
|
13
rpc/testdata/invalid-batch-toolarge.js
vendored
Normal file
13
rpc/testdata/invalid-batch-toolarge.js
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
// This file checks the behavior of the batch item limit code.
|
||||
// In tests, the batch item limit is set to 4. So to trigger the error,
|
||||
// all batches in this file have 5 elements.
|
||||
|
||||
// For batches that do not contain any calls, a response message with "id" == null
|
||||
// is returned.
|
||||
|
||||
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
|
||||
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}]
|
||||
|
||||
// For batches with at least one call, the call's "id" is used.
|
||||
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
|
||||
<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}]
|
@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newClient(ctx, connect)
|
||||
return newClient(ctx, cfg, connect)
|
||||
}
|
||||
|
||||
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
|
||||
@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newClient(ctx, connect)
|
||||
return newClient(ctx, cfg, connect)
|
||||
}
|
||||
|
||||
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user