Skip to content

Commit

Permalink
jhttp: Remove the CheckRequest bridge hook (#67)
Browse files Browse the repository at this point in the history
There is no need to have two separate request hooks; consolidate to use
the more general of the two.

Also includes some test code cleanup.
  • Loading branch information
creachadair authored Dec 24, 2021
1 parent 2b5d734 commit 9d2d659
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 164 deletions.
86 changes: 32 additions & 54 deletions jhttp/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (
// Allowed). If the Content-Type is not application/json, the bridge reports
// 415 (Unsupported Media Type).
//
// If either a CheckRequest or ParseRequest hook is set, these requirements are
// disabled, and the hooks are responsible for checking request structure.
// If a ParseRequest hook is set, these requirements are disabled, and the hook
// is entirely responsible for checking request structure.
//
// If the request completes, whether or not there is an error, the HTTP
// response is 200 (OK) for ordinary requests or 204 (No Response) for
Expand All @@ -38,16 +38,14 @@ import (
// headers. Use jhttp.HTTPRequest to retrieve the request from the context.
type Bridge struct {
local server.Local
checkReq func(*http.Request) error
parseReq func(*http.Request) ([]*jrpc2.Request, error)
}

// ServeHTTP implements the required method of http.Handler.
func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// If neither a check hook nor a parse hook are defined, insist that the
// method is POST and the content-type is application/json. Setting either
// hook disables these checks.
if b.checkReq == nil && b.parseReq == nil {
// If no parse hook is defined, insist that the method is POST and the
// content-type is application/json. Setting a hook disables these checks.
if b.parseReq == nil {
if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
Expand All @@ -57,11 +55,6 @@ func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
}
if err := b.checkHTTPRequest(req); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, err.Error())
return
}
if err := b.serveInternal(w, req); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, err.Error())
Expand Down Expand Up @@ -128,33 +121,7 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error {
rsp.SetID(inboundID[i])
}

// If there is only a single reply, send it alone; otherwise encode a batch.
// Per the spec (https://www.jsonrpc.org/specification#batch), this is OK;
// we are not required to respond to a batch with an array:
//
// The Server SHOULD respond with an Array containing the corresponding
// Response objects
//
var reply []byte
if len(rsps) == 1 {
reply, err = json.Marshal(rsps[0])
} else {
reply, err = json.Marshal(rsps)
}
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(reply)))
w.Write(reply)
return nil
}

func (b Bridge) checkHTTPRequest(req *http.Request) error {
if b.checkReq != nil {
return b.checkReq(req)
}
return nil
return b.encodeResponses(rsps, w)
}

