rpc: add SetHeader method to Client (#21392)
Resolves #20163 Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
9c2ac6fbd5
commit
290d6bd903
@ -85,7 +85,7 @@ type Client struct {
|
|||||||
|
|
||||||
// writeConn is used for writing to the connection on the caller's goroutine. It should
|
// writeConn is used for writing to the connection on the caller's goroutine. It should
|
||||||
// only be accessed outside of dispatch, with the write lock held. The write lock is
|
// only be accessed outside of dispatch, with the write lock held. The write lock is
|
||||||
// taken by sending on requestOp and released by sending on sendDone.
|
// taken by sending on reqInit and released by sending on reqSent.
|
||||||
writeConn jsonWriter
|
writeConn jsonWriter
|
||||||
|
|
||||||
// for dispatch
|
// for dispatch
|
||||||
@ -260,6 +260,19 @@ func (c *Client) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHeader adds a custom HTTP header to the client's requests.
|
||||||
|
// This method only works for clients using HTTP, it doesn't have
|
||||||
|
// any effect for clients using another transport.
|
||||||
|
func (c *Client) SetHeader(key, value string) {
|
||||||
|
if !c.isHTTP {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := c.writeConn.(*httpConn)
|
||||||
|
conn.mu.Lock()
|
||||||
|
conn.headers.Set(key, value)
|
||||||
|
conn.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Call performs a JSON-RPC call with the given arguments and unmarshals into
|
// Call performs a JSON-RPC call with the given arguments and unmarshals into
|
||||||
// result if no error occurred.
|
// result if no error occurred.
|
||||||
//
|
//
|
||||||
|
@ -26,6 +26,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) {
|
|||||||
doTest(23000, true)
|
doTest(23000, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientSetHeader(t *testing.T) {
|
||||||
|
var gotHeader bool
|
||||||
|
srv := newTestServer()
|
||||||
|
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("test") == "ok" {
|
||||||
|
gotHeader = true
|
||||||
|
}
|
||||||
|
srv.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer httpsrv.Close()
|
||||||
|
defer srv.Stop()
|
||||||
|
|
||||||
|
client, err := Dial(httpsrv.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
client.SetHeader("test", "ok")
|
||||||
|
if _, err := client.SupportedModules(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !gotHeader {
|
||||||
|
t.Fatal("client did not set custom header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that Content-Type can be replaced.
|
||||||
|
client.SetHeader("content-type", "application/x-garbage")
|
||||||
|
_, err = client.SupportedModules()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("no error for invalid content-type header")
|
||||||
|
} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
|
||||||
|
t.Fatalf("error is not related to content-type: %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientHTTP(t *testing.T) {
|
func TestClientHTTP(t *testing.T) {
|
||||||
server := newTestServer()
|
server := newTestServer()
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
|
37
rpc/http.go
37
rpc/http.go
@ -26,6 +26,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
|
|||||||
|
|
||||||
type httpConn struct {
|
type httpConn struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
req *http.Request
|
url string
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closeCh chan interface{}
|
closeCh chan interface{}
|
||||||
|
mu sync.Mutex // protects headers
|
||||||
|
headers http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpConn is treated specially by Client.
|
// httpConn is treated specially by Client.
|
||||||
@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hc *httpConn) remoteAddr() string {
|
func (hc *httpConn) remoteAddr() string {
|
||||||
return hc.req.URL.String()
|
return hc.url
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
|
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
|
||||||
@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
|
|||||||
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
|
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
|
||||||
// using the provided HTTP Client.
|
// using the provided HTTP Client.
|
||||||
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
||||||
req, err := http.NewRequest(http.MethodPost, endpoint, nil)
|
// Sanity check URL so we don't end up with a client that will fail every request.
|
||||||
|
_, err := url.Parse(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", contentType)
|
|
||||||
req.Header.Set("Accept", contentType)
|
|
||||||
|
|
||||||
initctx := context.Background()
|
initctx := context.Background()
|
||||||
|
headers := make(http.Header, 2)
|
||||||
|
headers.Set("accept", contentType)
|
||||||
|
headers.Set("content-type", contentType)
|
||||||
return newClient(initctx, func(context.Context) (ServerCodec, error) {
|
return newClient(initctx, func(context.Context) (ServerCodec, error) {
|
||||||
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
|
hc := &httpConn{
|
||||||
|
client: client,
|
||||||
|
headers: headers,
|
||||||
|
url: endpoint,
|
||||||
|
closeCh: make(chan interface{}),
|
||||||
|
}
|
||||||
|
return hc, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
|
|||||||
if respBody != nil {
|
if respBody != nil {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
|
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
|
||||||
return fmt.Errorf("%v %v", err, buf.String())
|
return fmt.Errorf("%v: %v", err, buf.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req := hc.req.WithContext(ctx)
|
req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
|
||||||
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
req.ContentLength = int64(len(body))
|
req.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
// set headers
|
||||||
|
hc.mu.Lock()
|
||||||
|
req.Header = hc.headers.Clone()
|
||||||
|
hc.mu.Unlock()
|
||||||
|
|
||||||
|
// do request
|
||||||
resp, err := hc.client.Do(req)
|
resp, err := hc.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user