infra/op-txproxy/auth_handler_test.go

149 lines
4.3 KiB
Go
Raw Permalink Normal View History

package op_txproxy
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
oprpc "github.com/ethereum-optimism/optimism/op-service/rpc"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/require"
)
var pingHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ping")
})
func TestAuthHandlerMissingAuth(t *testing.T) {
handler := authHandler{next: pingHandler}
rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
handler.ServeHTTP(rr, r)
// simply forwards the request
require.Equal(t, http.StatusOK, rr.Code)
require.Equal(t, "ping", rr.Body.String())
}
func TestAuthHandlerBadHeader(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("auth", "foobarbaz")
handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestAuthHandlerBadSignature(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("0xa"), "foobar"))
handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestAuthHandlerMismatchedCaller(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", strings.NewReader("body"))
privKey, _ := crypto.GenerateKey()
sig, _ := crypto.Sign(accounts.TextHash([]byte("body")), privKey)
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("0xa"), sig))
handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestAuthHandlerSetContext(t *testing.T) {
var ctx *AuthContext
ctxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx = AuthFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := authHandler{headerKey: "auth", next: ctxHandler}
rr := httptest.NewRecorder()
body := bytes.NewBufferString("body")
r, _ := http.NewRequest("GET", "/", body)
privKey, _ := crypto.GenerateKey()
sig, _ := crypto.Sign(accounts.TextHash(body.Bytes()), privKey)
addr := crypto.PubkeyToAddress(privKey.PublicKey)
r.Header.Set("auth", fmt.Sprintf("%s:%s", addr, common.Bytes2Hex(sig)))
handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusOK, rr.Code)
require.NotNil(t, ctx)
require.Equal(t, addr, ctx.Caller)
}
func TestAuthHandlerRpcMiddleware(t *testing.T) {
rpcServer := oprpc.NewServer("127.0.0.1", 0, "", oprpc.WithMiddleware(AuthMiddleware("auth")))
require.NoError(t, rpcServer.Start())
t.Cleanup(func() { _ = rpcServer.Stop() })
url := fmt.Sprintf("http://%s", rpcServer.Endpoint())
clnt, err := rpc.Dial(url)
require.NoError(t, err)
defer clnt.Close()
// pass without auth (default handler does not deny)
err = clnt.CallContext(context.Background(), nil, "rpc_modules")
require.Nil(t, err)
// denied with bad auth header
clnt.SetHeader("auth", "foobar")
err = clnt.CallContext(context.Background(), nil, "rpc_modules")
require.NotNil(t, err)
}
func TestAuthHandlerRequestBodyLimit(t *testing.T) {
var body []byte
bodyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
})
handler := authHandler{headerKey: "auth", next: bodyHandler}
// only up to limit is read when validating the request body
authBody := strings.Repeat("*", defaultBodyLimit)
excess := strings.Repeat("-", 10)
rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", strings.NewReader(authBody+excess))
// sign over just the auth body
privKey, _ := crypto.GenerateKey()
sig, _ := crypto.Sign(accounts.TextHash([]byte(authBody)), privKey)
addr := crypto.PubkeyToAddress(privKey.PublicKey)
r.Header.Set("auth", fmt.Sprintf("%s:%s", addr, common.Bytes2Hex(sig)))
// Auth handler successfully only parses through the max body limit
handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusOK, rr.Code, rr.Body)
// The next handler has the full request body present
require.Len(t, body, len(authBody)+len(excess))
}