rpc: add SetHeader method to Client (#21392)

Resolves #20163

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
rene 2020-08-03 14:08:42 +02:00 committed by GitHub
parent 9c2ac6fbd5
commit 290d6bd903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 10 deletions

@ -85,7 +85,7 @@ type Client struct {
// writeConn is used for writing to the connection on the caller's goroutine. It should // writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is // only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on requestOp and released by sending on sendDone. // taken by sending on reqInit and released by sending on reqSent.
writeConn jsonWriter writeConn jsonWriter
// for dispatch // for dispatch
@ -260,6 +260,19 @@ func (c *Client) Close() {
} }
} }
// SetHeader adds a custom HTTP header to the client's requests.
// This method only works for clients using HTTP, it doesn't have
// any effect for clients using another transport.
func (c *Client) SetHeader(key, value string) {
if !c.isHTTP {
return
}
conn := c.writeConn.(*httpConn)
conn.mu.Lock()
conn.headers.Set(key, value)
conn.mu.Unlock()
}
// Call performs a JSON-RPC call with the given arguments and unmarshals into // Call performs a JSON-RPC call with the given arguments and unmarshals into
// result if no error occurred. // result if no error occurred.
// //

@ -26,6 +26,7 @@ import (
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) {
doTest(23000, true) doTest(23000, true)
} }
func TestClientSetHeader(t *testing.T) {
var gotHeader bool
srv := newTestServer()
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("test") == "ok" {
gotHeader = true
}
srv.ServeHTTP(w, r)
}))
defer httpsrv.Close()
defer srv.Stop()
client, err := Dial(httpsrv.URL)
if err != nil {
t.Fatal(err)
}
defer client.Close()
client.SetHeader("test", "ok")
if _, err := client.SupportedModules(); err != nil {
t.Fatal(err)
}
if !gotHeader {
t.Fatal("client did not set custom header")
}
// Check that Content-Type can be replaced.
client.SetHeader("content-type", "application/x-garbage")
_, err = client.SupportedModules()
if err == nil {
t.Fatal("no error for invalid content-type header")
} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
t.Fatalf("error is not related to content-type: %q", err)
}
}
func TestClientHTTP(t *testing.T) { func TestClientHTTP(t *testing.T) {
server := newTestServer() server := newTestServer()
defer server.Stop() defer server.Stop()

@ -26,6 +26,7 @@ import (
"io/ioutil" "io/ioutil"
"mime" "mime"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time" "time"
) )
@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
type httpConn struct { type httpConn struct {
client *http.Client client *http.Client
req *http.Request url string
closeOnce sync.Once closeOnce sync.Once
closeCh chan interface{} closeCh chan interface{}
mu sync.Mutex // protects headers
headers http.Header
} }
// httpConn is treated specially by Client. // httpConn is treated specially by Client.
@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
} }
func (hc *httpConn) remoteAddr() string { func (hc *httpConn) remoteAddr() string {
return hc.req.URL.String() return hc.url
} }
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) { func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client. // using the provided HTTP Client.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
req, err := http.NewRequest(http.MethodPost, endpoint, nil) // Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", contentType)
req.Header.Set("Accept", contentType)
initctx := context.Background() initctx := context.Background()
headers := make(http.Header, 2)
headers.Set("accept", contentType)
headers.Set("content-type", contentType)
return newClient(initctx, func(context.Context) (ServerCodec, error) { return newClient(initctx, func(context.Context) (ServerCodec, error) {
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil hc := &httpConn{
client: client,
headers: headers,
url: endpoint,
closeCh: make(chan interface{}),
}
return hc, nil
}) })
} }
@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
if respBody != nil { if respBody != nil {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if _, err2 := buf.ReadFrom(respBody); err2 == nil { if _, err2 := buf.ReadFrom(respBody); err2 == nil {
return fmt.Errorf("%v %v", err, buf.String()) return fmt.Errorf("%v: %v", err, buf.String())
} }
} }
return err return err
@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
if err != nil { if err != nil {
return nil, err return nil, err
} }
req := hc.req.WithContext(ctx) req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
req.Body = ioutil.NopCloser(bytes.NewReader(body)) if err != nil {
return nil, err
}
req.ContentLength = int64(len(body)) req.ContentLength = int64(len(body))
// set headers
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
// do request
resp, err := hc.client.Do(req) resp, err := hc.client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err