rpc: implement full bi-directional communication (#18471)

New APIs added:

    client.RegisterName(namespace, service) // makes service available to server
    client.Notify(ctx, method, args...)     // sends a notification
    ClientFromContext(ctx)                  // to get a client in handler method

This is essentially a rewrite of the server-side code. JSON-RPC
processing code is now the same on both server and client side. Many
minor issues were fixed in the process and there is a new test suite for
JSON-RPC spec compliance (and non-compliance in some cases).

List of behavior changes:

- Method handlers are now called with a per-request context instead of a
  per-connection context. The context is canceled right after the method
  returns.
- Subscription error channels are always closed when the connection
  ends. There is no need to also wait on the Notifier's Closed channel
  to detect whether the subscription has ended.
- Client now omits "params" instead of sending "params": null when there
  are no arguments to a call. The previous behavior was not compliant
  with the spec. The server still accepts "params": null.
- Floating point numbers are allowed as "id". The spec doesn't allow
  them, but we handle request "id" as json.RawMessage and guarantee that
  the same number will be sent back.
- Logging is improved significantly. There is now a message at DEBUG
  level for each RPC call served.
This commit is contained in:
Felix Lange 2019-02-04 13:47:34 +01:00 committed by GitHub
parent ec3432bccb
commit 245f3146c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2211 additions and 2169 deletions

@ -18,17 +18,13 @@ package rpc
import ( import (
"bytes" "bytes"
"container/list"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/url" "net/url"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -39,13 +35,14 @@ var (
ErrClientQuit = errors.New("client is closed") ErrClientQuit = errors.New("client is closed")
ErrNoResult = errors.New("no result in JSON-RPC response") ErrNoResult = errors.New("no result in JSON-RPC response")
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
errClientReconnected = errors.New("client reconnected")
errDead = errors.New("connection lost")
) )
const ( const (
// Timeouts // Timeouts
tcpKeepAliveInterval = 30 * time.Second tcpKeepAliveInterval = 30 * time.Second
defaultDialTimeout = 10 * time.Second // used when dialing if the context has no deadline defaultDialTimeout = 10 * time.Second // used if context has no deadline
defaultWriteTimeout = 10 * time.Second // used for calls if the context has no deadline
subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls
) )
@ -76,56 +73,57 @@ type BatchElem struct {
Error error Error error
} }
// A value of this type can a JSON-RPC request, notification, successful response or
// error response. Which one it is depends on the fields.
type jsonrpcMessage struct {
Version string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Error *jsonError `json:"error,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
func (msg *jsonrpcMessage) isNotification() bool {
return msg.ID == nil && msg.Method != ""
}
func (msg *jsonrpcMessage) isResponse() bool {
return msg.hasValidID() && msg.Method == "" && len(msg.Params) == 0
}
func (msg *jsonrpcMessage) hasValidID() bool {
return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '['
}
func (msg *jsonrpcMessage) String() string {
b, _ := json.Marshal(msg)
return string(b)
}
// Client represents a connection to an RPC server. // Client represents a connection to an RPC server.
type Client struct { type Client struct {
idCounter uint32 idgen func() ID // for subscriptions
connectFunc func(ctx context.Context) (net.Conn, error)
isHTTP bool isHTTP bool
services *serviceRegistry
// writeConn is only safe to access outside dispatch, with the idCounter uint32
// write lock held. The write lock is taken by sending on
// requestOp and released by sending on sendDone. // This function, if non-nil, is called when the connection is lost.
writeConn net.Conn reconnectFunc reconnectFunc
// writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on requestOp and released by sending on sendDone.
writeConn jsonWriter
// for dispatch // for dispatch
close chan struct{} close chan struct{}
closing chan struct{} // closed when client is quitting closing chan struct{} // closed when client is quitting
didClose chan struct{} // closed when client quits didClose chan struct{} // closed when client quits
reconnected chan net.Conn // where write/reconnect sends the new connection reconnected chan ServerCodec // where write/reconnect sends the new connection
readOp chan readOp // read messages
readErr chan error // errors from read readErr chan error // errors from read
readResp chan []*jsonrpcMessage // valid messages from read reqInit chan *requestOp // register response IDs, takes write lock
requestOp chan *requestOp // for registering response IDs reqSent chan error // signals write completion, releases write lock
sendDone chan error // signals write completion, releases write lock reqTimeout chan *requestOp // removes response IDs when call timeout expires
respWait map[string]*requestOp // active requests }
subs map[string]*ClientSubscription // active subscriptions
type reconnectFunc func(ctx context.Context) (ServerCodec, error)
type clientContextKey struct{}
type clientConn struct {
codec ServerCodec
handler *handler
}
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
handler := newHandler(ctx, conn, c.idgen, c.services)
return &clientConn{conn, handler}
}
func (cc *clientConn) close(err error, inflightReq *requestOp) {
cc.handler.close(err, inflightReq)
cc.codec.Close()
}
type readOp struct {
msgs []*jsonrpcMessage
batch bool
} }
type requestOp struct { type requestOp struct {
@ -135,9 +133,14 @@ type requestOp struct {
sub *ClientSubscription // only set for EthSubscribe requests sub *ClientSubscription // only set for EthSubscribe requests
} }
func (op *requestOp) wait(ctx context.Context) (*jsonrpcMessage, error) { func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Send the timeout to dispatch so it can remove the request IDs.
select {
case c.reqTimeout <- op:
case <-c.closing:
}
return nil, ctx.Err() return nil, ctx.Err()
case resp := <-op.resp: case resp := <-op.resp:
return resp, op.err return resp, op.err
@ -181,36 +184,57 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) {
} }
} }
func newClient(initctx context.Context, connectFunc func(context.Context) (net.Conn, error)) (*Client, error) { // Client retrieves the client from the context, if any. This can be used to perform
conn, err := connectFunc(initctx) // 'reverse calls' in a handler method.
func ClientFromContext(ctx context.Context) (*Client, bool) {
client, ok := ctx.Value(clientContextKey{}).(*Client)
return client, ok
}
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
conn, err := connect(initctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
c.reconnectFunc = connect
return c, nil
}
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
_, isHTTP := conn.(*httpConn) _, isHTTP := conn.(*httpConn)
c := &Client{ c := &Client{
writeConn: conn, idgen: idgen,
isHTTP: isHTTP, isHTTP: isHTTP,
connectFunc: connectFunc, services: services,
writeConn: conn,
close: make(chan struct{}), close: make(chan struct{}),
closing: make(chan struct{}), closing: make(chan struct{}),
didClose: make(chan struct{}), didClose: make(chan struct{}),
reconnected: make(chan net.Conn), reconnected: make(chan ServerCodec),
readOp: make(chan readOp),
readErr: make(chan error), readErr: make(chan error),
readResp: make(chan []*jsonrpcMessage), reqInit: make(chan *requestOp),
requestOp: make(chan *requestOp), reqSent: make(chan error, 1),
sendDone: make(chan error, 1), reqTimeout: make(chan *requestOp),
respWait: make(map[string]*requestOp),
subs: make(map[string]*ClientSubscription),
} }
if !isHTTP { if !isHTTP {
go c.dispatch(conn) go c.dispatch(conn)
} }
return c, nil return c
}
// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
// service collection this client provides to the server.
func (c *Client) RegisterName(name string, receiver interface{}) error {
return c.services.registerName(name, receiver)
} }
func (c *Client) nextID() json.RawMessage { func (c *Client) nextID() json.RawMessage {
id := atomic.AddUint32(&c.idCounter, 1) id := atomic.AddUint32(&c.idCounter, 1)
return []byte(strconv.FormatUint(uint64(id), 10)) return strconv.AppendUint(nil, uint64(id), 10)
} }
// SupportedModules calls the rpc_modules method, retrieving the list of // SupportedModules calls the rpc_modules method, retrieving the list of
@ -267,7 +291,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
} }
// dispatch has accepted the request and will close the channel when it quits. // dispatch has accepted the request and will close the channel when it quits.
switch resp, err := op.wait(ctx); { switch resp, err := op.wait(ctx, c); {
case err != nil: case err != nil:
return err return err
case resp.Error != nil: case resp.Error != nil:
@ -325,7 +349,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
// Wait for all responses to come back. // Wait for all responses to come back.
for n := 0; n < len(b) && err == nil; n++ { for n := 0; n < len(b) && err == nil; n++ {
var resp *jsonrpcMessage var resp *jsonrpcMessage
resp, err = op.wait(ctx) resp, err = op.wait(ctx, c)
if err != nil { if err != nil {
break break
} }
@ -352,6 +376,22 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
return err return err
} }
// Notify sends a notification, i.e. a method call that doesn't expect a response.
func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error {
op := new(requestOp)
msg, err := c.newMessage(method, args...)
if err != nil {
return err
}
msg.ID = nil
if c.isHTTP {
return c.sendHTTP(ctx, op, msg)
} else {
return c.send(ctx, op, msg)
}
}
// EthSubscribe registers a subscripion under the "eth" namespace. // EthSubscribe registers a subscripion under the "eth" namespace.
func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) {
return c.Subscribe(ctx, "eth", channel, args...) return c.Subscribe(ctx, "eth", channel, args...)
@ -402,30 +442,30 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
if err := c.send(ctx, op, msg); err != nil { if err := c.send(ctx, op, msg); err != nil {
return nil, err return nil, err
} }
if _, err := op.wait(ctx); err != nil { if _, err := op.wait(ctx, c); err != nil {
return nil, err return nil, err
} }
return op.sub, nil return op.sub, nil
} }
func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) { func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) {
params, err := json.Marshal(paramsIn) msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method}
if err != nil { if paramsIn != nil { // prevent sending "params":null
var err error
if msg.Params, err = json.Marshal(paramsIn); err != nil {
return nil, err return nil, err
} }
return &jsonrpcMessage{Version: "2.0", ID: c.nextID(), Method: method, Params: params}, nil }
return msg, nil
} }
// send registers op with the dispatch loop, then sends msg on the connection. // send registers op with the dispatch loop, then sends msg on the connection.
// if sending fails, op is deregistered. // if sending fails, op is deregistered.
func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error { func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error {
select { select {
case c.requestOp <- op: case c.reqInit <- op:
log.Trace("", "msg", log.Lazy{Fn: func() string {
return fmt.Sprint("sending ", msg)
}})
err := c.write(ctx, msg) err := c.write(ctx, msg)
c.sendDone <- err c.reqSent <- err
return err return err
case <-ctx.Done(): case <-ctx.Done():
// This can happen if the client is overloaded or unable to keep up with // This can happen if the client is overloaded or unable to keep up with
@ -433,25 +473,17 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error
return ctx.Err() return ctx.Err()
case <-c.closing: case <-c.closing:
return ErrClientQuit return ErrClientQuit
case <-c.didClose:
return ErrClientQuit
} }
} }
func (c *Client) write(ctx context.Context, msg interface{}) error { func (c *Client) write(ctx context.Context, msg interface{}) error {
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultWriteTimeout)
}
// The previous write failed. Try to establish a new connection. // The previous write failed. Try to establish a new connection.
if c.writeConn == nil { if c.writeConn == nil {
if err := c.reconnect(ctx); err != nil { if err := c.reconnect(ctx); err != nil {
return err return err
} }
} }
c.writeConn.SetWriteDeadline(deadline) err := c.writeConn.Write(ctx, msg)
err := json.NewEncoder(c.writeConn).Encode(msg)
c.writeConn.SetWriteDeadline(time.Time{})
if err != nil { if err != nil {
c.writeConn = nil c.writeConn = nil
} }
@ -459,9 +491,18 @@ func (c *Client) write(ctx context.Context, msg interface{}) error {
} }
func (c *Client) reconnect(ctx context.Context) error { func (c *Client) reconnect(ctx context.Context) error {
newconn, err := c.connectFunc(ctx) if c.reconnectFunc == nil {
return errDead
}
if _, ok := ctx.Deadline(); !ok {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout)
defer cancel()
}
newconn, err := c.reconnectFunc(ctx)
if err != nil { if err != nil {
log.Trace(fmt.Sprintf("reconnect failed: %v", err)) log.Trace("RPC client reconnect failed", "err", err)
return err return err
} }
select { select {
@ -477,322 +518,107 @@ func (c *Client) reconnect(ctx context.Context) error {
// dispatch is the main loop of the client. // dispatch is the main loop of the client.
// It sends read messages to waiting calls to Call and BatchCall // It sends read messages to waiting calls to Call and BatchCall
// and subscription notifications to registered subscriptions. // and subscription notifications to registered subscriptions.
func (c *Client) dispatch(conn net.Conn) { func (c *Client) dispatch(codec ServerCodec) {
// Spawn the initial read loop.
go c.read(conn)
var ( var (
lastOp *requestOp // tracks last send operation lastOp *requestOp // tracks last send operation
requestOpLock = c.requestOp // nil while the send lock is held reqInitLock = c.reqInit // nil while the send lock is held
reading = true // if true, a read loop is running conn = c.newClientConn(codec)
reading = true
) )
defer close(c.didClose)
defer func() { defer func() {
close(c.closing) close(c.closing)
c.closeRequestOps(ErrClientQuit)
conn.Close()
if reading { if reading {
// Empty read channels until read is dead. conn.close(ErrClientQuit, nil)
for { c.drainRead()
select {
case <-c.readResp:
case <-c.readErr:
return
}
}
} }
close(c.didClose)
}() }()
// Spawn the initial read loop.
go c.read(codec)
for { for {
select { select {
case <-c.close: case <-c.close:
return return
// Read path. // Read path:
case batch := <-c.readResp: case op := <-c.readOp:
for _, msg := range batch { if op.batch {
switch { conn.handler.handleBatch(op.msgs)
case msg.isNotification(): } else {
log.Trace("", "msg", log.Lazy{Fn: func() string { conn.handler.handleMsg(op.msgs[0])
return fmt.Sprint("<-readResp: notification ", msg)
}})
c.handleNotification(msg)
case msg.isResponse():
log.Trace("", "msg", log.Lazy{Fn: func() string {
return fmt.Sprint("<-readResp: response ", msg)
}})
c.handleResponse(msg)
default:
log.Debug("", "msg", log.Lazy{Fn: func() string {
return fmt.Sprint("<-readResp: dropping weird message", msg)
}})
// TODO: maybe close
}
} }
case err := <-c.readErr: case err := <-c.readErr:
log.Debug("<-readErr", "err", err) conn.handler.log.Debug("RPC connection read error", "err", err)
c.closeRequestOps(err) conn.close(err, lastOp)
conn.Close()
reading = false reading = false
case newconn := <-c.reconnected: // Reconnect:
log.Debug("<-reconnected", "reading", reading, "remote", conn.RemoteAddr()) case newcodec := <-c.reconnected:
log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.RemoteAddr())
if reading { if reading {
// Wait for the previous read loop to exit. This is a rare case. // Wait for the previous read loop to exit. This is a rare case which
conn.Close() // happens if this loop isn't notified in time after the connection breaks.
<-c.readErr // In those cases the caller will notice first and reconnect. Closing the
// handler terminates all waiting requests (closing op.resp) except for
// lastOp, which will be transferred to the new handler.
conn.close(errClientReconnected, lastOp)
c.drainRead()
} }
go c.read(newconn) go c.read(newcodec)
reading = true reading = true
conn = newconn conn = c.newClientConn(newcodec)
// Re-register the in-flight request on the new handler
// because that's where it will be sent.
conn.handler.addRequestOp(lastOp)
// Send path. // Send path:
case op := <-requestOpLock: case op := <-reqInitLock:
// Stop listening for further send ops until the current one is done. // Stop listening for further requests until the current one has been sent.
requestOpLock = nil reqInitLock = nil
lastOp = op lastOp = op
for _, id := range op.ids { conn.handler.addRequestOp(op)
c.respWait[string(id)] = op
}
case err := <-c.sendDone: case err := <-c.reqSent:
if err != nil { if err != nil {
// Remove response handlers for the last send. We remove those here // Remove response handlers for the last send. When the read loop
// because the error is already handled in Call or BatchCall. When the // goes down, it will signal all other current operations.
// read loop goes down, it will signal all other current operations. conn.handler.removeRequestOp(lastOp)
for _, id := range lastOp.ids {
delete(c.respWait, string(id))
} }
} // Let the next request in.
// Listen for send ops again. reqInitLock = c.reqInit
requestOpLock = c.requestOp
lastOp = nil lastOp = nil
case op := <-c.reqTimeout:
conn.handler.removeRequestOp(op)
} }
} }
} }
// closeRequestOps unblocks pending send ops and active subscriptions. // drainRead drops read messages until an error occurs.
func (c *Client) closeRequestOps(err error) { func (c *Client) drainRead() {
didClose := make(map[*requestOp]bool)
for id, op := range c.respWait {
// Remove the op so that later calls will not close op.resp again.
delete(c.respWait, id)
if !didClose[op] {
op.err = err
close(op.resp)
didClose[op] = true
}
}
for id, sub := range c.subs {
delete(c.subs, id)
sub.quitWithError(err, false)
}
}
func (c *Client) handleNotification(msg *jsonrpcMessage) {
if !strings.HasSuffix(msg.Method, notificationMethodSuffix) {
log.Debug("dropping non-subscription message", "msg", msg)
return
}
var subResult struct {
ID string `json:"subscription"`
Result json.RawMessage `json:"result"`
}
if err := json.Unmarshal(msg.Params, &subResult); err != nil {
log.Debug("dropping invalid subscription message", "msg", msg)
return
}
if c.subs[subResult.ID] != nil {
c.subs[subResult.ID].deliver(subResult.Result)
}
}
func (c *Client) handleResponse(msg *jsonrpcMessage) {
op := c.respWait[string(msg.ID)]
if op == nil {
log.Debug("unsolicited response", "msg", msg)
return
}
delete(c.respWait, string(msg.ID))
// For normal responses, just forward the reply to Call/BatchCall.
if op.sub == nil {
op.resp <- msg
return
}
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
defer close(op.resp)
if msg.Error != nil {
op.err = msg.Error
return
}
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
go op.sub.start()
c.subs[op.sub.subid] = op.sub
}
}
// Reading happens on a dedicated goroutine.
func (c *Client) read(conn net.Conn) error {
var (
buf json.RawMessage
dec = json.NewDecoder(conn)
)
readMessage := func() (rs []*jsonrpcMessage, err error) {
buf = buf[:0]
if err = dec.Decode(&buf); err != nil {
return nil, err
}
if isBatch(buf) {
err = json.Unmarshal(buf, &rs)
} else {
rs = make([]*jsonrpcMessage, 1)
err = json.Unmarshal(buf, &rs[0])
}
return rs, err
}
for { for {
resp, err := readMessage() select {
case <-c.readOp:
case <-c.readErr:
return
}
}
}
// read decodes RPC messages from a codec, feeding them into dispatch.
func (c *Client) read(codec ServerCodec) {
for {
msgs, batch, err := codec.Read()
if _, ok := err.(*json.SyntaxError); ok {
codec.Write(context.Background(), errorMessage(&parseError{err.Error()}))
}
if err != nil { if err != nil {
c.readErr <- err c.readErr <- err
return err return
} }
c.readResp <- resp c.readOp <- readOp{msgs, batch}
} }
} }
// Subscriptions.
// A ClientSubscription represents a subscription established through EthSubscribe.
type ClientSubscription struct {
client *Client
etype reflect.Type
channel reflect.Value
namespace string
subid string
in chan json.RawMessage
quitOnce sync.Once // ensures quit is closed once
quit chan struct{} // quit is closed when the subscription exits
errOnce sync.Once // ensures err is closed once
err chan error
}
func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription {
sub := &ClientSubscription{
client: c,
namespace: namespace,
etype: channel.Type().Elem(),
channel: channel,
quit: make(chan struct{}),
err: make(chan error, 1),
in: make(chan json.RawMessage),
}
return sub
}
// Err returns the subscription error channel. The intended use of Err is to schedule
// resubscription when the client connection is closed unexpectedly.
//
// The error channel receives a value when the subscription has ended due
// to an error. The received error is nil if Close has been called
// on the underlying client and no other error has occurred.
//
// The error channel is closed when Unsubscribe is called on the subscription.
func (sub *ClientSubscription) Err() <-chan error {
return sub.err
}
// Unsubscribe unsubscribes the notification and closes the error channel.
// It can safely be called more than once.
func (sub *ClientSubscription) Unsubscribe() {
sub.quitWithError(nil, true)
sub.errOnce.Do(func() { close(sub.err) })
}
func (sub *ClientSubscription) quitWithError(err error, unsubscribeServer bool) {
sub.quitOnce.Do(func() {
// The dispatch loop won't be able to execute the unsubscribe call
// if it is blocked on deliver. Close sub.quit first because it
// unblocks deliver.
close(sub.quit)
if unsubscribeServer {
sub.requestUnsubscribe()
}
if err != nil {
if err == ErrClientQuit {
err = nil // Adhere to subscription semantics.
}
sub.err <- err
}
})
}
func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) {
select {
case sub.in <- result:
return true
case <-sub.quit:
return false
}
}
func (sub *ClientSubscription) start() {
sub.quitWithError(sub.forward())
}
func (sub *ClientSubscription) forward() (err error, unsubscribeServer bool) {
cases := []reflect.SelectCase{
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)},
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)},
{Dir: reflect.SelectSend, Chan: sub.channel},
}
buffer := list.New()
defer buffer.Init()
for {
var chosen int
var recv reflect.Value
if buffer.Len() == 0 {
// Idle, omit send case.
chosen, recv, _ = reflect.Select(cases[:2])
} else {
// Non-empty buffer, send the first queued item.
cases[2].Send = reflect.ValueOf(buffer.Front().Value)
chosen, recv, _ = reflect.Select(cases)
}
switch chosen {
case 0: // <-sub.quit
return nil, false
case 1: // <-sub.in
val, err := sub.unmarshal(recv.Interface().(json.RawMessage))
if err != nil {
return err, true
}
if buffer.Len() == maxClientSubscriptionBuffer {
return ErrSubscriptionQueueOverflow, true
}
buffer.PushBack(val)
case 2: // sub.channel<-
cases[2].Send = reflect.Value{} // Don't hold onto the value.
buffer.Remove(buffer.Front())
}
}
}
func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) {
val := reflect.New(sub.etype)
err := json.Unmarshal(result, val.Interface())
return val.Elem().Interface(), err
}
func (sub *ClientSubscription) requestUnsubscribe() error {
var result interface{}
return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid)
}

@ -35,13 +35,13 @@ import (
) )
func TestClientRequest(t *testing.T) { func TestClientRequest(t *testing.T) {
server := newTestServer("service", new(Service)) server := newTestServer()
defer server.Stop() defer server.Stop()
client := DialInProc(server) client := DialInProc(server)
defer client.Close() defer client.Close()
var resp Result var resp Result
if err := client.Call(&resp, "service_echo", "hello", 10, &Args{"world"}); err != nil { if err := client.Call(&resp, "test_echo", "hello", 10, &Args{"world"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(resp, Result{"hello", 10, &Args{"world"}}) { if !reflect.DeepEqual(resp, Result{"hello", 10, &Args{"world"}}) {
@ -50,19 +50,19 @@ func TestClientRequest(t *testing.T) {
} }
func TestClientBatchRequest(t *testing.T) { func TestClientBatchRequest(t *testing.T) {
server := newTestServer("service", new(Service)) server := newTestServer()
defer server.Stop() defer server.Stop()
client := DialInProc(server) client := DialInProc(server)
defer client.Close() defer client.Close()
batch := []BatchElem{ batch := []BatchElem{
{ {
Method: "service_echo", Method: "test_echo",
Args: []interface{}{"hello", 10, &Args{"world"}}, Args: []interface{}{"hello", 10, &Args{"world"}},
Result: new(Result), Result: new(Result),
}, },
{ {
Method: "service_echo", Method: "test_echo",
Args: []interface{}{"hello2", 11, &Args{"world"}}, Args: []interface{}{"hello2", 11, &Args{"world"}},
Result: new(Result), Result: new(Result),
}, },
@ -77,12 +77,12 @@ func TestClientBatchRequest(t *testing.T) {
} }
wantResult := []BatchElem{ wantResult := []BatchElem{
{ {
Method: "service_echo", Method: "test_echo",
Args: []interface{}{"hello", 10, &Args{"world"}}, Args: []interface{}{"hello", 10, &Args{"world"}},
Result: &Result{"hello", 10, &Args{"world"}}, Result: &Result{"hello", 10, &Args{"world"}},
}, },
{ {
Method: "service_echo", Method: "test_echo",
Args: []interface{}{"hello2", 11, &Args{"world"}}, Args: []interface{}{"hello2", 11, &Args{"world"}},
Result: &Result{"hello2", 11, &Args{"world"}}, Result: &Result{"hello2", 11, &Args{"world"}},
}, },
@ -90,7 +90,7 @@ func TestClientBatchRequest(t *testing.T) {
Method: "no_such_method", Method: "no_such_method",
Args: []interface{}{1, 2, 3}, Args: []interface{}{1, 2, 3},
Result: new(int), Result: new(int),
Error: &jsonError{Code: -32601, Message: "The method no_such_method_ does not exist/is not available"}, Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"},
}, },
} }
if !reflect.DeepEqual(batch, wantResult) { if !reflect.DeepEqual(batch, wantResult) {
@ -98,6 +98,17 @@ func TestClientBatchRequest(t *testing.T) {
} }
} }
func TestClientNotify(t *testing.T) {
server := newTestServer()
defer server.Stop()
client := DialInProc(server)
defer client.Close()
if err := client.Notify(context.Background(), "test_echo", "hello", 10, &Args{"world"}); err != nil {
t.Fatal(err)
}
}
// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } // func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) }
func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) }
func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) }
@ -106,7 +117,12 @@ func TestClientCancelIPC(t *testing.T) { testClientCancel("ipc", t) }
// This test checks that requests made through CallContext can be canceled by canceling // This test checks that requests made through CallContext can be canceled by canceling
// the context. // the context.
func testClientCancel(transport string, t *testing.T) { func testClientCancel(transport string, t *testing.T) {
server := newTestServer("service", new(Service)) // These tests take a lot of time, run them all at once.
// You probably want to run with -parallel 1 or comment out
// the call to t.Parallel if you enable the logging.
t.Parallel()
server := newTestServer()
defer server.Stop() defer server.Stop()
// What we want to achieve is that the context gets canceled // What we want to achieve is that the context gets canceled
@ -142,11 +158,6 @@ func testClientCancel(transport string, t *testing.T) {
panic("unknown transport: " + transport) panic("unknown transport: " + transport)
} }
// These tests take a lot of time, run them all at once.
// You probably want to run with -parallel 1 or comment out
// the call to t.Parallel if you enable the logging.
t.Parallel()
// The actual test starts here. // The actual test starts here.
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
@ -174,7 +185,8 @@ func testClientCancel(transport string, t *testing.T) {
} }
// Now perform a call with the context. // Now perform a call with the context.
// The key thing here is that no call will ever complete successfully. // The key thing here is that no call will ever complete successfully.
err := client.CallContext(ctx, nil, "service_sleep", 2*maxContextCancelTimeout) sleepTime := maxContextCancelTimeout + 20*time.Millisecond
err := client.CallContext(ctx, nil, "test_sleep", sleepTime)
if err != nil { if err != nil {
log.Debug(fmt.Sprint("got expected error:", err)) log.Debug(fmt.Sprint("got expected error:", err))
} else { } else {
@ -191,7 +203,7 @@ func testClientCancel(transport string, t *testing.T) {
} }
func TestClientSubscribeInvalidArg(t *testing.T) { func TestClientSubscribeInvalidArg(t *testing.T) {
server := newTestServer("service", new(Service)) server := newTestServer()
defer server.Stop() defer server.Stop()
client := DialInProc(server) client := DialInProc(server)
defer client.Close() defer client.Close()
@ -221,14 +233,14 @@ func TestClientSubscribeInvalidArg(t *testing.T) {
} }
func TestClientSubscribe(t *testing.T) { func TestClientSubscribe(t *testing.T) {
server := newTestServer("eth", new(NotificationTestService)) server := newTestServer()
defer server.Stop() defer server.Stop()
client := DialInProc(server) client := DialInProc(server)
defer client.Close() defer client.Close()
nc := make(chan int) nc := make(chan int)
count := 10 count := 10
sub, err := client.EthSubscribe(context.Background(), nc, "someSubscription", count, 0) sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0)
if err != nil { if err != nil {
t.Fatal("can't subscribe:", err) t.Fatal("can't subscribe:", err)
} }
@ -251,46 +263,17 @@ func TestClientSubscribe(t *testing.T) {
} }
} }
func TestClientSubscribeCustomNamespace(t *testing.T) { // In this test, the connection drops while Subscribe is waiting for a response.
namespace := "custom"
server := newTestServer(namespace, new(NotificationTestService))
defer server.Stop()
client := DialInProc(server)
defer client.Close()
nc := make(chan int)
count := 10
sub, err := client.Subscribe(context.Background(), namespace, nc, "someSubscription", count, 0)
if err != nil {
t.Fatal("can't subscribe:", err)
}
for i := 0; i < count; i++ {
if val := <-nc; val != i {
t.Fatalf("value mismatch: got %d, want %d", val, i)
}
}
sub.Unsubscribe()
select {
case v := <-nc:
t.Fatal("received value after unsubscribe:", v)
case err := <-sub.Err():
if err != nil {
t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("subscription not closed within 1s after unsubscribe")
}
}
// In this test, the connection drops while EthSubscribe is
// waiting for a response.
func TestClientSubscribeClose(t *testing.T) { func TestClientSubscribeClose(t *testing.T) {
service := &NotificationTestService{ server := newTestServer()
service := &notificationTestService{
gotHangSubscriptionReq: make(chan struct{}), gotHangSubscriptionReq: make(chan struct{}),
unblockHangSubscription: make(chan struct{}), unblockHangSubscription: make(chan struct{}),
} }
server := newTestServer("eth", service) if err := server.RegisterName("nftest2", service); err != nil {
t.Fatal(err)
}
defer server.Stop() defer server.Stop()
client := DialInProc(server) client := DialInProc(server)
defer client.Close() defer client.Close()
@ -302,7 +285,7 @@ func TestClientSubscribeClose(t *testing.T) {
err error err error
) )
go func() { go func() {
sub, err = client.EthSubscribe(context.Background(), nc, "hangSubscription", 999) sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999)
errc <- err errc <- err
}() }()
@ -313,27 +296,26 @@ func TestClientSubscribeClose(t *testing.T) {
select { select {
case err := <-errc: case err := <-errc:
if err == nil { if err == nil {
t.Errorf("EthSubscribe returned nil error after Close") t.Errorf("Subscribe returned nil error after Close")
} }
if sub != nil { if sub != nil {
t.Error("EthSubscribe returned non-nil subscription after Close") t.Error("Subscribe returned non-nil subscription after Close")
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatalf("EthSubscribe did not return within 1s after Close") t.Fatalf("Subscribe did not return within 1s after Close")
} }
} }
// This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the // This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the
// client hangs during shutdown when Unsubscribe races with Client.Close. // client hangs during shutdown when Unsubscribe races with Client.Close.
func TestClientCloseUnsubscribeRace(t *testing.T) { func TestClientCloseUnsubscribeRace(t *testing.T) {
service := &NotificationTestService{} server := newTestServer()
server := newTestServer("eth", service)
defer server.Stop() defer server.Stop()
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
client := DialInProc(server) client := DialInProc(server)
nc := make(chan int) nc := make(chan int)
sub, err := client.EthSubscribe(context.Background(), nc, "someSubscription", 3, 1) sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -350,7 +332,7 @@ func TestClientCloseUnsubscribeRace(t *testing.T) {
// This test checks that Client doesn't lock up when a single subscriber // This test checks that Client doesn't lock up when a single subscriber
// doesn't read subscription events. // doesn't read subscription events.
func TestClientNotificationStorm(t *testing.T) { func TestClientNotificationStorm(t *testing.T) {
server := newTestServer("eth", new(NotificationTestService)) server := newTestServer()
defer server.Stop() defer server.Stop()
doTest := func(count int, wantError bool) { doTest := func(count int, wantError bool) {
@ -362,7 +344,7 @@ func TestClientNotificationStorm(t *testing.T) {
// Subscribe on the server. It will start sending many notifications // Subscribe on the server. It will start sending many notifications
// very quickly. // very quickly.
nc := make(chan int) nc := make(chan int)
sub, err := client.EthSubscribe(ctx, nc, "someSubscription", count, 0) sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0)
if err != nil { if err != nil {
t.Fatal("can't subscribe:", err) t.Fatal("can't subscribe:", err)
} }
@ -384,7 +366,7 @@ func TestClientNotificationStorm(t *testing.T) {
return return
} }
var r int var r int
err := client.CallContext(ctx, &r, "eth_echo", i) err := client.CallContext(ctx, &r, "nftest_echo", i)
if err != nil { if err != nil {
if !wantError { if !wantError {
t.Fatalf("(%d/%d) call error: %v", i, count, err) t.Fatalf("(%d/%d) call error: %v", i, count, err)
@ -399,7 +381,7 @@ func TestClientNotificationStorm(t *testing.T) {
} }
func TestClientHTTP(t *testing.T) { func TestClientHTTP(t *testing.T) {
server := newTestServer("service", new(Service)) server := newTestServer()
defer server.Stop() defer server.Stop()
client, hs := httpTestClient(server, "http", nil) client, hs := httpTestClient(server, "http", nil)
@ -416,7 +398,7 @@ func TestClientHTTP(t *testing.T) {
for i := range results { for i := range results {
i := i i := i
go func() { go func() {
errc <- client.Call(&results[i], "service_echo", errc <- client.Call(&results[i], "test_echo",
wantResult.String, wantResult.Int, wantResult.Args) wantResult.String, wantResult.Int, wantResult.Args)
}() }()
} }
@ -445,16 +427,16 @@ func TestClientHTTP(t *testing.T) {
func TestClientReconnect(t *testing.T) { func TestClientReconnect(t *testing.T) {
startServer := func(addr string) (*Server, net.Listener) { startServer := func(addr string) (*Server, net.Listener) {
srv := newTestServer("service", new(Service)) srv := newTestServer()
l, err := net.Listen("tcp", addr) l, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal("can't listen:", err)
} }
go http.Serve(l, srv.WebsocketHandler([]string{"*"})) go http.Serve(l, srv.WebsocketHandler([]string{"*"}))
return srv, l return srv, l
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second)
defer cancel() defer cancel()
// Start a server and corresponding client. // Start a server and corresponding client.
@ -466,21 +448,22 @@ func TestClientReconnect(t *testing.T) {
// Perform a call. This should work because the server is up. // Perform a call. This should work because the server is up.
var resp Result var resp Result
if err := client.CallContext(ctx, &resp, "service_echo", "", 1, nil); err != nil { if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Shut down the server and try calling again. It shouldn't work. // Shut down the server and allow for some cool down time so we can listen on the same
// address again.
l1.Close() l1.Close()
s1.Stop() s1.Stop()
if err := client.CallContext(ctx, &resp, "service_echo", "", 2, nil); err == nil { time.Sleep(2 * time.Second)
// Try calling again. It shouldn't work.
if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil {
t.Error("successful call while the server is down") t.Error("successful call while the server is down")
t.Logf("resp: %#v", resp) t.Logf("resp: %#v", resp)
} }
// Allow for some cool down time so we can listen on the same address again.
time.Sleep(2 * time.Second)
// Start it up again and call again. The connection should be reestablished. // Start it up again and call again. The connection should be reestablished.
// We spawn multiple calls here to check whether this hangs somehow. // We spawn multiple calls here to check whether this hangs somehow.
s2, l2 := startServer(l1.Addr().String()) s2, l2 := startServer(l1.Addr().String())
@ -493,7 +476,7 @@ func TestClientReconnect(t *testing.T) {
go func() { go func() {
<-start <-start
var resp Result var resp Result
errors <- client.CallContext(ctx, &resp, "service_echo", "", 3, nil) errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil)
}() }()
} }
close(start) close(start)
@ -503,20 +486,12 @@ func TestClientReconnect(t *testing.T) {
errcount++ errcount++
} }
} }
t.Log("err:", err) t.Logf("%d errors, last error: %v", errcount, err)
if errcount > 1 { if errcount > 1 {
t.Errorf("expected one error after disconnect, got %d", errcount) t.Errorf("expected one error after disconnect, got %d", errcount)
} }
} }
func newTestServer(serviceName string, service interface{}) *Server {
server := NewServer()
if err := server.RegisterName(serviceName, service); err != nil {
panic(err)
}
return server
}
func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) {
// Create the HTTP server. // Create the HTTP server.
var hs *httptest.Server var hs *httptest.Server

@ -15,43 +15,49 @@
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
/* /*
Package rpc provides access to the exported methods of an object across a network
or other I/O connection. After creating a server instance objects can be registered, Package rpc implements bi-directional JSON-RPC 2.0 on multiple transports.
making it visible from the outside. Exported methods that follow specific
conventions can be called remotely. It also has support for the publish/subscribe It provides access to the exported methods of an object across a network or other I/O
pattern. connection. After creating a server or client instance, objects can be registered to make
them visible as 'services'. Exported methods that follow specific conventions can be
called remotely. It also has support for the publish/subscribe pattern.
RPC Methods
Methods that satisfy the following criteria are made available for remote access: Methods that satisfy the following criteria are made available for remote access:
- object must be exported
- method must be exported - method must be exported
- method returns 0, 1 (response or error) or 2 (response and error) values - method returns 0, 1 (response or error) or 2 (response and error) values
- method argument(s) must be exported or builtin types - method argument(s) must be exported or builtin types
- method returned value(s) must be exported or builtin types - method returned value(s) must be exported or builtin types
An example method: An example method:
func (s *CalcService) Add(a, b int) (int, error) func (s *CalcService) Add(a, b int) (int, error)
When the returned error isn't nil the returned integer is ignored and the error is When the returned error isn't nil the returned integer is ignored and the error is sent
sent back to the client. Otherwise the returned integer is sent back to the client. back to the client. Otherwise the returned integer is sent back to the client.
Optional arguments are supported by accepting pointer values as arguments. E.g. Optional arguments are supported by accepting pointer values as arguments. E.g. if we want
if we want to do the addition in an optional finite field we can accept a mod to do the addition in an optional finite field we can accept a mod argument as pointer
argument as pointer value. value.
func (s *CalService) Add(a, b int, mod *int) (int, error) func (s *CalcService) Add(a, b int, mod *int) (int, error)
This RPC method can be called with 2 integers and a null value as third argument. This RPC method can be called with 2 integers and a null value as third argument. In that
In that case the mod argument will be nil. Or it can be called with 3 integers, case the mod argument will be nil. Or it can be called with 3 integers, in that case mod
in that case mod will be pointing to the given third argument. Since the optional will be pointing to the given third argument. Since the optional argument is the last
argument is the last argument the RPC package will also accept 2 integers as argument the RPC package will also accept 2 integers as arguments. It will pass the mod
arguments. It will pass the mod argument as nil to the RPC method. argument as nil to the RPC method.
The server offers the ServeCodec method which accepts a ServerCodec instance. It will The server offers the ServeCodec method which accepts a ServerCodec instance. It will read
read requests from the codec, process the request and sends the response back to the requests from the codec, process the request and sends the response back to the client
client using the codec. The server can execute requests concurrently. Responses using the codec. The server can execute requests concurrently. Responses can be sent back
can be sent back to the client out of order. to the client out of order.
An example server which uses the JSON codec: An example server which uses the JSON codec:
type CalculatorService struct {} type CalculatorService struct {}
func (s *CalculatorService) Add(a, b int) int { func (s *CalculatorService) Add(a, b int) int {
@ -73,26 +79,40 @@ An example server which uses the JSON codec:
for { for {
c, _ := l.AcceptUnix() c, _ := l.AcceptUnix()
codec := v2.NewJSONCodec(c) codec := v2.NewJSONCodec(c)
go server.ServeCodec(codec) go server.ServeCodec(codec, 0)
} }
Subscriptions
The package also supports the publish subscribe pattern through the use of subscriptions. The package also supports the publish subscribe pattern through the use of subscriptions.
A method that is considered eligible for notifications must satisfy the following criteria: A method that is considered eligible for notifications must satisfy the following
- object must be exported criteria:
- method must be exported - method must be exported
- first method argument type must be context.Context - first method argument type must be context.Context
- method argument(s) must be exported or builtin types - method argument(s) must be exported or builtin types
- method must return the tuple Subscription, error - method must have return types (rpc.Subscription, error)
An example method: An example method:
func (s *BlockChainService) NewBlocks(ctx context.Context) (Subscription, error) {
func (s *BlockChainService) NewBlocks(ctx context.Context) (rpc.Subscription, error) {
... ...
} }
Subscriptions are deleted when: When the service containing the subscription method is registered to the server, for
- the user sends an unsubscribe request example under the "blockchain" namespace, a subscription is created by calling the
- the connection which was used to create the subscription is closed. This can be initiated "blockchain_subscribe" method.
by the client and server. The server will close the connection on a write error or when
the queue of buffered notifications gets too big. Subscriptions are deleted when the user sends an unsubscribe request or when the
connection which was used to create the subscription is closed. This can be initiated by
the client and server. The server will close the connection for any write error.
For more information about subscriptions, see https://github.com/ethereum/go-ethereum/wiki/RPC-PUB-SUB.
Reverse Calls
In any method handler, an instance of rpc.Client can be accessed through the
ClientFromContext method. Using this client instance, server-to-client method calls can be
performed on the RPC connection.
*/ */
package rpc package rpc

@ -18,18 +18,31 @@ package rpc
import "fmt" import "fmt"
// request is for an unknown service const defaultErrorCode = -32000
type methodNotFoundError struct {
service string type methodNotFoundError struct{ method string }
method string
}
func (e *methodNotFoundError) ErrorCode() int { return -32601 } func (e *methodNotFoundError) ErrorCode() int { return -32601 }
func (e *methodNotFoundError) Error() string { func (e *methodNotFoundError) Error() string {
return fmt.Sprintf("The method %s%s%s does not exist/is not available", e.service, serviceMethodSeparator, e.method) return fmt.Sprintf("the method %s does not exist/is not available", e.method)
} }
type subscriptionNotFoundError struct{ namespace, subscription string }
func (e *subscriptionNotFoundError) ErrorCode() int { return -32601 }
func (e *subscriptionNotFoundError) Error() string {
return fmt.Sprintf("no %q subscription in %s namespace", e.subscription, e.namespace)
}
// Invalid JSON was received by the server.
type parseError struct{ message string }
func (e *parseError) ErrorCode() int { return -32700 }
func (e *parseError) Error() string { return e.message }
// received message isn't a valid request // received message isn't a valid request
type invalidRequestError struct{ message string } type invalidRequestError struct{ message string }
@ -50,17 +63,3 @@ type invalidParamsError struct{ message string }
func (e *invalidParamsError) ErrorCode() int { return -32602 } func (e *invalidParamsError) ErrorCode() int { return -32602 }
func (e *invalidParamsError) Error() string { return e.message } func (e *invalidParamsError) Error() string { return e.message }
// logic error, callback returned an error
type callbackError struct{ message string }
func (e *callbackError) ErrorCode() int { return -32000 }
func (e *callbackError) Error() string { return e.message }
// issued when a request is received after the server is issued to stop.
type shutdownError struct{}
func (e *shutdownError) ErrorCode() int { return -32000 }
func (e *shutdownError) Error() string { return "server is shutting down" }

397
rpc/handler.go Normal file

@ -0,0 +1,397 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"context"
"encoding/json"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
)
// handler handles JSON-RPC messages. There is one handler per connection. Note that
// handler is not safe for concurrent use. Message handling never blocks indefinitely
// because RPCs are processed on background goroutines launched by handler.
//
// The entry points for incoming messages are:
//
// h.handleMsg(message)
// h.handleBatch(message)
//
// Outgoing calls use the requestOp struct. Register the request before sending it
// on the connection:
//
// op := &requestOp{ids: ...}
// h.addRequestOp(op)
//
// Now send the request, then wait for the reply to be delivered through handleMsg:
//
// if err := op.wait(...); err != nil {
// h.removeRequestOp(op) // timeout, etc.
// }
//
type handler struct {
reg *serviceRegistry
unsubscribeCb *callback
idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
allowSubscribe bool
subLock sync.Mutex
serverSubs map[ID]*Subscription
}
type callProc struct {
ctx context.Context
notifiers []*Notifier
}
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
h := &handler{
reg: reg,
idgen: idgen,
conn: conn,
respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx,
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
}
if conn.RemoteAddr() != "" {
h.log = h.log.New("conn", conn.RemoteAddr())
}
h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
return h
}
// handleBatch executes all messages in a batch and returns the responses.
func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
// Emit error response for empty batches:
if len(msgs) == 0 {
h.startCallProc(func(cp *callProc) {
h.conn.Write(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
})
return
}
// Handle non-call messages first:
calls := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range msgs {
if handled := h.handleImmediate(msg); !handled {
calls = append(calls, msg)
}
}
if len(calls) == 0 {
return
}
// Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) {
answers := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range calls {
if answer := h.handleCallMsg(cp, msg); answer != nil {
answers = append(answers, answer)
}
}
h.addSubscriptions(cp.notifiers)
if len(answers) > 0 {
h.conn.Write(cp.ctx, answers)
}
for _, n := range cp.notifiers {
n.activate()
}
})
}
// handleMsg handles a single message.
func (h *handler) handleMsg(msg *jsonrpcMessage) {
if ok := h.handleImmediate(msg); ok {
return
}
h.startCallProc(func(cp *callProc) {
answer := h.handleCallMsg(cp, msg)
h.addSubscriptions(cp.notifiers)
if answer != nil {
h.conn.Write(cp.ctx, answer)
}
for _, n := range cp.notifiers {
n.activate()
}
})
}
// close cancels all requests except for inflightReq and waits for
// call goroutines to shut down.
func (h *handler) close(err error, inflightReq *requestOp) {
h.cancelAllRequests(err, inflightReq)
h.cancelRoot()
h.callWG.Wait()
h.cancelServerSubscriptions(err)
}
// addRequestOp registers a request operation.
func (h *handler) addRequestOp(op *requestOp) {
for _, id := range op.ids {
h.respWait[string(id)] = op
}
}
// removeRequestOps stops waiting for the given request IDs.
func (h *handler) removeRequestOp(op *requestOp) {
for _, id := range op.ids {
delete(h.respWait, string(id))
}
}
// cancelAllRequests unblocks and removes pending requests and active subscriptions.
func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) {
didClose := make(map[*requestOp]bool)
if inflightReq != nil {
didClose[inflightReq] = true
}
for id, op := range h.respWait {
// Remove the op so that later calls will not close op.resp again.
delete(h.respWait, id)
if !didClose[op] {
op.err = err
close(op.resp)
didClose[op] = true
}
}
for id, sub := range h.clientSubs {
delete(h.clientSubs, id)
sub.quitWithError(err, false)
}
}
func (h *handler) addSubscriptions(nn []*Notifier) {
h.subLock.Lock()
defer h.subLock.Unlock()
for _, n := range nn {
if sub := n.takeSubscription(); sub != nil {
h.serverSubs[sub.ID] = sub
}
}
}
// cancelServerSubscriptions removes all subscriptions and closes their error channels.
func (h *handler) cancelServerSubscriptions(err error) {
h.subLock.Lock()
defer h.subLock.Unlock()
for id, s := range h.serverSubs {
s.err <- err
close(s.err)
delete(h.serverSubs, id)
}
}
// startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group.
func (h *handler) startCallProc(fn func(*callProc)) {
h.callWG.Add(1)
go func() {
ctx, cancel := context.WithCancel(h.rootCtx)
defer h.callWG.Done()
defer cancel()
fn(&callProc{ctx: ctx})
}()
}
// handleImmediate executes non-call messages. It returns false if the message is a
// call or requires a reply.
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
start := time.Now()
switch {
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
return true
}
return false
case msg.isResponse():
h.handleResponse(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
return true
default:
return false
}
}
// handleSubscriptionResult processes subscription notifications.
func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
var result subscriptionResult
if err := json.Unmarshal(msg.Params, &result); err != nil {
h.log.Debug("Dropping invalid subscription message")
return
}
if h.clientSubs[result.ID] != nil {
h.clientSubs[result.ID].deliver(result.Result)
}
}
// handleResponse processes method call responses.
func (h *handler) handleResponse(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
delete(h.respWait, string(msg.ID))
// For normal responses, just forward the reply to Call/BatchCall.
if op.sub == nil {
op.resp <- msg
return
}
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
defer close(op.resp)
if msg.Error != nil {
op.err = msg.Error
return
}
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
go op.sub.start()
h.clientSubs[op.sub.subid] = op.sub
}
}
// handleCallMsg executes a call message and returns the answer.
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
start := time.Now()
switch {
case msg.isNotification():
h.handleCall(ctx, msg)
h.log.Debug("Served "+msg.Method, "t", time.Since(start))
return nil
case msg.isCall():
resp := h.handleCall(ctx, msg)
if resp.Error != nil {
h.log.Info("Served "+msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start), "err", resp.Error.Message)
} else {
h.log.Debug("Served "+msg.Method, "reqid", idForLog{msg.ID}, "t", time.Since(start))
}
return resp
case msg.hasValidID():
return msg.errorResponse(&invalidRequestError{"invalid request"})
default:
return errorMessage(&invalidRequestError{"invalid request"})
}
}
// handleCall processes method calls.
func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
if msg.isSubscribe() {
return h.handleSubscribe(cp, msg)
}
var callb *callback
if msg.isUnsubscribe() {
callb = h.unsubscribeCb
} else {
callb = h.reg.callback(msg.Method)
}
if callb == nil {
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
}
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()})
}
return h.runMethod(cp.ctx, msg, callb, args)
}
// handleSubscribe processes *_subscribe method calls.
func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
if !h.allowSubscribe {
return msg.errorResponse(ErrNotificationsUnsupported)
}
// Subscription method name is first argument.
name, err := parseSubscriptionName(msg.Params)
if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()})
}
namespace := msg.namespace()
callb := h.reg.subscription(namespace, name)
if callb == nil {
return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
}
// Parse subscription name arg too, but remove it before calling the callback.
argTypes := append([]reflect.Type{stringType}, callb.argTypes...)
args, err := parsePositionalArguments(msg.Params, argTypes)
if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()})
}
args = args[1:]
// Install notifier in context so the subscription handler can find it.
n := &Notifier{h: h, namespace: namespace}
cp.notifiers = append(cp.notifiers, n)
ctx := context.WithValue(cp.ctx, notifierKey{}, n)
return h.runMethod(ctx, msg, callb, args)
}
// runMethod runs the Go callback for an RPC method.
func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage {
result, err := callb.call(ctx, msg.Method, args)
if err != nil {
return msg.errorResponse(err)
}
return msg.response(result)
}
// unsubscribe is the callback function for all *_unsubscribe calls.
func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) {
h.subLock.Lock()
defer h.subLock.Unlock()
s := h.serverSubs[id]
if s == nil {
return false, ErrSubscriptionNotFound
}
close(s.err)
delete(h.serverSubs, id)
return true, nil
}
type idForLog struct{ json.RawMessage }
func (id idForLog) String() string {
if s, err := strconv.Unquote(string(id.RawMessage)); err == nil {
return s
}
return string(id.RawMessage)
}

@ -37,38 +37,39 @@ import (
const ( const (
maxRequestContentLength = 1024 * 512 maxRequestContentLength = 1024 * 512
contentType = "application/json"
) )
var (
// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 // https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13
acceptedContentTypes = []string{"application/json", "application/json-rpc", "application/jsonrequest"} var acceptedContentTypes = []string{contentType, "application/json-rpc", "application/jsonrequest"}
contentType = acceptedContentTypes[0]
nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
)
type httpConn struct { type httpConn struct {
client *http.Client client *http.Client
req *http.Request req *http.Request
closeOnce sync.Once closeOnce sync.Once
closed chan struct{} closed chan interface{}
} }
// httpConn is treated specially by Client. // httpConn is treated specially by Client.
func (hc *httpConn) LocalAddr() net.Addr { return nullAddr } func (hc *httpConn) Write(context.Context, interface{}) error {
func (hc *httpConn) RemoteAddr() net.Addr { return nullAddr } panic("Write called on httpConn")
func (hc *httpConn) SetReadDeadline(time.Time) error { return nil }
func (hc *httpConn) SetWriteDeadline(time.Time) error { return nil }
func (hc *httpConn) SetDeadline(time.Time) error { return nil }
func (hc *httpConn) Write([]byte) (int, error) { panic("Write called") }
func (hc *httpConn) Read(b []byte) (int, error) {
<-hc.closed
return 0, io.EOF
} }
func (hc *httpConn) Close() error { func (hc *httpConn) RemoteAddr() string {
return hc.req.URL.String()
}
func (hc *httpConn) Read() ([]*jsonrpcMessage, bool, error) {
<-hc.closed
return nil, false, io.EOF
}
func (hc *httpConn) Close() {
hc.closeOnce.Do(func() { close(hc.closed) }) hc.closeOnce.Do(func() { close(hc.closed) })
return nil }
func (hc *httpConn) Closed() <-chan interface{} {
return hc.closed
} }
// HTTPTimeouts represents the configuration params for the HTTP RPC server. // HTTPTimeouts represents the configuration params for the HTTP RPC server.
@ -114,8 +115,8 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
req.Header.Set("Accept", contentType) req.Header.Set("Accept", contentType)
initctx := context.Background() initctx := context.Background()
return newClient(initctx, func(context.Context) (net.Conn, error) { return newClient(initctx, func(context.Context) (ServerCodec, error) {
return &httpConn{client: client, req: req, closed: make(chan struct{})}, nil return &httpConn{client: client, req: req, closed: make(chan interface{})}, nil
}) })
} }
@ -184,17 +185,30 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
return resp.Body, nil return resp.Body, nil
} }
// httpReadWriteNopCloser wraps a io.Reader and io.Writer with a NOP Close method. // httpServerConn turns a HTTP connection into a Conn.
type httpReadWriteNopCloser struct { type httpServerConn struct {
io.Reader io.Reader
io.Writer io.Writer
r *http.Request
} }
// Close does nothing and returns always nil func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
func (t *httpReadWriteNopCloser) Close() error { body := io.LimitReader(r.Body, maxRequestContentLength)
return nil conn := &httpServerConn{Reader: body, Writer: w, r: r}
return NewJSONCodec(conn)
} }
// Close does nothing and always returns nil.
func (t *httpServerConn) Close() error { return nil }
// RemoteAddr returns the peer address of the underlying connection.
func (t *httpServerConn) RemoteAddr() string {
return t.r.RemoteAddr
}
// SetWriteDeadline does nothing and always returns nil.
func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil }
// NewHTTPServer creates a new HTTP RPC server around an API provider. // NewHTTPServer creates a new HTTP RPC server around an API provider.
// //
// Deprecated: Server implements http.Handler // Deprecated: Server implements http.Handler
@ -226,7 +240,7 @@ func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv ht
} }
// ServeHTTP serves JSON-RPC requests over HTTP. // ServeHTTP serves JSON-RPC requests over HTTP.
func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Permit dumb empty requests for remote health-checks (AWS) // Permit dumb empty requests for remote health-checks (AWS)
if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" { if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" {
return return
@ -249,12 +263,10 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx = context.WithValue(ctx, "Origin", origin) ctx = context.WithValue(ctx, "Origin", origin)
} }
body := io.LimitReader(r.Body, maxRequestContentLength)
codec := NewJSONCodec(&httpReadWriteNopCloser{body, w})
defer codec.Close()
w.Header().Set("content-type", contentType) w.Header().Set("content-type", contentType)
srv.ServeSingleRequest(ctx, codec, OptionMethodInvocation) codec := newHTTPServerConn(r, w)
defer codec.Close()
s.serveSingleRequest(ctx, codec)
} }
// validateRequest returns a non-zero response code and error message if the // validateRequest returns a non-zero response code and error message if the

@ -24,10 +24,10 @@ import (
// DialInProc attaches an in-process connection to the given RPC server. // DialInProc attaches an in-process connection to the given RPC server.
func DialInProc(handler *Server) *Client { func DialInProc(handler *Server) *Client {
initctx := context.Background() initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (net.Conn, error) { c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions)
return p2, nil return NewJSONCodec(p2), nil
}) })
return c return c
} }

@ -25,17 +25,17 @@ import (
) )
// ServeListener accepts connections on l, serving JSON-RPC on them. // ServeListener accepts connections on l, serving JSON-RPC on them.
func (srv *Server) ServeListener(l net.Listener) error { func (s *Server) ServeListener(l net.Listener) error {
for { for {
conn, err := l.Accept() conn, err := l.Accept()
if netutil.IsTemporaryError(err) { if netutil.IsTemporaryError(err) {
log.Warn("IPC accept error", "err", err) log.Warn("RPC accept error", "err", err)
continue continue
} else if err != nil { } else if err != nil {
return err return err
} }
log.Trace("IPC accepted connection") log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr())
go srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) go s.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
} }
} }
@ -46,7 +46,11 @@ func (srv *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not // The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client. // affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) { func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
return newClient(ctx, func(ctx context.Context) (net.Conn, error) { return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
return newIPCConnection(ctx, endpoint) conn, err := newIPCConnection(ctx, endpoint)
if err != nil {
return nil, err
}
return NewJSONCodec(conn), err
}) })
} }

@ -18,36 +18,104 @@ package rpc
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/ethereum/go-ethereum/log"
) )
const ( const (
jsonrpcVersion = "2.0" vsn = "2.0"
serviceMethodSeparator = "_" serviceMethodSeparator = "_"
subscribeMethodSuffix = "_subscribe" subscribeMethodSuffix = "_subscribe"
unsubscribeMethodSuffix = "_unsubscribe" unsubscribeMethodSuffix = "_unsubscribe"
notificationMethodSuffix = "_subscription" notificationMethodSuffix = "_subscription"
defaultWriteTimeout = 10 * time.Second // used if context has no deadline
) )
type jsonRequest struct { var null = json.RawMessage("null")
Method string `json:"method"`
Version string `json:"jsonrpc"` type subscriptionResult struct {
Id json.RawMessage `json:"id,omitempty"` ID string `json:"subscription"`
Payload json.RawMessage `json:"params,omitempty"` Result json.RawMessage `json:"result,omitempty"`
} }
type jsonSuccessResponse struct { // A value of this type can a JSON-RPC request, notification, successful response or
Version string `json:"jsonrpc"` // error response. Which one it is depends on the fields.
Id interface{} `json:"id,omitempty"` type jsonrpcMessage struct {
Result interface{} `json:"result"` Version string `json:"jsonrpc,omitempty"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Error *jsonError `json:"error,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
func (msg *jsonrpcMessage) isNotification() bool {
return msg.ID == nil && msg.Method != ""
}
func (msg *jsonrpcMessage) isCall() bool {
return msg.hasValidID() && msg.Method != ""
}
func (msg *jsonrpcMessage) isResponse() bool {
return msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil)
}
func (msg *jsonrpcMessage) hasValidID() bool {
return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '['
}
func (msg *jsonrpcMessage) isSubscribe() bool {
return strings.HasSuffix(msg.Method, subscribeMethodSuffix)
}
func (msg *jsonrpcMessage) isUnsubscribe() bool {
return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix)
}
func (msg *jsonrpcMessage) namespace() string {
elem := strings.SplitN(msg.Method, serviceMethodSeparator, 2)
return elem[0]
}
func (msg *jsonrpcMessage) String() string {
b, _ := json.Marshal(msg)
return string(b)
}
func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage {
resp := errorMessage(err)
resp.ID = msg.ID
return resp
}
func (msg *jsonrpcMessage) response(result interface{}) *jsonrpcMessage {
enc, err := json.Marshal(result)
if err != nil {
// TODO: wrap with 'internal server error'
return msg.errorResponse(err)
}
return &jsonrpcMessage{Version: vsn, ID: msg.ID, Result: enc}
}
func errorMessage(err error) *jsonrpcMessage {
msg := &jsonrpcMessage{Version: vsn, ID: null, Error: &jsonError{
Code: defaultErrorCode,
Message: err.Error(),
}}
ec, ok := err.(Error)
if ok {
msg.Error.Code = ec.ErrorCode()
}
return msg
} }
type jsonError struct { type jsonError struct {
@ -56,35 +124,6 @@ type jsonError struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
type jsonErrResponse struct {
Version string `json:"jsonrpc"`
Id interface{} `json:"id,omitempty"`
Error jsonError `json:"error"`
}
type jsonSubscription struct {
Subscription string `json:"subscription"`
Result interface{} `json:"result,omitempty"`
}
type jsonNotification struct {
Version string `json:"jsonrpc"`
Method string `json:"method"`
Params jsonSubscription `json:"params"`
}
// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It
// also has support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
closer sync.Once // close closed channel once
closed chan interface{} // closed on Close
decMu sync.Mutex // guards the decoder
decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports
rw io.ReadWriteCloser // connection
}
func (err *jsonError) Error() string { func (err *jsonError) Error() string {
if err.Message == "" { if err.Message == "" {
return fmt.Sprintf("json-rpc error %d", err.Code) return fmt.Sprintf("json-rpc error %d", err.Code)
@ -96,34 +135,126 @@ func (err *jsonError) ErrorCode() int {
return err.Code return err.Code
} }
// Conn is a subset of the methods of net.Conn which are sufficient for ServerCodec.
type Conn interface {
io.ReadWriteCloser
SetWriteDeadline(time.Time) error
}
// ConnRemoteAddr wraps the RemoteAddr operation, which returns a description
// of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this
// description is used in log messages.
type ConnRemoteAddr interface {
RemoteAddr() string
}
// connWithRemoteAddr overrides the remote address of a connection.
type connWithRemoteAddr struct {
Conn
addr string
}
func (c connWithRemoteAddr) RemoteAddr() string { return c.addr }
// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has
// support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
remoteAddr string
closer sync.Once // close closed channel once
closed chan interface{} // closed on Close
decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports
conn Conn
}
// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
// on explicitly given encoding and decoding methods. // on explicitly given encoding and decoding methods.
func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec { func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec {
return &jsonCodec{ codec := &jsonCodec{
closed: make(chan interface{}), closed: make(chan interface{}),
encode: encode, encode: encode,
decode: decode, decode: decode,
rw: rwc, conn: conn,
} }
if ra, ok := conn.(ConnRemoteAddr); ok {
codec.remoteAddr = ra.RemoteAddr()
}
return codec
} }
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0. // NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec { func NewJSONCodec(conn Conn) ServerCodec {
enc := json.NewEncoder(rwc) enc := json.NewEncoder(conn)
dec := json.NewDecoder(rwc) dec := json.NewDecoder(conn)
dec.UseNumber() dec.UseNumber()
return NewCodec(conn, enc.Encode, dec.Decode)
return &jsonCodec{
closed: make(chan interface{}),
encode: enc.Encode,
decode: dec.Decode,
rw: rwc,
} }
func (c *jsonCodec) RemoteAddr() string {
return c.remoteAddr
}
func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) {
// Decode the next JSON object in the input stream.
// This verifies basic syntax, etc.
var rawmsg json.RawMessage
if err := c.decode(&rawmsg); err != nil {
return nil, false, err
}
msg, batch = parseMessage(rawmsg)
return msg, batch, nil
}
// Write sends a message to client.
func (c *jsonCodec) Write(ctx context.Context, v interface{}) error {
c.encMu.Lock()
defer c.encMu.Unlock()
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultWriteTimeout)
}
c.conn.SetWriteDeadline(deadline)
return c.encode(v)
}
// Close the underlying connection
func (c *jsonCodec) Close() {
c.closer.Do(func() {
close(c.closed)
c.conn.Close()
})
}
// Closed returns a channel which will be closed when Close is called
func (c *jsonCodec) Closed() <-chan interface{} {
return c.closed
}
// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error
// checks in this function because the raw message has already been syntax-checked when it
// is called. Any non-JSON-RPC messages in the input return the zero value of
// jsonrpcMessage.
func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) {
if !isBatch(raw) {
msgs := []*jsonrpcMessage{{}}
json.Unmarshal(raw, &msgs[0])
return msgs, false
}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.Token() // skip '['
var msgs []*jsonrpcMessage
for dec.More() {
msgs = append(msgs, new(jsonrpcMessage))
dec.Decode(&msgs[len(msgs)-1])
}
return msgs, true
} }
// isBatch returns true when the first non-whitespace characters is '[' // isBatch returns true when the first non-whitespace characters is '['
func isBatch(msg json.RawMessage) bool { func isBatch(raw json.RawMessage) bool {
for _, c := range msg { for _, c := range raw {
// skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt) // skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt)
if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d { if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d {
continue continue
@ -133,231 +264,67 @@ func isBatch(msg json.RawMessage) bool {
return false return false
} }
// ReadRequestHeaders will read new requests without parsing the arguments. It will
// return a collection of requests, an indication if these requests are in batch
// form or an error when the incoming message could not be read/parsed.
func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) {
c.decMu.Lock()
defer c.decMu.Unlock()
var incomingMsg json.RawMessage
if err := c.decode(&incomingMsg); err != nil {
return nil, false, &invalidRequestError{err.Error()}
}
if isBatch(incomingMsg) {
return parseBatchRequest(incomingMsg)
}
return parseRequest(incomingMsg)
}
// checkReqId returns an error when the given reqId isn't valid for RPC method calls.
// valid id's are strings, numbers or null
func checkReqId(reqId json.RawMessage) error {
if len(reqId) == 0 {
return fmt.Errorf("missing request id")
}
if _, err := strconv.ParseFloat(string(reqId), 64); err == nil {
return nil
}
var str string
if err := json.Unmarshal(reqId, &str); err == nil {
return nil
}
return fmt.Errorf("invalid request id")
}
// parseRequest will parse a single request from the given RawMessage. It will return
// the parsed request, an indication if the request was a batch or an error when
// the request could not be parsed.
func parseRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) {
var in jsonRequest
if err := json.Unmarshal(incomingMsg, &in); err != nil {
return nil, false, &invalidMessageError{err.Error()}
}
if err := checkReqId(in.Id); err != nil {
return nil, false, &invalidMessageError{err.Error()}
}
// subscribe are special, they will always use `subscribeMethod` as first param in the payload
if strings.HasSuffix(in.Method, subscribeMethodSuffix) {
reqs := []rpcRequest{{id: &in.Id, isPubSub: true}}
if len(in.Payload) > 0 {
// first param must be subscription name
var subscribeMethod [1]string
if err := json.Unmarshal(in.Payload, &subscribeMethod); err != nil {
log.Debug(fmt.Sprintf("Unable to parse subscription method: %v\n", err))
return nil, false, &invalidRequestError{"Unable to parse subscription request"}
}
reqs[0].service, reqs[0].method = strings.TrimSuffix(in.Method, subscribeMethodSuffix), subscribeMethod[0]
reqs[0].params = in.Payload
return reqs, false, nil
}
return nil, false, &invalidRequestError{"Unable to parse subscription request"}
}
if strings.HasSuffix(in.Method, unsubscribeMethodSuffix) {
return []rpcRequest{{id: &in.Id, isPubSub: true,
method: in.Method, params: in.Payload}}, false, nil
}
elems := strings.Split(in.Method, serviceMethodSeparator)
if len(elems) != 2 {
return nil, false, &methodNotFoundError{in.Method, ""}
}
// regular RPC call
if len(in.Payload) == 0 {
return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id}}, false, nil
}
return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id, params: in.Payload}}, false, nil
}
// parseBatchRequest will parse a batch request into a collection of requests from the given RawMessage, an indication
// if the request was a batch or an error when the request could not be read.
func parseBatchRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) {
var in []jsonRequest
if err := json.Unmarshal(incomingMsg, &in); err != nil {
return nil, false, &invalidMessageError{err.Error()}
}
requests := make([]rpcRequest, len(in))
for i, r := range in {
if err := checkReqId(r.Id); err != nil {
return nil, false, &invalidMessageError{err.Error()}
}
id := &in[i].Id
// subscribe are special, they will always use `subscriptionMethod` as first param in the payload
if strings.HasSuffix(r.Method, subscribeMethodSuffix) {
requests[i] = rpcRequest{id: id, isPubSub: true}
if len(r.Payload) > 0 {
// first param must be subscription name
var subscribeMethod [1]string
if err := json.Unmarshal(r.Payload, &subscribeMethod); err != nil {
log.Debug(fmt.Sprintf("Unable to parse subscription method: %v\n", err))
return nil, false, &invalidRequestError{"Unable to parse subscription request"}
}
requests[i].service, requests[i].method = strings.TrimSuffix(r.Method, subscribeMethodSuffix), subscribeMethod[0]
requests[i].params = r.Payload
continue
}
return nil, true, &invalidRequestError{"Unable to parse (un)subscribe request arguments"}
}
if strings.HasSuffix(r.Method, unsubscribeMethodSuffix) {
requests[i] = rpcRequest{id: id, isPubSub: true, method: r.Method, params: r.Payload}
continue
}
if len(r.Payload) == 0 {
requests[i] = rpcRequest{id: id, params: nil}
} else {
requests[i] = rpcRequest{id: id, params: r.Payload}
}
if elem := strings.Split(r.Method, serviceMethodSeparator); len(elem) == 2 {
requests[i].service, requests[i].method = elem[0], elem[1]
} else {
requests[i].err = &methodNotFoundError{r.Method, ""}
}
}
return requests, true, nil
}
// ParseRequestArguments tries to parse the given params (json.RawMessage) with the given
// types. It returns the parsed values or an error when the parsing failed.
func (c *jsonCodec) ParseRequestArguments(argTypes []reflect.Type, params interface{}) ([]reflect.Value, Error) {
if args, ok := params.(json.RawMessage); !ok {
return nil, &invalidParamsError{"Invalid params supplied"}
} else {
return parsePositionalArguments(args, argTypes)
}
}
// parsePositionalArguments tries to parse the given args to an array of values with the // parsePositionalArguments tries to parse the given args to an array of values with the
// given types. It returns the parsed values or an error when the args could not be // given types. It returns the parsed values or an error when the args could not be
// parsed. Missing optional arguments are returned as reflect.Zero values. // parsed. Missing optional arguments are returned as reflect.Zero values.
func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, Error) { func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) {
// Read beginning of the args array.
dec := json.NewDecoder(bytes.NewReader(rawArgs)) dec := json.NewDecoder(bytes.NewReader(rawArgs))
if tok, _ := dec.Token(); tok != json.Delim('[') { var args []reflect.Value
return nil, &invalidParamsError{"non-array args"} tok, err := dec.Token()
switch {
case err == io.EOF || tok == nil && err == nil:
// "params" is optional and may be empty. Also allow "params":null even though it's
// not in the spec because our own client used to send it.
case err != nil:
return nil, err
case tok == json.Delim('['):
// Read argument array.
if args, err = parseArgumentArray(dec, types); err != nil {
return nil, err
} }
// Read args. default:
args := make([]reflect.Value, 0, len(types)) return nil, errors.New("non-array args")
for i := 0; dec.More(); i++ {
if i >= len(types) {
return nil, &invalidParamsError{fmt.Sprintf("too many arguments, want at most %d", len(types))}
}
argval := reflect.New(types[i])
if err := dec.Decode(argval.Interface()); err != nil {
return nil, &invalidParamsError{fmt.Sprintf("invalid argument %d: %v", i, err)}
}
if argval.IsNil() && types[i].Kind() != reflect.Ptr {
return nil, &invalidParamsError{fmt.Sprintf("missing value for required argument %d", i)}
}
args = append(args, argval.Elem())
}
// Read end of args array.
if _, err := dec.Token(); err != nil {
return nil, &invalidParamsError{err.Error()}
} }
// Set any missing args to nil. // Set any missing args to nil.
for i := len(args); i < len(types); i++ { for i := len(args); i < len(types); i++ {
if types[i].Kind() != reflect.Ptr { if types[i].Kind() != reflect.Ptr {
return nil, &invalidParamsError{fmt.Sprintf("missing value for required argument %d", i)} return nil, fmt.Errorf("missing value for required argument %d", i)
} }
args = append(args, reflect.Zero(types[i])) args = append(args, reflect.Zero(types[i]))
} }
return args, nil return args, nil
} }
// CreateResponse will create a JSON-RPC success response with the given id and reply as result. func parseArgumentArray(dec *json.Decoder, types []reflect.Type) ([]reflect.Value, error) {
func (c *jsonCodec) CreateResponse(id interface{}, reply interface{}) interface{} { args := make([]reflect.Value, 0, len(types))
return &jsonSuccessResponse{Version: jsonrpcVersion, Id: id, Result: reply} for i := 0; dec.More(); i++ {
if i >= len(types) {
return args, fmt.Errorf("too many arguments, want at most %d", len(types))
}
argval := reflect.New(types[i])
if err := dec.Decode(argval.Interface()); err != nil {
return args, fmt.Errorf("invalid argument %d: %v", i, err)
}
if argval.IsNil() && types[i].Kind() != reflect.Ptr {
return args, fmt.Errorf("missing value for required argument %d", i)
}
args = append(args, argval.Elem())
}
// Read end of args array.
_, err := dec.Token()
return args, err
} }
// CreateErrorResponse will create a JSON-RPC error response with the given id and error. // parseSubscriptionName extracts the subscription name from an encoded argument array.
func (c *jsonCodec) CreateErrorResponse(id interface{}, err Error) interface{} { func parseSubscriptionName(rawArgs json.RawMessage) (string, error) {
return &jsonErrResponse{Version: jsonrpcVersion, Id: id, Error: jsonError{Code: err.ErrorCode(), Message: err.Error()}} dec := json.NewDecoder(bytes.NewReader(rawArgs))
if tok, _ := dec.Token(); tok != json.Delim('[') {
return "", errors.New("non-array args")
} }
v, _ := dec.Token()
// CreateErrorResponseWithInfo will create a JSON-RPC error response with the given id and error. method, ok := v.(string)
// info is optional and contains additional information about the error. When an empty string is passed it is ignored. if !ok {
func (c *jsonCodec) CreateErrorResponseWithInfo(id interface{}, err Error, info interface{}) interface{} { return "", errors.New("expected subscription name as first argument")
return &jsonErrResponse{Version: jsonrpcVersion, Id: id,
Error: jsonError{Code: err.ErrorCode(), Message: err.Error(), Data: info}}
} }
return method, nil
// CreateNotification will create a JSON-RPC notification with the given subscription id and event as params.
func (c *jsonCodec) CreateNotification(subid, namespace string, event interface{}) interface{} {
return &jsonNotification{Version: jsonrpcVersion, Method: namespace + notificationMethodSuffix,
Params: jsonSubscription{Subscription: subid, Result: event}}
}
// Write message to client
func (c *jsonCodec) Write(res interface{}) error {
c.encMu.Lock()
defer c.encMu.Unlock()
return c.encode(res)
}
// Close the underlying connection
func (c *jsonCodec) Close() {
c.closer.Do(func() {
close(c.closed)
c.rw.Close()
})
}
// Closed returns a channel which will be closed when Close is called
func (c *jsonCodec) Closed() <-chan interface{} {
return c.closed
} }

@ -1,178 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"bufio"
"bytes"
"encoding/json"
"reflect"
"strconv"
"testing"
)
type RWC struct {
*bufio.ReadWriter
}
func (rwc *RWC) Close() error {
return nil
}
func TestJSONRequestParsing(t *testing.T) {
server := NewServer()
service := new(Service)
if err := server.RegisterName("calc", service); err != nil {
t.Fatalf("%v", err)
}
req := bytes.NewBufferString(`{"id": 1234, "jsonrpc": "2.0", "method": "calc_add", "params": [11, 22]}`)
var str string
reply := bytes.NewBufferString(str)
rw := &RWC{bufio.NewReadWriter(bufio.NewReader(req), bufio.NewWriter(reply))}
codec := NewJSONCodec(rw)
requests, batch, err := codec.ReadRequestHeaders()
if err != nil {
t.Fatalf("%v", err)
}
if batch {
t.Fatalf("Request isn't a batch")
}
if len(requests) != 1 {
t.Fatalf("Expected 1 request but got %d requests - %v", len(requests), requests)
}
if requests[0].service != "calc" {
t.Fatalf("Expected service 'calc' but got '%s'", requests[0].service)
}
if requests[0].method != "add" {
t.Fatalf("Expected method 'Add' but got '%s'", requests[0].method)
}
if rawId, ok := requests[0].id.(*json.RawMessage); ok {
id, e := strconv.ParseInt(string(*rawId), 0, 64)
if e != nil {
t.Fatalf("%v", e)
}
if id != 1234 {
t.Fatalf("Expected id 1234 but got %d", id)
}
} else {
t.Fatalf("invalid request, expected *json.RawMesage got %T", requests[0].id)
}
var arg int
args := []reflect.Type{reflect.TypeOf(arg), reflect.TypeOf(arg)}
v, err := codec.ParseRequestArguments(args, requests[0].params)
if err != nil {
t.Fatalf("%v", err)
}
if len(v) != 2 {
t.Fatalf("Expected 2 argument values, got %d", len(v))
}
if v[0].Int() != 11 || v[1].Int() != 22 {
t.Fatalf("expected %d == 11 && %d == 22", v[0].Int(), v[1].Int())
}
}
func TestJSONRequestParamsParsing(t *testing.T) {
var (
stringT = reflect.TypeOf("")
intT = reflect.TypeOf(0)
intPtrT = reflect.TypeOf(new(int))
stringV = reflect.ValueOf("abc")
i = 1
intV = reflect.ValueOf(i)
intPtrV = reflect.ValueOf(&i)
)
var validTests = []struct {
input string
argTypes []reflect.Type
expected []reflect.Value
}{
{`[]`, []reflect.Type{}, []reflect.Value{}},
{`[]`, []reflect.Type{intPtrT}, []reflect.Value{intPtrV}},
{`[1]`, []reflect.Type{intT}, []reflect.Value{intV}},
{`[1,"abc"]`, []reflect.Type{intT, stringT}, []reflect.Value{intV, stringV}},
{`[null]`, []reflect.Type{intPtrT}, []reflect.Value{intPtrV}},
{`[null,"abc"]`, []reflect.Type{intPtrT, stringT, intPtrT}, []reflect.Value{intPtrV, stringV, intPtrV}},
{`[null,"abc",null]`, []reflect.Type{intPtrT, stringT, intPtrT}, []reflect.Value{intPtrV, stringV, intPtrV}},
}
codec := jsonCodec{}
for _, test := range validTests {
params := (json.RawMessage)([]byte(test.input))
args, err := codec.ParseRequestArguments(test.argTypes, params)
if err != nil {
t.Fatal(err)
}
var match []interface{}
json.Unmarshal([]byte(test.input), &match)
if len(args) != len(test.argTypes) {
t.Fatalf("expected %d parsed args, got %d", len(test.argTypes), len(args))
}
for i, arg := range args {
expected := test.expected[i]
if arg.Kind() != expected.Kind() {
t.Errorf("expected type for param %d in %s", i, test.input)
}
if arg.Kind() == reflect.Int && arg.Int() != expected.Int() {
t.Errorf("expected int(%d), got int(%d) in %s", expected.Int(), arg.Int(), test.input)
}
if arg.Kind() == reflect.String && arg.String() != expected.String() {
t.Errorf("expected string(%s), got string(%s) in %s", expected.String(), arg.String(), test.input)
}
}
}
var invalidTests = []struct {
input string
argTypes []reflect.Type
}{
{`[]`, []reflect.Type{intT}},
{`[null]`, []reflect.Type{intT}},
{`[1]`, []reflect.Type{stringT}},
{`[1,2]`, []reflect.Type{stringT}},
{`["abc", null]`, []reflect.Type{stringT, intT}},
}
for i, test := range invalidTests {
if _, err := codec.ParseRequestArguments(test.argTypes, test.input); err == nil {
t.Errorf("expected test %d - %s to fail", i, test.input)
}
}
}

@ -18,11 +18,7 @@ package rpc
import ( import (
"context" "context"
"fmt" "io"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic" "sync/atomic"
mapset "github.com/deckarep/golang-set" mapset "github.com/deckarep/golang-set"
@ -31,7 +27,9 @@ import (
const MetadataApi = "rpc" const MetadataApi = "rpc"
// CodecOption specifies which type of messages this codec supports // CodecOption specifies which type of messages a codec supports.
//
// Deprecated: this option is no longer honored by Server.
type CodecOption int type CodecOption int
const ( const (
@ -42,22 +40,94 @@ const (
OptionSubscriptions = 1 << iota // support pub sub OptionSubscriptions = 1 << iota // support pub sub
) )
// NewServer will create a new server instance with no registered handlers. // Server is an RPC server.
func NewServer() *Server { type Server struct {
server := &Server{ services serviceRegistry
services: make(serviceRegistry), idgen func() ID
codecs: mapset.NewSet(), run int32
run: 1, codecs mapset.Set
} }
// register a default service which will provide meta information about the RPC service such as the services and // NewServer creates a new server instance with no registered handlers.
// methods it offers. func NewServer() *Server {
server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1}
// Register the default service providing meta information about the RPC service such
// as the services and methods it offers.
rpcService := &RPCService{server} rpcService := &RPCService{server}
server.RegisterName(MetadataApi, rpcService) server.RegisterName(MetadataApi, rpcService)
return server return server
} }
// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
// service collection this server provides to clients.
func (s *Server) RegisterName(name string, receiver interface{}) error {
return s.services.registerName(name, receiver)
}
// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed or the
// server is stopped. In either case the codec is closed.
//
// Note that codec options are no longer supported.
func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
defer codec.Close()
// Don't serve if server is stopped.
if atomic.LoadInt32(&s.run) == 0 {
return
}
// Add the codec to the set so it can be closed by Stop.
s.codecs.Add(codec)
defer s.codecs.Remove(codec)
c := initClient(codec, s.idgen, &s.services)
<-codec.Closed()
c.Close()
}
// serveSingleRequest reads and processes a single RPC request from the given codec. This
// is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in
// this mode.
func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
// Don't serve if server is stopped.
if atomic.LoadInt32(&s.run) == 0 {
return
}
h := newHandler(ctx, codec, s.idgen, &s.services)
h.allowSubscribe = false
defer h.close(io.EOF, nil)
reqs, batch, err := codec.Read()
if err != nil {
if err != io.EOF {
codec.Write(ctx, errorMessage(&invalidMessageError{"parse error"}))
}
return
}
if batch {
h.handleBatch(reqs)
} else {
h.handleMsg(reqs[0])
}
}
// Stop stops reading new requests, waits for stopPendingRequestTimeout to allow pending
// requests to finish, then closes all codecs which will cancel pending requests and
// subscriptions.
func (s *Server) Stop() {
if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
log.Debug("RPC server shutting down")
s.codecs.Each(func(c interface{}) bool {
c.(ServerCodec).Close()
return true
})
}
}
// RPCService gives meta information about the server. // RPCService gives meta information about the server.
// e.g. gives information about the loaded modules. // e.g. gives information about the loaded modules.
type RPCService struct { type RPCService struct {
@ -66,377 +136,12 @@ type RPCService struct {
// Modules returns the list of RPC services with their version number // Modules returns the list of RPC services with their version number
func (s *RPCService) Modules() map[string]string { func (s *RPCService) Modules() map[string]string {
s.server.services.mu.Lock()
defer s.server.services.mu.Unlock()
modules := make(map[string]string) modules := make(map[string]string)
for name := range s.server.services { for name := range s.server.services.services {
modules[name] = "1.0" modules[name] = "1.0"
} }
return modules return modules
} }
// RegisterName will create a service for the given rcvr type under the given name. When no methods on the given rcvr
// match the criteria to be either a RPC method or a subscription an error is returned. Otherwise a new service is
// created and added to the service collection this server instance serves.
func (s *Server) RegisterName(name string, rcvr interface{}) error {
if s.services == nil {
s.services = make(serviceRegistry)
}
svc := new(service)
svc.typ = reflect.TypeOf(rcvr)
rcvrVal := reflect.ValueOf(rcvr)
if name == "" {
return fmt.Errorf("no service name for type %s", svc.typ.String())
}
if !isExported(reflect.Indirect(rcvrVal).Type().Name()) {
return fmt.Errorf("%s is not exported", reflect.Indirect(rcvrVal).Type().Name())
}
methods, subscriptions := suitableCallbacks(rcvrVal, svc.typ)
if len(methods) == 0 && len(subscriptions) == 0 {
return fmt.Errorf("Service %T doesn't have any suitable methods/subscriptions to expose", rcvr)
}
// already a previous service register under given name, merge methods/subscriptions
if regsvc, present := s.services[name]; present {
for _, m := range methods {
regsvc.callbacks[formatName(m.method.Name)] = m
}
for _, s := range subscriptions {
regsvc.subscriptions[formatName(s.method.Name)] = s
}
return nil
}
svc.name = name
svc.callbacks, svc.subscriptions = methods, subscriptions
s.services[svc.name] = svc
return nil
}
// serveRequest will reads requests from the codec, calls the RPC callback and
// writes the response to the given codec.
//
// If singleShot is true it will process a single request, otherwise it will handle
// requests until the codec returns an error when reading a request (in most cases
// an EOF). It executes requests in parallel when singleShot is false.
func (s *Server) serveRequest(ctx context.Context, codec ServerCodec, singleShot bool, options CodecOption) error {
var pend sync.WaitGroup
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Error(string(buf))
}
s.codecsMu.Lock()
s.codecs.Remove(codec)
s.codecsMu.Unlock()
}()
// ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// if the codec supports notification include a notifier that callbacks can use
// to send notification to clients. It is tied to the codec/connection. If the
// connection is closed the notifier will stop and cancels all active subscriptions.
if options&OptionSubscriptions == OptionSubscriptions {
ctx = context.WithValue(ctx, notifierKey{}, newNotifier(codec))
}
s.codecsMu.Lock()
if atomic.LoadInt32(&s.run) != 1 { // server stopped
s.codecsMu.Unlock()
return &shutdownError{}
}
s.codecs.Add(codec)
s.codecsMu.Unlock()
// test if the server is ordered to stop
for atomic.LoadInt32(&s.run) == 1 {
reqs, batch, err := s.readRequest(codec)
if err != nil {
// If a parsing error occurred, send an error
if err.Error() != "EOF" {
log.Debug(fmt.Sprintf("read error %v\n", err))
codec.Write(codec.CreateErrorResponse(nil, err))
}
// Error or end of stream, wait for requests and tear down
pend.Wait()
return nil
}
// check if server is ordered to shutdown and return an error
// telling the client that his request failed.
if atomic.LoadInt32(&s.run) != 1 {
err = &shutdownError{}
if batch {
resps := make([]interface{}, len(reqs))
for i, r := range reqs {
resps[i] = codec.CreateErrorResponse(&r.id, err)
}
codec.Write(resps)
} else {
codec.Write(codec.CreateErrorResponse(&reqs[0].id, err))
}
return nil
}
// If a single shot request is executing, run and return immediately
if singleShot {
if batch {
s.execBatch(ctx, codec, reqs)
} else {
s.exec(ctx, codec, reqs[0])
}
return nil
}
// For multi-shot connections, start a goroutine to serve and loop back
pend.Add(1)
go func(reqs []*serverRequest, batch bool) {
defer pend.Done()
if batch {
s.execBatch(ctx, codec, reqs)
} else {
s.exec(ctx, codec, reqs[0])
}
}(reqs, batch)
}
return nil
}
// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes the
// response back using the given codec. It will block until the codec is closed or the server is
// stopped. In either case the codec is closed.
func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
defer codec.Close()
s.serveRequest(context.Background(), codec, false, options)
}
// ServeSingleRequest reads and processes a single RPC request from the given codec. It will not
// close the codec unless a non-recoverable error has occurred. Note, this method will return after
// a single request has been processed!
func (s *Server) ServeSingleRequest(ctx context.Context, codec ServerCodec, options CodecOption) {
s.serveRequest(ctx, codec, true, options)
}
// Stop will stop reading new requests, wait for stopPendingRequestTimeout to allow pending requests to finish,
// close all codecs which will cancel pending requests/subscriptions.
func (s *Server) Stop() {
if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
log.Debug("RPC Server shutdown initiatied")
s.codecsMu.Lock()
defer s.codecsMu.Unlock()
s.codecs.Each(func(c interface{}) bool {
c.(ServerCodec).Close()
return true
})
}
}
// createSubscription will call the subscription callback and returns the subscription id or error.
func (s *Server) createSubscription(ctx context.Context, c ServerCodec, req *serverRequest) (ID, error) {
// subscription have as first argument the context following optional arguments
args := []reflect.Value{req.callb.rcvr, reflect.ValueOf(ctx)}
args = append(args, req.args...)
reply := req.callb.method.Func.Call(args)
if !reply[1].IsNil() { // subscription creation failed
return "", reply[1].Interface().(error)
}
return reply[0].Interface().(*Subscription).ID, nil
}
// handle executes a request and returns the response from the callback.
func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) (interface{}, func()) {
if req.err != nil {
return codec.CreateErrorResponse(&req.id, req.err), nil
}
if req.isUnsubscribe { // cancel subscription, first param must be the subscription id
if len(req.args) >= 1 && req.args[0].Kind() == reflect.String {
notifier, supported := NotifierFromContext(ctx)
if !supported { // interface doesn't support subscriptions (e.g. http)
return codec.CreateErrorResponse(&req.id, &callbackError{ErrNotificationsUnsupported.Error()}), nil
}
subid := ID(req.args[0].String())
if err := notifier.unsubscribe(subid); err != nil {
return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil
}
return codec.CreateResponse(req.id, true), nil
}
return codec.CreateErrorResponse(&req.id, &invalidParamsError{"Expected subscription id as first argument"}), nil
}
if req.callb.isSubscribe {
subid, err := s.createSubscription(ctx, codec, req)
if err != nil {
return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil
}
// active the subscription after the sub id was successfully sent to the client
activateSub := func() {
notifier, _ := NotifierFromContext(ctx)
notifier.activate(subid, req.svcname)
}
return codec.CreateResponse(req.id, subid), activateSub
}
// regular RPC call, prepare arguments
if len(req.args) != len(req.callb.argTypes) {
rpcErr := &invalidParamsError{fmt.Sprintf("%s%s%s expects %d parameters, got %d",
req.svcname, serviceMethodSeparator, req.callb.method.Name,
len(req.callb.argTypes), len(req.args))}
return codec.CreateErrorResponse(&req.id, rpcErr), nil
}
arguments := []reflect.Value{req.callb.rcvr}
if req.callb.hasCtx {
arguments = append(arguments, reflect.ValueOf(ctx))
}
if len(req.args) > 0 {
arguments = append(arguments, req.args...)
}
// execute RPC method and return result
reply := req.callb.method.Func.Call(arguments)
if len(reply) == 0 {
return codec.CreateResponse(req.id, nil), nil
}
if req.callb.errPos >= 0 { // test if method returned an error
if !reply[req.callb.errPos].IsNil() {
e := reply[req.callb.errPos].Interface().(error)
res := codec.CreateErrorResponse(&req.id, &callbackError{e.Error()})
return res, nil
}
}
return codec.CreateResponse(req.id, reply[0].Interface()), nil
}
// exec executes the given request and writes the result back using the codec.
func (s *Server) exec(ctx context.Context, codec ServerCodec, req *serverRequest) {
var response interface{}
var callback func()
if req.err != nil {
response = codec.CreateErrorResponse(&req.id, req.err)
} else {
response, callback = s.handle(ctx, codec, req)
}
if err := codec.Write(response); err != nil {
log.Error(fmt.Sprintf("%v\n", err))
codec.Close()
}
// when request was a subscribe request this allows these subscriptions to be actived
if callback != nil {
callback()
}
}
// execBatch executes the given requests and writes the result back using the codec.
// It will only write the response back when the last request is processed.
func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*serverRequest) {
responses := make([]interface{}, len(requests))
var callbacks []func()
for i, req := range requests {
if req.err != nil {
responses[i] = codec.CreateErrorResponse(&req.id, req.err)
} else {
var callback func()
if responses[i], callback = s.handle(ctx, codec, req); callback != nil {
callbacks = append(callbacks, callback)
}
}
}
if err := codec.Write(responses); err != nil {
log.Error(fmt.Sprintf("%v\n", err))
codec.Close()
}
// when request holds one of more subscribe requests this allows these subscriptions to be activated
for _, c := range callbacks {
c()
}
}
// readRequest requests the next (batch) request from the codec. It will return the collection
// of requests, an indication if the request was a batch, the invalid request identifier and an
// error when the request could not be read/parsed.
func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, Error) {
reqs, batch, err := codec.ReadRequestHeaders()
if err != nil {
return nil, batch, err
}
requests := make([]*serverRequest, len(reqs))
// verify requests
for i, r := range reqs {
var ok bool
var svc *service
if r.err != nil {
requests[i] = &serverRequest{id: r.id, err: r.err}
continue
}
if r.isPubSub && strings.HasSuffix(r.method, unsubscribeMethodSuffix) {
requests[i] = &serverRequest{id: r.id, isUnsubscribe: true}
argTypes := []reflect.Type{reflect.TypeOf("")} // expect subscription id as first arg
if args, err := codec.ParseRequestArguments(argTypes, r.params); err == nil {
requests[i].args = args
} else {
requests[i].err = &invalidParamsError{err.Error()}
}
continue
}
if svc, ok = s.services[r.service]; !ok { // rpc method isn't available
requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}}
continue
}
if r.isPubSub { // eth_subscribe, r.method contains the subscription method name
if callb, ok := svc.subscriptions[r.method]; ok {
requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb}
if r.params != nil && len(callb.argTypes) > 0 {
argTypes := []reflect.Type{reflect.TypeOf("")}
argTypes = append(argTypes, callb.argTypes...)
if args, err := codec.ParseRequestArguments(argTypes, r.params); err == nil {
requests[i].args = args[1:] // first one is service.method name which isn't an actual argument
} else {
requests[i].err = &invalidParamsError{err.Error()}
}
}
} else {
requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}}
}
continue
}
if callb, ok := svc.callbacks[r.method]; ok { // lookup RPC method
requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb}
if r.params != nil && len(callb.argTypes) > 0 {
if args, err := codec.ParseRequestArguments(callb.argTypes, r.params); err == nil {
requests[i].args = args
} else {
requests[i].err = &invalidParamsError{err.Error()}
}
}
continue
}
requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}}
}
return requests, batch, nil
}

@ -17,146 +17,136 @@
package rpc package rpc
import ( import (
"context" "bufio"
"encoding/json" "bytes"
"io"
"io/ioutil"
"net" "net"
"reflect" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
) )
type Service struct{}
type Args struct {
S string
}
func (s *Service) NoArgsRets() {
}
type Result struct {
String string
Int int
Args *Args
}
func (s *Service) Echo(str string, i int, args *Args) Result {
return Result{str, i, args}
}
func (s *Service) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result {
return Result{str, i, args}
}
func (s *Service) Sleep(ctx context.Context, duration time.Duration) {
select {
case <-time.After(duration):
case <-ctx.Done():
}
}
func (s *Service) Rets() (string, error) {
return "", nil
}
func (s *Service) InvalidRets1() (error, string) {
return nil, ""
}
func (s *Service) InvalidRets2() (string, string) {
return "", ""
}
func (s *Service) InvalidRets3() (string, string, error) {
return "", "", nil
}
func (s *Service) Subscription(ctx context.Context) (*Subscription, error) {
return nil, nil
}
func TestServerRegisterName(t *testing.T) { func TestServerRegisterName(t *testing.T) {
server := NewServer() server := NewServer()
service := new(Service) service := new(testService)
if err := server.RegisterName("calc", service); err != nil {
t.Fatalf("%v", err)
}
if len(server.services) != 2 {
t.Fatalf("Expected 2 service entries, got %d", len(server.services))
}
svc, ok := server.services["calc"]
if !ok {
t.Fatalf("Expected service calc to be registered")
}
if len(svc.callbacks) != 5 {
t.Errorf("Expected 5 callbacks for service 'calc', got %d", len(svc.callbacks))
}
if len(svc.subscriptions) != 1 {
t.Errorf("Expected 1 subscription for service 'calc', got %d", len(svc.subscriptions))
}
}
func testServerMethodExecution(t *testing.T, method string) {
server := NewServer()
service := new(Service)
if err := server.RegisterName("test", service); err != nil { if err := server.RegisterName("test", service); err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }
stringArg := "string arg" if len(server.services.services) != 2 {
intArg := 1122 t.Fatalf("Expected 2 service entries, got %d", len(server.services.services))
argsArg := &Args{"abcde"} }
params := []interface{}{stringArg, intArg, argsArg}
request := map[string]interface{}{ svc, ok := server.services.services["test"]
"id": 12345, if !ok {
"method": "test_" + method, t.Fatalf("Expected service calc to be registered")
"version": "2.0", }
"params": params,
wantCallbacks := 7
if len(svc.callbacks) != wantCallbacks {
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
}
}
func TestServer(t *testing.T) {
files, err := ioutil.ReadDir("testdata")
if err != nil {
t.Fatal("where'd my testdata go?")
}
for _, f := range files {
if f.IsDir() || strings.HasPrefix(f.Name(), ".") {
continue
}
path := filepath.Join("testdata", f.Name())
name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))
t.Run(name, func(t *testing.T) {
runTestScript(t, path)
})
}
}
func runTestScript(t *testing.T, file string) {
server := newTestServer()
content, err := ioutil.ReadFile(file)
if err != nil {
t.Fatal(err)
} }
clientConn, serverConn := net.Pipe() clientConn, serverConn := net.Pipe()
defer clientConn.Close() defer clientConn.Close()
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation) readbuf := bufio.NewReader(clientConn)
for _, line := range strings.Split(string(content), "\n") {
out := json.NewEncoder(clientConn) line = strings.TrimSpace(line)
in := json.NewDecoder(clientConn) switch {
case len(line) == 0 || strings.HasPrefix(line, "//"):
if err := out.Encode(request); err != nil { // skip comments, blank lines
t.Fatal(err) continue
case strings.HasPrefix(line, "--> "):
t.Log(line)
// write to connection
clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil {
t.Fatalf("write error: %v", err)
} }
case strings.HasPrefix(line, "<-- "):
response := jsonSuccessResponse{Result: &Result{}} t.Log(line)
if err := in.Decode(&response); err != nil { want := line[4:]
t.Fatal(err) // read line from connection and compare text
clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
sent, err := readbuf.ReadString('\n')
if err != nil {
t.Fatalf("read error: %v", err)
} }
sent = strings.TrimRight(sent, "\r\n")
if result, ok := response.Result.(*Result); ok { if sent != want {
if result.String != stringArg { t.Errorf("wrong line from server\ngot: %s\nwant: %s", sent, want)
t.Errorf("expected %s, got : %s\n", stringArg, result.String)
} }
if result.Int != intArg { default:
t.Errorf("expected %d, got %d\n", intArg, result.Int) panic("invalid line in test script: " + line)
} }
if !reflect.DeepEqual(result.Args, argsArg) {
t.Errorf("expected %v, got %v\n", argsArg, result)
}
} else {
t.Fatalf("invalid response: expected *Result - got: %T", response.Result)
} }
} }
func TestServerMethodExecution(t *testing.T) { // This test checks that responses are delivered for very short-lived connections that
testServerMethodExecution(t, "echo") // only carry a single request.
} func TestServerShortLivedConn(t *testing.T) {
server := newTestServer()
defer server.Stop()
func TestServerMethodWithCtx(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0")
testServerMethodExecution(t, "echoWithCtx") if err != nil {
t.Fatal("can't listen:", err)
}
defer listener.Close()
go server.ServeListener(listener)
var (
request = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n"
wantResp = `{"jsonrpc":"2.0","id":1,"result":{"nftest":"1.0","rpc":"1.0","test":"1.0"}}` + "\n"
deadline = time.Now().Add(10 * time.Second)
)
for i := 0; i < 20; i++ {
conn, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
t.Fatal("can't dial:", err)
}
defer conn.Close()
conn.SetDeadline(deadline)
// Write the request, then half-close the connection so the server stops reading.
conn.Write([]byte(request))
conn.(*net.TCPConn).CloseWrite()
// Now try to get the response.
buf := make([]byte, 2000)
n, err := conn.Read(buf)
if err != nil {
t.Fatal("read error:", err)
}
if !bytes.Equal(buf[:n], []byte(wantResp)) {
t.Fatalf("wrong response: %s", buf[:n])
}
}
} }

285
rpc/service.go Normal file

@ -0,0 +1,285 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"context"
"errors"
"fmt"
"reflect"
"runtime"
"strings"
"sync"
"unicode"
"unicode/utf8"
"github.com/ethereum/go-ethereum/log"
)
var (
contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
errorType = reflect.TypeOf((*error)(nil)).Elem()
subscriptionType = reflect.TypeOf(Subscription{})
stringType = reflect.TypeOf("")
)
type serviceRegistry struct {
mu sync.Mutex
services map[string]service
}
// service represents a registered object.
type service struct {
name string // name for service
callbacks map[string]*callback // registered handlers
subscriptions map[string]*callback // available subscriptions/notifications
}
// callback is a method callback which was registered in the server
type callback struct {
fn reflect.Value // the function
rcvr reflect.Value // receiver object of method, set if fn is method
argTypes []reflect.Type // input argument types
hasCtx bool // method's first argument is a context (not included in argTypes)
errPos int // err return idx, of -1 when method cannot return error
isSubscribe bool // true if this is a subscription callback
}
func (r *serviceRegistry) registerName(name string, rcvr interface{}) error {
rcvrVal := reflect.ValueOf(rcvr)
if name == "" {
return fmt.Errorf("no service name for type %s", rcvrVal.Type().String())
}
callbacks := suitableCallbacks(rcvrVal)
if len(callbacks) == 0 {
return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr)
}
r.mu.Lock()
defer r.mu.Unlock()
if r.services == nil {
r.services = make(map[string]service)
}
svc, ok := r.services[name]
if !ok {
svc = service{
name: name,
callbacks: make(map[string]*callback),
subscriptions: make(map[string]*callback),
}
r.services[name] = svc
}
for name, cb := range callbacks {
if cb.isSubscribe {
svc.subscriptions[name] = cb
} else {
svc.callbacks[name] = cb
}
}
return nil
}
// callback returns the callback corresponding to the given RPC method name.
func (r *serviceRegistry) callback(method string) *callback {
elem := strings.SplitN(method, serviceMethodSeparator, 2)
if len(elem) != 2 {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
return r.services[elem[0]].callbacks[elem[1]]
}
// subscription returns a subscription callback in the given service.
func (r *serviceRegistry) subscription(service, name string) *callback {
r.mu.Lock()
defer r.mu.Unlock()
return r.services[service].subscriptions[name]
}
// suitableCallbacks iterates over the methods of the given type. It determines if a method
// satisfies the criteria for a RPC callback or a subscription callback and adds it to the
// collection of callbacks. See server documentation for a summary of these criteria.
func suitableCallbacks(receiver reflect.Value) map[string]*callback {
typ := receiver.Type()
callbacks := make(map[string]*callback)
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
if method.PkgPath != "" {
continue // method not exported
}
cb := newCallback(receiver, method.Func)
if cb == nil {
continue // function invalid
}
name := formatName(method.Name)
callbacks[name] = cb
}
return callbacks
}
// newCallback turns fn (a function) into a callback object. It returns nil if the function
// is unsuitable as an RPC callback.
func newCallback(receiver, fn reflect.Value) *callback {
fntype := fn.Type()
c := &callback{fn: fn, rcvr: receiver, errPos: -1, isSubscribe: isPubSub(fntype)}
// Determine parameter types. They must all be exported or builtin types.
c.makeArgTypes()
if !allExportedOrBuiltin(c.argTypes) {
return nil
}
// Verify return types. The function must return at most one error
// and/or one other non-error value.
outs := make([]reflect.Type, fntype.NumOut())
for i := 0; i < fntype.NumOut(); i++ {
outs[i] = fntype.Out(i)
}
if len(outs) > 2 || !allExportedOrBuiltin(outs) {
return nil
}
// If an error is returned, it must be the last returned value.
switch {
case len(outs) == 1 && isErrorType(outs[0]):
c.errPos = 0
case len(outs) == 2:
if isErrorType(outs[0]) || !isErrorType(outs[1]) {
return nil
}
c.errPos = 1
}
return c
}
// makeArgTypes composes the argTypes list.
func (c *callback) makeArgTypes() {
fntype := c.fn.Type()
// Skip receiver and context.Context parameter (if present).
firstArg := 0
if c.rcvr.IsValid() {
firstArg++
}
if fntype.NumIn() > firstArg && fntype.In(firstArg) == contextType {
c.hasCtx = true
firstArg++
}
// Add all remaining parameters.
c.argTypes = make([]reflect.Type, fntype.NumIn()-firstArg)
for i := firstArg; i < fntype.NumIn(); i++ {
c.argTypes[i-firstArg] = fntype.In(i)
}
}
// call invokes the callback.
func (c *callback) call(ctx context.Context, method string, args []reflect.Value) (res interface{}, errRes error) {
// Create the argument slice.
fullargs := make([]reflect.Value, 0, 2+len(args))
if c.rcvr.IsValid() {
fullargs = append(fullargs, c.rcvr)
}
if c.hasCtx {
fullargs = append(fullargs, reflect.ValueOf(ctx))
}
fullargs = append(fullargs, args...)
// Catch panic while running the callback.
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Error("RPC method " + method + " crashed: " + fmt.Sprintf("%v\n%s", err, buf))
errRes = errors.New("method handler crashed")
}
}()
// Run the callback.
results := c.fn.Call(fullargs)
if len(results) == 0 {
return nil, nil
}
if c.errPos >= 0 && !results[c.errPos].IsNil() {
// Method has returned non-nil error value.
err := results[c.errPos].Interface().(error)
return reflect.Value{}, err
}
return results[0].Interface(), nil
}
// Is this an exported - upper case - name?
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// Are all those types exported or built-in?
func allExportedOrBuiltin(types []reflect.Type) bool {
for _, typ := range types {
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
// PkgPath will be non-empty even for an exported type,
// so we need to check the type name as well.
if !isExported(typ.Name()) && typ.PkgPath() != "" {
return false
}
}
return true
}
// Is t context.Context or *context.Context?
func isContextType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t == contextType
}
// Does t satisfy the error interface?
func isErrorType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Implements(errorType)
}
// Is t Subscription or *Subscription?
func isSubscriptionType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t == subscriptionType
}
// isPubSub tests whether the given method has as as first argument a context.Context and
// returns the pair (Subscription, error).
func isPubSub(methodType reflect.Type) bool {
// numIn(0) is the receiver type
if methodType.NumIn() < 2 || methodType.NumOut() != 2 {
return false
}
return isContextType(methodType.In(1)) &&
isSubscriptionType(methodType.Out(0)) &&
isErrorType(methodType.Out(1))
}
// formatName converts to first character of name to lowercase.
func formatName(name string) string {
ret := []rune(name)
if len(ret) > 0 {
ret[0] = unicode.ToLower(ret[0])
}
return string(ret)
}

@ -26,8 +26,8 @@ import (
// DialStdIO creates a client on stdin/stdout. // DialStdIO creates a client on stdin/stdout.
func DialStdIO(ctx context.Context) (*Client, error) { func DialStdIO(ctx context.Context) (*Client, error) {
return newClient(ctx, func(_ context.Context) (net.Conn, error) { return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
return stdioConn{}, nil return NewJSONCodec(stdioConn{}), nil
}) })
} }
@ -45,20 +45,8 @@ func (io stdioConn) Close() error {
return nil return nil
} }
func (io stdioConn) LocalAddr() net.Addr { func (io stdioConn) RemoteAddr() string {
return &net.UnixAddr{Name: "stdio", Net: "stdio"} return "/dev/stdin"
}
func (io stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "stdio"}
}
func (io stdioConn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "stdio", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (io stdioConn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "stdio", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }
func (io stdioConn) SetWriteDeadline(t time.Time) error { func (io stdioConn) SetWriteDeadline(t time.Time) error {

@ -17,9 +17,19 @@
package rpc package rpc
import ( import (
"bufio"
"container/list"
"context" "context"
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors" "errors"
"math/rand"
"reflect"
"strings"
"sync" "sync"
"time"
) )
var ( var (
@ -29,10 +39,147 @@ var (
ErrSubscriptionNotFound = errors.New("subscription not found") ErrSubscriptionNotFound = errors.New("subscription not found")
) )
var globalGen = randomIDGenerator()
// ID defines a pseudo random number that is used to identify RPC subscriptions. // ID defines a pseudo random number that is used to identify RPC subscriptions.
type ID string type ID string
// a Subscription is created by a notifier and tight to that notifier. The client can use // NewID returns a new, random ID.
func NewID() ID {
return globalGen()
}
// randomIDGenerator returns a function generates a random IDs.
func randomIDGenerator() func() ID {
seed, err := binary.ReadVarint(bufio.NewReader(crand.Reader))
if err != nil {
seed = int64(time.Now().Nanosecond())
}
var (
mu sync.Mutex
rng = rand.New(rand.NewSource(seed))
)
return func() ID {
mu.Lock()
defer mu.Unlock()
id := make([]byte, 16)
rng.Read(id)
return encodeID(id)
}
}
func encodeID(b []byte) ID {
id := hex.EncodeToString(b)
id = strings.TrimLeft(id, "0")
if id == "" {
id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0.
}
return ID("0x" + id)
}
type notifierKey struct{}
// NotifierFromContext returns the Notifier value stored in ctx, if any.
func NotifierFromContext(ctx context.Context) (*Notifier, bool) {
n, ok := ctx.Value(notifierKey{}).(*Notifier)
return n, ok
}
// Notifier is tied to a RPC connection that supports subscriptions.
// Server callbacks use the notifier to send notifications.
type Notifier struct {
h *handler
namespace string
mu sync.Mutex
sub *Subscription
buffer []json.RawMessage
callReturned bool
activated bool
}
// CreateSubscription returns a new subscription that is coupled to the
// RPC connection. By default subscriptions are inactive and notifications
// are dropped until the subscription is marked as active. This is done
// by the RPC server after the subscription ID is send to the client.
func (n *Notifier) CreateSubscription() *Subscription {
n.mu.Lock()
defer n.mu.Unlock()
if n.sub != nil {
panic("can't create multiple subscriptions with Notifier")
} else if n.callReturned {
panic("can't create subscription after subscribe call has returned")
}
n.sub = &Subscription{ID: n.h.idgen(), namespace: n.namespace, err: make(chan error, 1)}
return n.sub
}
// Notify sends a notification to the client with the given data as payload.
// If an error occurs the RPC connection is closed and the error is returned.
func (n *Notifier) Notify(id ID, data interface{}) error {
enc, err := json.Marshal(data)
if err != nil {
return err
}
n.mu.Lock()
defer n.mu.Unlock()
if n.sub == nil {
panic("can't Notify before subscription is created")
} else if n.sub.ID != id {
panic("Notify with wrong ID")
}
if n.activated {
return n.send(n.sub, enc)
}
n.buffer = append(n.buffer, enc)
return nil
}
// Closed returns a channel that is closed when the RPC connection is closed.
// Deprecated: use subscription error channel
func (n *Notifier) Closed() <-chan interface{} {
return n.h.conn.Closed()
}
// takeSubscription returns the subscription (if one has been created). No subscription can
// be created after this call.
func (n *Notifier) takeSubscription() *Subscription {
n.mu.Lock()
defer n.mu.Unlock()
n.callReturned = true
return n.sub
}
// acticate is called after the subscription ID was sent to client. Notifications are
// buffered before activation. This prevents notifications being sent to the client before
// the subscription ID is sent to the client.
func (n *Notifier) activate() error {
n.mu.Lock()
defer n.mu.Unlock()
for _, data := range n.buffer {
if err := n.send(n.sub, data); err != nil {
return err
}
}
n.activated = true
return nil
}
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
ctx := context.Background()
return n.h.conn.Write(ctx, &jsonrpcMessage{
Version: vsn,
Method: n.namespace + notificationMethodSuffix,
Params: params,
})
}
// A Subscription is created by a notifier and tight to that notifier. The client can use
// this subscription to wait for an unsubscribe request for the client, see Err(). // this subscription to wait for an unsubscribe request for the client, see Err().
type Subscription struct { type Subscription struct {
ID ID ID ID
@ -45,105 +192,136 @@ func (s *Subscription) Err() <-chan error {
return s.err return s.err
} }
// notifierKey is used to store a notifier within the connection context. // MarshalJSON marshals a subscription as its ID.
type notifierKey struct{} func (s *Subscription) MarshalJSON() ([]byte, error) {
return json.Marshal(s.ID)
// Notifier is tight to a RPC connection that supports subscriptions.
// Server callbacks use the notifier to send notifications.
type Notifier struct {
codec ServerCodec
subMu sync.Mutex
active map[ID]*Subscription
inactive map[ID]*Subscription
buffer map[ID][]interface{} // unsent notifications of inactive subscriptions
} }
// newNotifier creates a new notifier that can be used to send subscription // ClientSubscription is a subscription established through the Client's Subscribe or
// notifications to the client. // EthSubscribe methods.
func newNotifier(codec ServerCodec) *Notifier { type ClientSubscription struct {
return &Notifier{ client *Client
codec: codec, etype reflect.Type
active: make(map[ID]*Subscription), channel reflect.Value
inactive: make(map[ID]*Subscription), namespace string
buffer: make(map[ID][]interface{}), subid string
} in chan json.RawMessage
quitOnce sync.Once // ensures quit is closed once
quit chan struct{} // quit is closed when the subscription exits
errOnce sync.Once // ensures err is closed once
err chan error
} }
// NotifierFromContext returns the Notifier value stored in ctx, if any. func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription {
func NotifierFromContext(ctx context.Context) (*Notifier, bool) { sub := &ClientSubscription{
n, ok := ctx.Value(notifierKey{}).(*Notifier) client: c,
return n, ok namespace: namespace,
etype: channel.Type().Elem(),
channel: channel,
quit: make(chan struct{}),
err: make(chan error, 1),
in: make(chan json.RawMessage),
}
return sub
} }
// CreateSubscription returns a new subscription that is coupled to the // Err returns the subscription error channel. The intended use of Err is to schedule
// RPC connection. By default subscriptions are inactive and notifications // resubscription when the client connection is closed unexpectedly.
// are dropped until the subscription is marked as active. This is done //
// by the RPC server after the subscription ID is send to the client. // The error channel receives a value when the subscription has ended due
func (n *Notifier) CreateSubscription() *Subscription { // to an error. The received error is nil if Close has been called
s := &Subscription{ID: NewID(), err: make(chan error)} // on the underlying client and no other error has occurred.
n.subMu.Lock() //
n.inactive[s.ID] = s // The error channel is closed when Unsubscribe is called on the subscription.
n.subMu.Unlock() func (sub *ClientSubscription) Err() <-chan error {
return s return sub.err
} }
// Notify sends a notification to the client with the given data as payload. // Unsubscribe unsubscribes the notification and closes the error channel.
// If an error occurs the RPC connection is closed and the error is returned. // It can safely be called more than once.
func (n *Notifier) Notify(id ID, data interface{}) error { func (sub *ClientSubscription) Unsubscribe() {
n.subMu.Lock() sub.quitWithError(nil, true)
defer n.subMu.Unlock() sub.errOnce.Do(func() { close(sub.err) })
if sub, active := n.active[id]; active {
n.send(sub, data)
} else {
n.buffer[id] = append(n.buffer[id], data)
}
return nil
} }
func (n *Notifier) send(sub *Subscription, data interface{}) error { func (sub *ClientSubscription) quitWithError(err error, unsubscribeServer bool) {
notification := n.codec.CreateNotification(string(sub.ID), sub.namespace, data) sub.quitOnce.Do(func() {
err := n.codec.Write(notification) // The dispatch loop won't be able to execute the unsubscribe call
// if it is blocked on deliver. Close sub.quit first because it
// unblocks deliver.
close(sub.quit)
if unsubscribeServer {
sub.requestUnsubscribe()
}
if err != nil { if err != nil {
n.codec.Close() if err == ErrClientQuit {
err = nil // Adhere to subscription semantics.
} }
return err sub.err <- err
}
})
} }
// Closed returns a channel that is closed when the RPC connection is closed. func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) {
func (n *Notifier) Closed() <-chan interface{} { select {
return n.codec.Closed() case sub.in <- result:
return true
case <-sub.quit:
return false
}
} }
// unsubscribe a subscription. func (sub *ClientSubscription) start() {
// If the subscription could not be found ErrSubscriptionNotFound is returned. sub.quitWithError(sub.forward())
func (n *Notifier) unsubscribe(id ID) error {
n.subMu.Lock()
defer n.subMu.Unlock()
if s, found := n.active[id]; found {
close(s.err)
delete(n.active, id)
return nil
}
return ErrSubscriptionNotFound
} }
// activate enables a subscription. Until a subscription is enabled all func (sub *ClientSubscription) forward() (err error, unsubscribeServer bool) {
// notifications are dropped. This method is called by the RPC server after cases := []reflect.SelectCase{
// the subscription ID was sent to client. This prevents notifications being {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)},
// send to the client before the subscription ID is send to the client. {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)},
func (n *Notifier) activate(id ID, namespace string) { {Dir: reflect.SelectSend, Chan: sub.channel},
n.subMu.Lock() }
defer n.subMu.Unlock() buffer := list.New()
defer buffer.Init()
for {
var chosen int
var recv reflect.Value
if buffer.Len() == 0 {
// Idle, omit send case.
chosen, recv, _ = reflect.Select(cases[:2])
} else {
// Non-empty buffer, send the first queued item.
cases[2].Send = reflect.ValueOf(buffer.Front().Value)
chosen, recv, _ = reflect.Select(cases)
}
if sub, found := n.inactive[id]; found { switch chosen {
sub.namespace = namespace case 0: // <-sub.quit
n.active[id] = sub return nil, false
delete(n.inactive, id) case 1: // <-sub.in
// Send buffered notifications. val, err := sub.unmarshal(recv.Interface().(json.RawMessage))
for _, data := range n.buffer[id] { if err != nil {
n.send(sub, data) return err, true
} }
delete(n.buffer, id) if buffer.Len() == maxClientSubscriptionBuffer {
return ErrSubscriptionQueueOverflow, true
}
buffer.PushBack(val)
case 2: // sub.channel<-
cases[2].Send = reflect.Value{} // Don't hold onto the value.
buffer.Remove(buffer.Front())
} }
} }
}
func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) {
val := reflect.New(sub.etype)
err := json.Unmarshal(result, val.Interface())
return val.Elem().Interface(), err
}
func (sub *ClientSubscription) requestUnsubscribe() error {
var result interface{}
return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid)
}

@ -17,232 +17,62 @@
package rpc package rpc
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"sync" "strings"
"testing" "testing"
"time" "time"
) )
type NotificationTestService struct { func TestNewID(t *testing.T) {
mu sync.Mutex hexchars := "0123456789ABCDEFabcdef"
unsubscribed chan string for i := 0; i < 100; i++ {
gotHangSubscriptionReq chan struct{} id := string(NewID())
unblockHangSubscription chan struct{} if !strings.HasPrefix(id, "0x") {
t.Fatalf("invalid ID prefix, want '0x...', got %s", id)
} }
func (s *NotificationTestService) Echo(i int) int { id = id[2:]
return i if len(id) == 0 || len(id) > 32 {
t.Fatalf("invalid ID length, want len(id) > 0 && len(id) <= 32), got %d", len(id))
} }
func (s *NotificationTestService) Unsubscribe(subid string) { for i := 0; i < len(id); i++ {
if s.unsubscribed != nil { if strings.IndexByte(hexchars, id[i]) == -1 {
s.unsubscribed <- subid t.Fatalf("unexpected byte, want any valid hex char, got %c", id[i])
} }
} }
func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) {
notifier, supported := NotifierFromContext(ctx)
if !supported {
return nil, ErrNotificationsUnsupported
}
// by explicitly creating an subscription we make sure that the subscription id is send back to the client
// before the first subscription.Notify is called. Otherwise the events might be send before the response
// for the eth_subscribe method.
subscription := notifier.CreateSubscription()
go func() {
// test expects n events, if we begin sending event immediately some events
// will probably be dropped since the subscription ID might not be send to
// the client.
for i := 0; i < n; i++ {
if err := notifier.Notify(subscription.ID, val+i); err != nil {
return
}
}
select {
case <-notifier.Closed():
case <-subscription.Err():
}
if s.unsubscribed != nil {
s.unsubscribed <- string(subscription.ID)
}
}()
return subscription, nil
}
// HangSubscription blocks on s.unblockHangSubscription before
// sending anything.
func (s *NotificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) {
notifier, supported := NotifierFromContext(ctx)
if !supported {
return nil, ErrNotificationsUnsupported
}
s.gotHangSubscriptionReq <- struct{}{}
<-s.unblockHangSubscription
subscription := notifier.CreateSubscription()
go func() {
notifier.Notify(subscription.ID, val)
}()
return subscription, nil
}
func TestNotifications(t *testing.T) {
server := NewServer()
service := &NotificationTestService{unsubscribed: make(chan string)}
if err := server.RegisterName("eth", service); err != nil {
t.Fatalf("unable to register test service %v", err)
}
clientConn, serverConn := net.Pipe()
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
out := json.NewEncoder(clientConn)
in := json.NewDecoder(clientConn)
n := 5
val := 12345
request := map[string]interface{}{
"id": 1,
"method": "eth_subscribe",
"version": "2.0",
"params": []interface{}{"someSubscription", n, val},
}
// create subscription
if err := out.Encode(request); err != nil {
t.Fatal(err)
}
var subid string
response := jsonSuccessResponse{Result: subid}
if err := in.Decode(&response); err != nil {
t.Fatal(err)
}
var ok bool
if _, ok = response.Result.(string); !ok {
t.Fatalf("expected subscription id, got %T", response.Result)
}
for i := 0; i < n; i++ {
var notification jsonNotification
if err := in.Decode(&notification); err != nil {
t.Fatalf("%v", err)
}
if int(notification.Params.Result.(float64)) != val+i {
t.Fatalf("expected %d, got %d", val+i, notification.Params.Result)
}
}
clientConn.Close() // causes notification unsubscribe callback to be called
select {
case <-service.unsubscribed:
case <-time.After(1 * time.Second):
t.Fatal("Unsubscribe not called after one second")
}
}
func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSuccessResponse,
failures chan<- jsonErrResponse, notifications chan<- jsonNotification, errors chan<- error) {
// read and parse server messages
for {
var rmsg json.RawMessage
if err := in.Decode(&rmsg); err != nil {
return
}
var responses []map[string]interface{}
if rmsg[0] == '[' {
if err := json.Unmarshal(rmsg, &responses); err != nil {
errors <- fmt.Errorf("Received invalid message: %s", rmsg)
return
}
} else {
var msg map[string]interface{}
if err := json.Unmarshal(rmsg, &msg); err != nil {
errors <- fmt.Errorf("Received invalid message: %s", rmsg)
return
}
responses = append(responses, msg)
}
for _, msg := range responses {
// determine what kind of msg was received and broadcast
// it to over the corresponding channel
if _, found := msg["result"]; found {
successes <- jsonSuccessResponse{
Version: msg["jsonrpc"].(string),
Id: msg["id"],
Result: msg["result"],
}
continue
}
if _, found := msg["error"]; found {
params := msg["params"].(map[string]interface{})
failures <- jsonErrResponse{
Version: msg["jsonrpc"].(string),
Id: msg["id"],
Error: jsonError{int(params["subscription"].(float64)), params["message"].(string), params["data"]},
}
continue
}
if _, found := msg["params"]; found {
params := msg["params"].(map[string]interface{})
notifications <- jsonNotification{
Version: msg["jsonrpc"].(string),
Method: msg["method"].(string),
Params: jsonSubscription{params["subscription"].(string), params["result"]},
}
continue
}
errors <- fmt.Errorf("Received invalid message: %s", msg)
}
} }
} }
// TestSubscriptionMultipleNamespaces ensures that subscriptions can exists func TestSubscriptions(t *testing.T) {
// for multiple different namespaces.
func TestSubscriptionMultipleNamespaces(t *testing.T) {
var ( var (
namespaces = []string{"eth", "shh", "bzz"} namespaces = []string{"eth", "shh", "bzz"}
service = NotificationTestService{} service = &notificationTestService{}
subCount = len(namespaces) * 2 subCount = len(namespaces)
notificationCount = 3 notificationCount = 3
server = NewServer() server = NewServer()
clientConn, serverConn = net.Pipe() clientConn, serverConn = net.Pipe()
out = json.NewEncoder(clientConn) out = json.NewEncoder(clientConn)
in = json.NewDecoder(clientConn) in = json.NewDecoder(clientConn)
successes = make(chan jsonSuccessResponse) successes = make(chan subConfirmation)
failures = make(chan jsonErrResponse) notifications = make(chan subscriptionResult)
notifications = make(chan jsonNotification) errors = make(chan error, subCount*notificationCount+1)
errors = make(chan error, 10)
) )
// setup and start server // setup and start server
for _, namespace := range namespaces { for _, namespace := range namespaces {
if err := server.RegisterName(namespace, &service); err != nil { if err := server.RegisterName(namespace, service); err != nil {
t.Fatalf("unable to register test service %v", err) t.Fatalf("unable to register test service %v", err)
} }
} }
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
defer server.Stop() defer server.Stop()
// wait for message and write them to the given channels // wait for message and write them to the given channels
go waitForMessages(t, in, successes, failures, notifications, errors) go waitForMessages(in, successes, notifications, errors)
// create subscriptions one by one // create subscriptions one by one
for i, namespace := range namespaces { for i, namespace := range namespaces {
@ -252,27 +82,11 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
"version": "2.0", "version": "2.0",
"params": []interface{}{"someSubscription", notificationCount, i}, "params": []interface{}{"someSubscription", notificationCount, i},
} }
if err := out.Encode(&request); err != nil { if err := out.Encode(&request); err != nil {
t.Fatalf("Could not create subscription: %v", err) t.Fatalf("Could not create subscription: %v", err)
} }
} }
// create all subscriptions in 1 batch
var requests []interface{}
for i, namespace := range namespaces {
requests = append(requests, map[string]interface{}{
"id": i,
"method": fmt.Sprintf("%s_subscribe", namespace),
"version": "2.0",
"params": []interface{}{"someSubscription", notificationCount, i},
})
}
if err := out.Encode(&requests); err != nil {
t.Fatalf("Could not create subscription in batch form: %v", err)
}
timeout := time.After(30 * time.Second) timeout := time.After(30 * time.Second)
subids := make(map[string]string, subCount) subids := make(map[string]string, subCount)
count := make(map[string]int, subCount) count := make(map[string]int, subCount)
@ -285,17 +99,14 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
} }
return done return done
} }
for !allReceived() { for !allReceived() {
select { select {
case suc := <-successes: // subscription created case confirmation := <-successes: // subscription created
subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string) subids[namespaces[confirmation.reqid]] = string(confirmation.subid)
case notification := <-notifications: case notification := <-notifications:
count[notification.Params.Subscription]++ count[notification.ID]++
case err := <-errors: case err := <-errors:
t.Fatal(err) t.Fatal(err)
case failure := <-failures:
t.Errorf("received error: %v", failure.Error)
case <-timeout: case <-timeout:
for _, namespace := range namespaces { for _, namespace := range namespaces {
subid, found := subids[namespace] subid, found := subids[namespace]
@ -311,3 +122,85 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) {
} }
} }
} }
// This test checks that unsubscribing works.
func TestServerUnsubscribe(t *testing.T) {
// Start the server.
server := newTestServer()
service := &notificationTestService{unsubscribed: make(chan string)}
server.RegisterName("nftest2", service)
p1, p2 := net.Pipe()
go server.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions)
p2.SetDeadline(time.Now().Add(10 * time.Second))
// Subscribe.
p2.Write([]byte(`{"jsonrpc":"2.0","id":1,"method":"nftest2_subscribe","params":["someSubscription",0,10]}`))
// Handle received messages.
resps := make(chan subConfirmation)
notifications := make(chan subscriptionResult)
errors := make(chan error)
go waitForMessages(json.NewDecoder(p2), resps, notifications, errors)
// Receive the subscription ID.
var sub subConfirmation
select {
case sub = <-resps:
case err := <-errors:
t.Fatal(err)
}
// Unsubscribe and check that it is handled on the server side.
p2.Write([]byte(`{"jsonrpc":"2.0","method":"nftest2_unsubscribe","params":["` + sub.subid + `"]}`))
for {
select {
case id := <-service.unsubscribed:
if id != string(sub.subid) {
t.Errorf("wrong subscription ID unsubscribed")
}
return
case err := <-errors:
t.Fatal(err)
case <-notifications:
// drop notifications
}
}
}
type subConfirmation struct {
reqid int
subid ID
}
func waitForMessages(in *json.Decoder, successes chan subConfirmation, notifications chan subscriptionResult, errors chan error) {
for {
var msg jsonrpcMessage
if err := in.Decode(&msg); err != nil {
errors <- fmt.Errorf("decode error: %v", err)
return
}
switch {
case msg.isNotification():
var res subscriptionResult
if err := json.Unmarshal(msg.Params, &res); err != nil {
errors <- fmt.Errorf("invalid subscription result: %v", err)
} else {
notifications <- res
}
case msg.isResponse():
var c subConfirmation
if msg.Error != nil {
errors <- msg.Error
} else if err := json.Unmarshal(msg.Result, &c.subid); err != nil {
errors <- fmt.Errorf("invalid response: %v", err)
} else {
json.Unmarshal(msg.ID, &c.reqid)
successes <- c
}
default:
errors <- fmt.Errorf("unrecognized message: %v", msg)
return
}
}
}

7
rpc/testdata/invalid-badid.js vendored Normal file

@ -0,0 +1,7 @@
// This test checks processing of messages with invalid ID.
--> {"id":[],"method":"test_foo"}
<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}
--> {"id":{},"method":"test_foo"}
<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}

14
rpc/testdata/invalid-batch.js vendored Normal file

@ -0,0 +1,14 @@
// This test checks the behavior of batches with invalid elements.
// Empty batches are not allowed. Batches may contain junk.
--> []
<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}}
--> [1]
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}]
--> [1,2,3]
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}]
--> [{"jsonrpc":"2.0","id":1,"method":"test_echo","params":["foo",1]},55,{"jsonrpc":"2.0","id":2,"method":"unknown_method"},{"foo":"bar"}]
<-- [{"jsonrpc":"2.0","id":1,"result":{"String":"foo","Int":1,"Args":null}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method unknown_method does not exist/is not available"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}]

7
rpc/testdata/invalid-idonly.js vendored Normal file

@ -0,0 +1,7 @@
// This test checks processing of messages that contain just the ID and nothing else.
--> {"id":1}
<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
--> {"jsonrpc":"2.0","id":1}
<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}

4
rpc/testdata/invalid-nonobj.js vendored Normal file

@ -0,0 +1,4 @@
// This test checks behavior for invalid requests.
--> 1
<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}

5
rpc/testdata/invalid-syntax.json vendored Normal file

@ -0,0 +1,5 @@
// This test checks that an error is written for invalid JSON requests.
--> 'f
<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"invalid character '\\'' looking for beginning of value"}}

8
rpc/testdata/reqresp-batch.js vendored Normal file

@ -0,0 +1,8 @@
// There is no response for all-notification batches.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
// This test checks regular batch calls.
--> [{"jsonrpc":"2.0","id":2,"method":"test_echo","params":[]}, {"jsonrpc":"2.0","id": 3,"method":"test_echo","params":["x",3]}]
<-- [{"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}},{"jsonrpc":"2.0","id":3,"result":{"String":"x","Int":3,"Args":null}}]

16
rpc/testdata/reqresp-echo.js vendored Normal file

@ -0,0 +1,16 @@
// This test calls the test_echo method.
--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": []}
<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}}
--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x"]}
<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 1"}}
--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3]}
<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":null}}
--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3, {"S": "foo"}]}
<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}}
--> {"jsonrpc": "2.0", "id": 2, "method": "test_echoWithCtx", "params": ["x", 3, {"S": "foo"}]}
<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}}

5
rpc/testdata/reqresp-namedparam.js vendored Normal file

@ -0,0 +1,5 @@
// This test checks that an error response is sent for calls
// with named parameters.
--> {"jsonrpc":"2.0","method":"test_echo","params":{"int":23},"id":3}
<-- {"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"non-array args"}}

4
rpc/testdata/reqresp-noargsrets.js vendored Normal file

@ -0,0 +1,4 @@
// This test calls the test_noArgsRets method.
--> {"jsonrpc": "2.0", "id": "foo", "method": "test_noArgsRets", "params": []}
<-- {"jsonrpc":"2.0","id":"foo","result":null}

4
rpc/testdata/reqresp-nomethod.js vendored Normal file

@ -0,0 +1,4 @@
// This test calls a method that doesn't exist.
--> {"jsonrpc": "2.0", "id": 2, "method": "invalid_method", "params": [2, 3]}
<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method invalid_method does not exist/is not available"}}

4
rpc/testdata/reqresp-noparam.js vendored Normal file

@ -0,0 +1,4 @@
// This test checks that calls with no parameters work.
--> {"jsonrpc":"2.0","method":"test_noArgsRets","id":3}
<-- {"jsonrpc":"2.0","id":3,"result":null}

4
rpc/testdata/reqresp-paramsnull.js vendored Normal file

@ -0,0 +1,4 @@
// This test checks that calls with "params":null work.
--> {"jsonrpc":"2.0","method":"test_noArgsRets","params":null,"id":3}
<-- {"jsonrpc":"2.0","id":3,"result":null}

6
rpc/testdata/revcall.js vendored Normal file

@ -0,0 +1,6 @@
// This test checks reverse calls.
--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBack","params":["foo",[1]]}
<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]}
--> {"jsonrpc":"2.0","id":1,"result":"my result"}
<-- {"jsonrpc":"2.0","id":2,"result":"my result"}

7
rpc/testdata/revcall2.js vendored Normal file

@ -0,0 +1,7 @@
// This test checks reverse calls.
--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBackLater","params":["foo",[1]]}
<-- {"jsonrpc":"2.0","id":2,"result":null}
<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]}
--> {"jsonrpc":"2.0","id":1,"result":"my result"}

12
rpc/testdata/subscription.js vendored Normal file

@ -0,0 +1,12 @@
// This test checks basic subscription support.
--> {"jsonrpc":"2.0","id":1,"method":"nftest_subscribe","params":["someSubscription",5,1]}
<-- {"jsonrpc":"2.0","id":1,"result":"0x1"}
<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}}
<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":2}}
<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":3}}
<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":4}}
<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":5}}
--> {"jsonrpc":"2.0","id":2,"method":"nftest_echo","params":[11]}
<-- {"jsonrpc":"2.0","id":2,"result":11}

180
rpc/testservice_test.go Normal file

@ -0,0 +1,180 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"context"
"encoding/binary"
"errors"
"sync"
"time"
)
func newTestServer() *Server {
server := NewServer()
server.idgen = sequentialIDGenerator()
if err := server.RegisterName("test", new(testService)); err != nil {
panic(err)
}
if err := server.RegisterName("nftest", new(notificationTestService)); err != nil {
panic(err)
}
return server
}
func sequentialIDGenerator() func() ID {
var (
mu sync.Mutex
counter uint64
)
return func() ID {
mu.Lock()
defer mu.Unlock()
counter++
id := make([]byte, 8)
binary.BigEndian.PutUint64(id, counter)
return encodeID(id)
}
}
type testService struct{}
type Args struct {
S string
}
type Result struct {
String string
Int int
Args *Args
}
func (s *testService) NoArgsRets() {}
func (s *testService) Echo(str string, i int, args *Args) Result {
return Result{str, i, args}
}
func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result {
return Result{str, i, args}
}
func (s *testService) Sleep(ctx context.Context, duration time.Duration) {
time.Sleep(duration)
}
func (s *testService) Rets() (string, error) {
return "", nil
}
func (s *testService) InvalidRets1() (error, string) {
return nil, ""
}
func (s *testService) InvalidRets2() (string, string) {
return "", ""
}
func (s *testService) InvalidRets3() (string, string, error) {
return "", "", nil
}
func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) {
c, ok := ClientFromContext(ctx)
if !ok {
return nil, errors.New("no client")
}
var result interface{}
err := c.Call(&result, method, args...)
return result, err
}
func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error {
c, ok := ClientFromContext(ctx)
if !ok {
return errors.New("no client")
}
go func() {
<-ctx.Done()
var result interface{}
c.Call(&result, method, args...)
}()
return nil
}
func (s *testService) Subscription(ctx context.Context) (*Subscription, error) {
return nil, nil
}
type notificationTestService struct {
unsubscribed chan string
gotHangSubscriptionReq chan struct{}
unblockHangSubscription chan struct{}
}
func (s *notificationTestService) Echo(i int) int {
return i
}
func (s *notificationTestService) Unsubscribe(subid string) {
if s.unsubscribed != nil {
s.unsubscribed <- subid
}
}
func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) {
notifier, supported := NotifierFromContext(ctx)
if !supported {
return nil, ErrNotificationsUnsupported
}
// By explicitly creating an subscription we make sure that the subscription id is send
// back to the client before the first subscription.Notify is called. Otherwise the
// events might be send before the response for the *_subscribe method.
subscription := notifier.CreateSubscription()
go func() {
for i := 0; i < n; i++ {
if err := notifier.Notify(subscription.ID, val+i); err != nil {
return
}
}
select {
case <-notifier.Closed():
case <-subscription.Err():
}
if s.unsubscribed != nil {
s.unsubscribed <- string(subscription.ID)
}
}()
return subscription, nil
}
// HangSubscription blocks on s.unblockHangSubscription before sending anything.
func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) {
notifier, supported := NotifierFromContext(ctx)
if !supported {
return nil, ErrNotificationsUnsupported
}
s.gotHangSubscriptionReq <- struct{}{}
<-s.unblockHangSubscription
subscription := notifier.CreateSubscription()
go func() {
notifier.Notify(subscription.ID, val)
}()
return subscription, nil
}

@ -17,13 +17,11 @@
package rpc package rpc
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"reflect"
"strings" "strings"
"sync"
mapset "github.com/deckarep/golang-set"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
) )
@ -35,57 +33,6 @@ type API struct {
Public bool // indication if the methods must be considered safe for public use Public bool // indication if the methods must be considered safe for public use
} }
// callback is a method callback which was registered in the server
type callback struct {
rcvr reflect.Value // receiver of method
method reflect.Method // callback
argTypes []reflect.Type // input argument types
hasCtx bool // method's first argument is a context (not included in argTypes)
errPos int // err return idx, of -1 when method cannot return error
isSubscribe bool // indication if the callback is a subscription
}
// service represents a registered object
type service struct {
name string // name for service
typ reflect.Type // receiver type
callbacks callbacks // registered handlers
subscriptions subscriptions // available subscriptions/notifications
}
// serverRequest is an incoming request
type serverRequest struct {
id interface{}
svcname string
callb *callback
args []reflect.Value
isUnsubscribe bool
err Error
}
type serviceRegistry map[string]*service // collection of services
type callbacks map[string]*callback // collection of RPC callbacks
type subscriptions map[string]*callback // collection of subscription callbacks
// Server represents a RPC server
type Server struct {
services serviceRegistry
run int32
codecsMu sync.Mutex
codecs mapset.Set
}
// rpcRequest represents a raw incoming RPC request
type rpcRequest struct {
service string
method string
id interface{}
isPubSub bool
params interface{}
err Error // invalid batch element
}
// Error wraps RPC errors, which contain an error code in addition to the message. // Error wraps RPC errors, which contain an error code in addition to the message.
type Error interface { type Error interface {
Error() string // returns the message Error() string // returns the message
@ -96,24 +43,19 @@ type Error interface {
// a RPC session. Implementations must be go-routine safe since the codec can be called in // a RPC session. Implementations must be go-routine safe since the codec can be called in
// multiple go-routines concurrently. // multiple go-routines concurrently.
type ServerCodec interface { type ServerCodec interface {
// Read next request Read() (msgs []*jsonrpcMessage, isBatch bool, err error)
ReadRequestHeaders() ([]rpcRequest, bool, Error)
// Parse request argument to the given types
ParseRequestArguments(argTypes []reflect.Type, params interface{}) ([]reflect.Value, Error)
// Assemble success response, expects response id and payload
CreateResponse(id interface{}, reply interface{}) interface{}
// Assemble error response, expects response id and error
CreateErrorResponse(id interface{}, err Error) interface{}
// Assemble error response with extra information about the error through info
CreateErrorResponseWithInfo(id interface{}, err Error, info interface{}) interface{}
// Create notification response
CreateNotification(id, namespace string, event interface{}) interface{}
// Write msg to client.
Write(msg interface{}) error
// Close underlying data stream
Close() Close()
// Closed when underlying connection is closed jsonWriter
}
// jsonWriter can write JSON messages to its underlying connection.
// Implementations must be safe for concurrent use.
type jsonWriter interface {
Write(context.Context, interface{}) error
// Closed returns a channel which is closed when the connection is closed.
Closed() <-chan interface{} Closed() <-chan interface{}
// RemoteAddr returns the peer address of the connection.
RemoteAddr() string
} }
type BlockNumber int64 type BlockNumber int64

@ -1,226 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"bufio"
"context"
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"math/rand"
"reflect"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
)
var (
subscriptionIDGenMu sync.Mutex
subscriptionIDGen = idGenerator()
)
// Is this an exported - upper case - name?
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// Is this type exported or a builtin?
func isExportedOrBuiltinType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
// PkgPath will be non-empty even for an exported type,
// so we need to check the type name as well.
return isExported(t.Name()) || t.PkgPath() == ""
}
var contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
// isContextType returns an indication if the given t is of context.Context or *context.Context type
func isContextType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t == contextType
}
var errorType = reflect.TypeOf((*error)(nil)).Elem()
// Implements this type the error interface
func isErrorType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Implements(errorType)
}
var subscriptionType = reflect.TypeOf((*Subscription)(nil)).Elem()
// isSubscriptionType returns an indication if the given t is of Subscription or *Subscription type
func isSubscriptionType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t == subscriptionType
}
// isPubSub tests whether the given method has as as first argument a context.Context
// and returns the pair (Subscription, error)
func isPubSub(methodType reflect.Type) bool {
// numIn(0) is the receiver type
if methodType.NumIn() < 2 || methodType.NumOut() != 2 {
return false
}
return isContextType(methodType.In(1)) &&
isSubscriptionType(methodType.Out(0)) &&
isErrorType(methodType.Out(1))
}
// formatName will convert to first character to lower case
func formatName(name string) string {
ret := []rune(name)
if len(ret) > 0 {
ret[0] = unicode.ToLower(ret[0])
}
return string(ret)
}
// suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria
// for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server
// documentation for a summary of these criteria.
func suitableCallbacks(rcvr reflect.Value, typ reflect.Type) (callbacks, subscriptions) {
callbacks := make(callbacks)
subscriptions := make(subscriptions)
METHODS:
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
mtype := method.Type
mname := formatName(method.Name)
if method.PkgPath != "" { // method must be exported
continue
}
var h callback
h.isSubscribe = isPubSub(mtype)
h.rcvr = rcvr
h.method = method
h.errPos = -1
firstArg := 1
numIn := mtype.NumIn()
if numIn >= 2 && mtype.In(1) == contextType {
h.hasCtx = true
firstArg = 2
}
if h.isSubscribe {
h.argTypes = make([]reflect.Type, numIn-firstArg) // skip rcvr type
for i := firstArg; i < numIn; i++ {
argType := mtype.In(i)
if isExportedOrBuiltinType(argType) {
h.argTypes[i-firstArg] = argType
} else {
continue METHODS
}
}
subscriptions[mname] = &h
continue METHODS
}
// determine method arguments, ignore first arg since it's the receiver type
// Arguments must be exported or builtin types
h.argTypes = make([]reflect.Type, numIn-firstArg)
for i := firstArg; i < numIn; i++ {
argType := mtype.In(i)
if !isExportedOrBuiltinType(argType) {
continue METHODS
}
h.argTypes[i-firstArg] = argType
}
// check that all returned values are exported or builtin types
for i := 0; i < mtype.NumOut(); i++ {
if !isExportedOrBuiltinType(mtype.Out(i)) {
continue METHODS
}
}
// when a method returns an error it must be the last returned value
h.errPos = -1
for i := 0; i < mtype.NumOut(); i++ {
if isErrorType(mtype.Out(i)) {
h.errPos = i
break
}
}
if h.errPos >= 0 && h.errPos != mtype.NumOut()-1 {
continue METHODS
}
switch mtype.NumOut() {
case 0, 1, 2:
if mtype.NumOut() == 2 && h.errPos == -1 { // method must one return value and 1 error
continue METHODS
}
callbacks[mname] = &h
}
}
return callbacks, subscriptions
}
// idGenerator helper utility that generates a (pseudo) random sequence of
// bytes that are used to generate identifiers.
func idGenerator() *rand.Rand {
if seed, err := binary.ReadVarint(bufio.NewReader(crand.Reader)); err == nil {
return rand.New(rand.NewSource(seed))
}
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
}
// NewID generates a identifier that can be used as an identifier in the RPC interface.
// e.g. filter and subscription identifier.
func NewID() ID {
subscriptionIDGenMu.Lock()
defer subscriptionIDGenMu.Unlock()
id := make([]byte, 16)
for i := 0; i < len(id); i += 7 {
val := subscriptionIDGen.Int63()
for j := 0; i+j < len(id) && j < 7; j++ {
id[i+j] = byte(val)
val >>= 8
}
}
rpcId := hex.EncodeToString(id)
// rpc ID's are RPC quantities, no leading zero's and 0 is 0x0
rpcId = strings.TrimLeft(rpcId, "0")
if rpcId == "" {
rpcId = "0"
}
return ID("0x" + rpcId)
}

@ -1,43 +0,0 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"strings"
"testing"
)
func TestNewID(t *testing.T) {
hexchars := "0123456789ABCDEFabcdef"
for i := 0; i < 100; i++ {
id := string(NewID())
if !strings.HasPrefix(id, "0x") {
t.Fatalf("invalid ID prefix, want '0x...', got %s", id)
}
id = id[2:]
if len(id) == 0 || len(id) > 32 {
t.Fatalf("invalid ID length, want len(id) > 0 && len(id) <= 32), got %d", len(id))
}
for i := 0; i < len(id); i++ {
if strings.IndexByte(hexchars, id[i]) == -1 {
t.Fatalf("unexpected byte, want any valid hex char, got %c", id[i])
}
}
}
}

@ -22,6 +22,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -56,22 +57,37 @@ var websocketJSONCodec = websocket.Codec{
// //
// allowedOrigins should be a comma-separated list of allowed origin URLs. // allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*". // To allow connections with any origin, pass "*".
func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{ return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins), Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) { Handler: func(conn *websocket.Conn) {
codec := newWebsocketCodec(conn)
s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
},
}
}
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
// Create a custom encode/decode pair to enforce payload size and number encoding // Create a custom encode/decode pair to enforce payload size and number encoding
conn.MaxPayloadBytes = maxRequestContentLength conn.MaxPayloadBytes = maxRequestContentLength
encoder := func(v interface{}) error { encoder := func(v interface{}) error {
return websocketJSONCodec.Send(conn, v) return websocketJSONCodec.Send(conn, v)
} }
decoder := func(v interface{}) error { decoder := func(v interface{}) error {
return websocketJSONCodec.Receive(conn, v) return websocketJSONCodec.Receive(conn, v)
} }
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) rpcconn := Conn(conn)
}, if conn.IsServerConn() {
// Override remote address with the actual socket address because
// package websocket crashes if there is no request origin.
addr := conn.Request().RemoteAddr
if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil {
// Add origin if present.
addr += "(" + wsaddr.URL.String() + ")"
} }
rpcconn = connWithRemoteAddr{conn, addr}
}
return NewCodec(rpcconn, encoder, decoder)
} }
// NewWSServer creates a new websocket RPC server around an API provider. // NewWSServer creates a new websocket RPC server around an API provider.
@ -105,15 +121,16 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http
} }
} }
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
f := func(cfg *websocket.Config, req *http.Request) error { f := func(cfg *websocket.Config, req *http.Request) error {
// Verify origin against whitelist.
origin := strings.ToLower(req.Header.Get("Origin")) origin := strings.ToLower(req.Header.Get("Origin"))
if allowAllOrigins || origins.Contains(origin) { if allowAllOrigins || origins.Contains(origin) {
return nil return nil
} }
log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) log.Warn("Rejected WebSocket connection", "origin", origin)
return fmt.Errorf("origin %s not allowed", origin) return errors.New("origin not allowed")
} }
return f return f
@ -155,8 +172,12 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
return nil, err return nil, err
} }
return newClient(ctx, func(ctx context.Context) (net.Conn, error) { return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
return wsDialContext(ctx, config) conn, err := wsDialContext(ctx, config)
if err != nil {
return nil, err
}
return newWebsocketCodec(conn), nil
}) })
} }