diff --git a/jhttp/bridge.go b/jhttp/bridge.go index d4a9756..c55086f 100644 --- a/jhttp/bridge.go +++ b/jhttp/bridge.go @@ -5,7 +5,6 @@ package jhttp import ( - "bytes" "context" "encoding/json" "fmt" @@ -19,9 +18,12 @@ import ( // A Bridge is a http.Handler that bridges requests to a JSON-RPC server. // -// The body of the HTTP POST request must contain the complete JSON-RPC request -// message, encoded with Content-Type: application/json. Either a single -// request object or a list of request objects is supported. +// By default, the bridge accepts only HTTP POST requests with the complete +// JSON-RPC request message in the body, with Content-Type application/json. +// Either a single request object or a list of request objects is supported. +// +// If either a CheckRequest or ParseRequest hook is set, these requirements are +// disabled, and the hooks are responsible for checking request structure. // // If the request completes, whether or not there is an error, the HTTP // response is 200 (OK) for ordinary requests or 204 (No Response) for @@ -35,22 +37,27 @@ import ( // client, allowing an EncodeContext callback to retrieve state from the HTTP // headers. Use jhttp.HTTPRequest to retrieve the request from the context. type Bridge struct { - local server.Local - checkType func(string) bool - checkReq func(*http.Request) error + local server.Local + checkReq func(*http.Request) error + parseReq func(*http.Request) ([]*jrpc2.Request, error) } // ServeHTTP implements the required method of http.Handler. func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - if !b.checkType(req.Header.Get("Content-Type")) { - w.WriteHeader(http.StatusUnsupportedMediaType) - return + // If neither a check hook nor a parse hook are defined, insist that the + // method is POST and the content-type is application/json. Setting either + // hook disables these checks. + if b.checkReq == nil && b.parseReq == nil { + if req.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if req.Header.Get("Content-Type") != "application/json" { + w.WriteHeader(http.StatusUnsupportedMediaType) + return + } } - if err := b.checkReq(req); err != nil { + if err := b.checkHTTPRequest(req); err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintln(w, err.Error()) return @@ -62,11 +69,6 @@ func (b Bridge) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { - body, err := io.ReadAll(req.Body) - if err != nil { - return err - } - // The HTTP request requires a response, but the server will not reply if // all the requests are notifications. Check whether we have any calls // needing a response, and choose whether to wait for a reply based on that. @@ -74,7 +76,7 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { // Note that we are forgiving about a missing version marker in a request, // since we can't tell at this point whether the server is willing to accept // messages like that. - jreq, err := jrpc2.ParseRequests(body) + jreq, err := b.parseHTTPRequest(req) if err != nil && err != jrpc2.ErrInvalidVersion { return err } @@ -126,10 +128,15 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { rsp.SetID(inboundID[i]) } - // If the original request was a single message, make sure we encode the - // response the same way. + // If there is only a single reply, send it alone; otherwise encode a batch. + // Per the spec (https://www.jsonrpc.org/specification#batch), this is OK; + // we are not required to respond to a batch with an array: + // + // The Server SHOULD respond with an Array containing the corresponding + // Response objects + // var reply []byte - if len(rsps) == 1 && !bytes.HasPrefix(bytes.TrimSpace(body), []byte("[")) { + if len(rsps) == 1 { reply, err = json.Marshal(rsps[0]) } else { reply, err = json.Marshal(rsps) @@ -143,6 +150,24 @@ func (b Bridge) serveInternal(w http.ResponseWriter, req *http.Request) error { return nil } +func (b Bridge) checkHTTPRequest(req *http.Request) error { + if b.checkReq != nil { + return b.checkReq(req) + } + return nil +} + +func (b Bridge) parseHTTPRequest(req *http.Request) ([]*jrpc2.Request, error) { + if b.parseReq != nil { + return b.parseReq(req) + } + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + return jrpc2.ParseRequests(body) +} + // Close closes the channel to the server, waits for the server to exit, and // reports its exit status. func (b Bridge) Close() error { return b.local.Close() } @@ -162,8 +187,8 @@ func NewBridge(mux jrpc2.Assigner, opts *BridgeOptions) Bridge { Client: opts.clientOptions(), Server: opts.serverOptions(), }), - checkType: opts.checkContentType(), - checkReq: opts.checkRequest(), + checkReq: opts.checkRequest(), + parseReq: opts.parseRequest(), } } @@ -176,15 +201,20 @@ type BridgeOptions struct { // Options for the bridge server (default nil). Server *jrpc2.ServerOptions - // If non-nil, this function is called to check whether the HTTP request's - // declared content-type is valid. If this function returns false, the - // request is rejected. If nil, the default check requires a content type of - // "application/json". - CheckContentType func(contentType string) bool - // If non-nil, this function is called to check the HTTP request. If this // function reports an error, the request is rejected. + // + // Setting this hook disables the default requirement that the request + // method be POST and the content-type be application/json. CheckRequest func(*http.Request) error + + // If non-nil, this function is called to parse JSON-RPC requests from the + // HTTP request. If this function reports an error, the request fails. By + // default, the bridge uses jrpc2.ParseRequests on the HTTP request body. + // + // Setting this hook disables the default requirement that the request + // method be POST and the content-type be application/json. + ParseRequest func(*http.Request) ([]*jrpc2.Request, error) } func (o *BridgeOptions) clientOptions() *jrpc2.ClientOptions { @@ -201,18 +231,18 @@ func (o *BridgeOptions) serverOptions() *jrpc2.ServerOptions { return o.Server } -func (o *BridgeOptions) checkContentType() func(string) bool { - if o == nil || o.CheckContentType == nil { - return func(ctype string) bool { return ctype == "application/json" } +func (o *BridgeOptions) checkRequest() func(*http.Request) error { + if o == nil { + return nil } - return o.CheckContentType + return o.CheckRequest } -func (o *BridgeOptions) checkRequest() func(*http.Request) error { - if o == nil || o.CheckRequest == nil { - return func(*http.Request) error { return nil } +func (o *BridgeOptions) parseRequest() func(*http.Request) ([]*jrpc2.Request, error) { + if o == nil { + return nil } - return o.CheckRequest + return o.ParseRequest } type httpReqKey struct{} diff --git a/jhttp/jhttp_test.go b/jhttp/jhttp_test.go index 2290885..9f52de0 100644 --- a/jhttp/jhttp_test.go +++ b/jhttp/jhttp_test.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" "net/http/httptest" @@ -51,12 +50,12 @@ func TestBridge(t *testing.T) { // Verify that a valid POST request succeeds. t.Run("PostOK", func(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "jsonrpc": "2.0", - "id": 1, - "method": "Test1", - "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] -} -`)) + "jsonrpc": "2.0", + "id": 1, + "method": "Test1", + "params": ["a", "foolish", "consistency", "is", "the", "hobgoblin"] + } + `)) if err != nil { t.Fatalf("POST request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { @@ -76,10 +75,10 @@ func TestBridge(t *testing.T) { // Verify that the bridge will accept a batch. t.Run("PostBatchOK", func(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`[ - {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, - {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} -] -`)) + {"jsonrpc":"2.0", "id": 3, "method": "Test1", "params": ["first"]}, + {"jsonrpc":"2.0", "id": 7, "method": "Test1", "params": ["among", "equals"]} + ] + `)) if err != nil { t.Fatalf("POST request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { @@ -122,10 +121,10 @@ func TestBridge(t *testing.T) { // Verify that a POST that generates a JSON-RPC error succeeds. t.Run("PostErrorReply", func(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "id": 1, - "jsonrpc": "2.0" -} -`)) + "id": 1, + "jsonrpc": "2.0" + } + `)) if err != nil { t.Fatalf("POST request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { @@ -145,10 +144,10 @@ func TestBridge(t *testing.T) { // Verify that a notification returns an empty success. t.Run("PostNotification", func(t *testing.T) { rsp, err := http.Post(hsrv.URL, "application/json", strings.NewReader(`{ - "jsonrpc": "2.0", - "method": "TakeNotice", - "params": [] -}`)) + "jsonrpc": "2.0", + "method": "TakeNotice", + "params": [] + }`)) if err != nil { t.Fatalf("POST request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusNoContent; got != want { @@ -165,12 +164,15 @@ func TestBridge(t *testing.T) { } // Verify that the content-type check hook works. -func TestBridge_contentTypeCheck(t *testing.T) { +func TestBridge_requestCheck(t *testing.T) { defer leaktest.Check(t)() b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ - CheckContentType: func(ctype string) bool { - return ctype == "application/octet-stream" + CheckRequest: func(req *http.Request) error { + if req.Header.Get("x-test-header") == "fail" { + return errors.New("request rejected") + } + return nil }, }) defer checkClose(t, b) @@ -178,38 +180,59 @@ func TestBridge_contentTypeCheck(t *testing.T) { hsrv := httptest.NewServer(b) defer hsrv.Close() - const reqTemplate = `{"jsonrpc":"2.0","id":%q,"method":"Test1","params":["a","b","c"]}` - t.Run("ContentTypeOK", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "application/octet-stream", - strings.NewReader(fmt.Sprintf(reqTemplate, "ok"))) + const reqBody = `{"jsonrpc":"2.0","id":1,"method":"Test1","params":["a","b","c"]}` + const wantReply = `{"jsonrpc":"2.0","id":1,"result":3}` + + t.Run("Succeed", func(t *testing.T) { + // With a check hook set, the method and content-type checks should not happen. + req, err := http.NewRequest("GET", hsrv.URL, strings.NewReader(reqBody)) if err != nil { - t.Fatalf("POST request failed: %v", err) + t.Fatalf("NewRequest: %v", err) + } + + rsp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("GET request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST response code: got %v, want %v", got, want) + t.Errorf("GET response code: got %v, want %v", got, want) + } + body, _ := io.ReadAll(rsp.Body) + rsp.Body.Close() + if got := string(body); got != wantReply { + t.Errorf("Response: got %#q, want %#q", got, wantReply) } }) - t.Run("ContentTypeBad", func(t *testing.T) { - rsp, err := http.Post(hsrv.URL, "text/plain", - strings.NewReader(fmt.Sprintf(reqTemplate, "bad"))) + t.Run("CheckFailed", func(t *testing.T) { + req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader(reqBody)) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + req.Header.Set("X-Test-Header", "fail") + + rsp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("POST request failed: %v", err) - } else if got, want := rsp.StatusCode, http.StatusUnsupportedMediaType; got != want { + } else if got, want := rsp.StatusCode, http.StatusInternalServerError; got != want { t.Errorf("POST response code: got %v, want %v", got, want) } }) } -// Verify that the content-type check hook works. -func TestBridge_requestCheck(t *testing.T) { +// Verify that the request-parsing hook works. +func TestBridge_parseRequest(t *testing.T) { defer leaktest.Check(t)() + const reqMessage = `{"jsonrpc":"2.0", "method": "Test2", "id": 100, "params":null}` + const wantReply = `{"jsonrpc":"2.0","id":100,"result":0}` + b := jhttp.NewBridge(testService, &jhttp.BridgeOptions{ - CheckRequest: func(req *http.Request) error { - if req.Header.Get("x-test-header") == "fail" { - return errors.New("request rejected") + ParseRequest: func(req *http.Request) ([]*jrpc2.Request, error) { + action := req.Header.Get("x-test-header") + if action == "fail" { + return nil, errors.New("parse hook reporting failure") } - return nil + return jrpc2.ParseRequests([]byte(reqMessage)) }, }) defer checkClose(t, b) @@ -217,30 +240,33 @@ func TestBridge_requestCheck(t *testing.T) { hsrv := httptest.NewServer(b) defer hsrv.Close() - const reqBody = `{"jsonrpc":"2.0","id":1,"method":"Test1","params":["a","b","c"]}` - t.Run("RequestOK", func(t *testing.T) { - req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader(reqBody)) + t.Run("Succeed", func(t *testing.T) { + // Since a parse hook is set, the method and content-type checks should not occur. + // We send an empty body to be sure the request comes from the hook. + req, err := http.NewRequest("GET", hsrv.URL, strings.NewReader("")) if err != nil { t.Fatalf("NewRequest: %v", err) } - req.Header.Set("X-Test-Header", "succeed") - req.Header.Set("Content-Type", "application/json") rsp, err := http.DefaultClient.Do(req) if err != nil { - t.Fatalf("POST request failed: %v", err) + t.Fatalf("GET request failed: %v", err) } else if got, want := rsp.StatusCode, http.StatusOK; got != want { - t.Errorf("POST response code: got %v, want %v", got, want) + t.Errorf("GET response code: got %v, want %v", got, want) + } + body, _ := io.ReadAll(rsp.Body) + rsp.Body.Close() + if got := string(body); got != wantReply { + t.Errorf("Response: got %#q, want %#q", got, wantReply) } }) - t.Run("RequestBad", func(t *testing.T) { - req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader(reqBody)) + t.Run("Fail", func(t *testing.T) { + req, err := http.NewRequest("POST", hsrv.URL, strings.NewReader("")) if err != nil { t.Fatalf("NewRequest: %v", err) } req.Header.Set("X-Test-Header", "fail") - req.Header.Set("Content-Type", "application/json") rsp, err := http.DefaultClient.Do(req) if err != nil {