func (b Bridge) parseHTTPRequest(req *http.Request) ([]*jrpc2.Request, error) {
Expand All @@ -168,6 +135,24 @@ func (b Bridge) parseHTTPRequest(req *http.Request) ([]*jrpc2.Request, error) {
return jrpc2.ParseRequests(body)
}

func (b Bridge) encodeResponses(rsps []*jrpc2.Response, w http.ResponseWriter) error {
// If there is only a single reply, send it alone; otherwise encode a batch.
// Per the spec (https://www.jsonrpc.org/specification#batch), this is OK;
// we are not required to respond to a batch with an array:
//
// The Server SHOULD respond with an Array containing the corresponding
// Response objects
//
data, err := marshalResponses(rsps)
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
_, err = w.Write(data)
return err
}

// Close closes the channel to the server, waits for the server to exit, and
// reports its exit status.
func (b Bridge) Close() error { return b.local.Close() }
Expand All @@ -187,7 +172,6 @@ func NewBridge(mux jrpc2.Assigner, opts *BridgeOptions) Bridge {
Client: opts.clientOptions(),
Server: opts.serverOptions(),
}),
checkReq: opts.checkRequest(),
parseReq: opts.parseRequest(),
}
}
Expand All @@ -201,13 +185,6 @@ type BridgeOptions struct {
// Options for the bridge server (default nil).
Server *jrpc2.ServerOptions

// If non-nil, this function is called to check the HTTP request. If this
// function reports an error, the request is rejected.
//
// Setting this hook disables the default requirement that the request
// method be POST and the content-type be application/json.
CheckRequest func(*http.Request) error

// If non-nil, this function is called to parse JSON-RPC requests from the
// HTTP request. If this function reports an error, the request fails. By
// default, the bridge uses jrpc2.ParseRequests on the HTTP request body.
Expand All @@ -231,13 +208,6 @@ func (o *BridgeOptions) serverOptions() *jrpc2.ServerOptions {
return o.Server
}

func (o *BridgeOptions) checkRequest() func(*http.Request) error {
if o == nil {
return nil
}
return o.CheckRequest
}

func (o *BridgeOptions) parseRequest() func(*http.Request) ([]*jrpc2.Request, error) {
if o == nil {
return nil
Expand All @@ -256,3 +226,11 @@ func HTTPRequest(ctx context.Context) *http.Request {
}
return nil
}

// marshalResponses encodes a batch of JSON-RPC responses into JSON.
func marshalResponses(rsps []*jrpc2.Response) ([]byte, error) {
if len(rsps) == 1 {
return json.Marshal(rsps[0])
}
return json.Marshal(rsps)
}
139 changes: 29 additions & 110 deletions jhttp/jhttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,49 +49,29 @@ func TestBridge(t *testing.T) {

// Verify that a valid POST request succeeds.
t.Run("PostOK", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{
got := mustPost(t, hsrv.URL, `{
"jsonrpc": "2.0",
"id": 1,
"method": "Test1",
"params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"]
}
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
}`, http.StatusOK)

const want = `{"jsonrpc":"2.0","id":1,"result":6}`
if got := string(body); got != want {
if got != want {
t.Errorf("POST body: got %#q, want %#q", got, want)
}
})

// Verify that the bridge will accept a batch.
t.Run("PostBatchOK", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`[
got := mustPost(t, hsrv.URL, `[
{"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]},
{"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]}
]
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
]`, http.StatusOK)

const want = `[{"jsonrpc":"2.0","id":3,"result":1},` +
`{"jsonrpc":"2.0","id":7,"result":2}]`
if got := string(body); got != want {
if got != want {
t.Errorf("POST body: got %#q, want %#q", got, want)
}
})
Expand All @@ -112,113 +92,37 @@ func TestBridge(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "text/plain", strings.NewReader(`{}`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
}
if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want {
t.Errorf("POST status: got %v, want %v", got, want)
} else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
})

// Verify that a POST that generates a JSON-RPC error succeeds.
t.Run("PostErrorReply", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{
got := mustPost(t, hsrv.URL, `{
"id": 1,
"jsonrpc": "2.0"
}
`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("POST status: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
}`, http.StatusOK)

const exp = `{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"empty method name"}}`
if got := string(body); got != exp {
if got != exp {
t.Errorf("POST body: got %#q, want %#q", got, exp)
}
})

// Verify that a notification returns an empty success.
t.Run("PostNotification", func(t *testing.T) {
rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{
got := mustPost(t, hsrv.URL, `{
"jsonrpc": "2.0",
"method": "TakeNotice",
"params": []
}`))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusNoContent; got != want {
t.Errorf("POST status: got %v, want %v", got, want)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
if got := string(body); got != "" {
}`, http.StatusNoContent)
if got != "" {
t.Errorf("POST body: got %q, want empty", got)
}
})
}

// Verify that the content-type check hook works.
func TestBridge_requestCheck(t *testing.T) {
defer leaktest.Check(t)()

b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{
CheckRequest: func(req *http.Request) error {
if req.Header.Get("x-test-header") == "fail" {
return errors.New("request rejected")
}
return nil
},
})
defer checkClose(t, b)

hsrv := httptest.NewServer(b)
defer hsrv.Close()

const reqBody = `{"jsonrpc":"2.0","id":1,"method":"Test1","params":["a","b","c"]}`
const wantReply = `{"jsonrpc":"2.0","id":1,"result":3}`

t.Run("Succeed", func(t *testing.T) {
// With a check hook set, the method and content-type checks should not happen.
req, err := http.NewRequest("GET", hsrv.URL, strings.NewReader(reqBody))
if err != nil {
t.Fatalf("NewRequest: %v", err)
}

rsp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("GET request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusOK; got != want {
t.Errorf("GET response code: got %v, want %v", got, want)
}
body, _ := io.ReadAll(rsp.Body)
rsp.Body.Close()
if got := string(body); got != wantReply {
t.Errorf("Response: got %#q, want %#q", got, wantReply)
}
})

t.Run("CheckFailed", func(t *testing.T) {
req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader(reqBody))
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.Header.Set("X-Test-Header", "fail")

rsp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got, want := rsp.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("POST response code: got %v, want %v", got, want)
}
})
}

// Verify that the request-parsing hook works.
func TestBridge_parseRequest(t *testing.T) {
defer leaktest.Check(t)()
Expand Down Expand Up @@ -345,3 +249,18 @@ func checkClose(t *testing.T, c io.Closer) {
t.Errorf("Error in Close: %v", err)
}
}

func mustPost(t *testing.T, url, req string, code int) string {
t.Helper()
rsp, err := http.Post(url, "application/json", strings.NewReader(req))
if err != nil {
t.Fatalf("POST request failed: %v", err)
} else if got := rsp.StatusCode; got != code {
t.Errorf("POST response code: got %v, want %v", got, code)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Errorf("Reading POST body: %v", err)
}
return string(body)
}

0 comments on commit 9d2d659

Please sign in to comment.