Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
)

const (
contentType = "application/json"
maxHTTPRequestContentLength = 1024 * 128
)

Expand Down Expand Up @@ -69,8 +70,8 @@ func DialHTTP(endpoint string) (*Client, error) {
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", contentType)
req.Header.Set("Accept", contentType)

initctx := context.Background()
return newClient(initctx, func(context.Context) (net.Conn, error) {
Expand Down Expand Up @@ -150,21 +151,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.ContentLength == 0 && r.URL.RawQuery == "" {
return
}
// For meaningful requests, validate it's size and content type
if r.ContentLength > maxHTTPRequestContentLength {
http.Error(w,
fmt.Sprintf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength),
http.StatusRequestEntityTooLarge)
return
}
ct := r.Header.Get("content-type")
mt, _, err := mime.ParseMediaType(ct)
if err != nil || mt != "application/json" {
http.Error(w,
"invalid content type, only application/json is supported",
http.StatusUnsupportedMediaType)
if responseCode, errorMessage := httpErrorResponse(r); responseCode != 0 {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking, but perhaps use shorted variable names?

responseCode -> code
errorMessage -> err

http.Error(w, errorMessage, responseCode)
return
}

// All checks passed, create a codec that reads direct from the request body
// untilEOF and writes the response to w and order the server to process a
// single request.
Expand All @@ -175,6 +166,28 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
srv.ServeSingleRequest(codec, OptionMethodInvocation)
}

// Returns a non-zero response code and error message if the request is invalid.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method docs in Go start with the name of the method. ie.

// httpErrorResponse returns a non-zero response code and error message if the request is invalid.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also rather call it validateRequest, to make it clearer what it does.

func httpErrorResponse(r *http.Request) (int, string) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of returning (int, string), please return (int, error). That makes it a lot cleaner imho, since you don't need to check for 0 equality outside (which is arbitrary), rather can do a clean error check which is the standard in Go.

if r.Method == "PUT" || r.Method == "DELETE" {
errorMessage := "method not allowed"
return http.StatusMethodNotAllowed, errorMessage
}

if r.ContentLength > maxHTTPRequestContentLength {
errorMessage := fmt.Sprintf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength)
return http.StatusRequestEntityTooLarge, errorMessage
}

ct := r.Header.Get("content-type")
mt, _, err := mime.ParseMediaType(ct)
if err != nil || mt != contentType {
errorMessage := fmt.Sprintf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, errorMessage
}

return 0, ""
}

func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler {
// disable CORS support if user has not specified a custom CORS configuration
if len(allowedOrigins) == 0 {
Expand Down
40 changes: 40 additions & 0 deletions rpc/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package rpc
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the LGPL copyright header that we have on each of our source files.


import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestHTTPErrorResponseWithDelete(t *testing.T) {
httpErrorResponseTest(t, "DELETE", contentType, "", http.StatusMethodNotAllowed)
}

func TestHTTPErrorResponseWithPut(t *testing.T) {
httpErrorResponseTest(t, "PUT", contentType, "", http.StatusMethodNotAllowed)
}

func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
body := make([]rune, maxHTTPRequestContentLength+1, maxHTTPRequestContentLength+1)
httpErrorResponseTest(t,
"POST", contentType, string(body), http.StatusRequestEntityTooLarge)
}

func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) {
httpErrorResponseTest(t, "POST", "", "", http.StatusUnsupportedMediaType)
}

func TestHTTPErrorResponseWithValidRequest(t *testing.T) {
httpErrorResponseTest(t, "POST", contentType, "", 0)
}

func httpErrorResponseTest(t *testing.T,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please call this testHTTPErrorResponse. Lower case won't be executed and it's usually the convention.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a helper called by all the other tests.

method, contentType, body string, expectedResponse int) {

request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body))
request.Header.Set("content-type", contentType)
if response, _ := httpErrorResponse(request); response != expectedResponse {
t.Fatalf("response code should be %d not %d", expectedResponse, response)
}
}