Skip to content

Commit

Permalink
Fix ErrorWriter IsSupportedCheck with required connect protocol option (
Browse files Browse the repository at this point in the history
#700)

This PR fixes ErrorWriter to correctly return unsupported protocol if
the option `WithRequireConnectProtocolHeader` is set and the header or
query value isn't include in the request. It will now correctly return
unsupported to ensure fallback options can process the error.

Fixes #699
  • Loading branch information
emcfarlane authored Mar 6, 2024
1 parent 32b3f43 commit 6fab35e
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ linters:
enable-all: true
disable:
- cyclop # covered by gocyclo
- depguard # unnecessary for small libraries
- deadcode # abandoned
- depguard # unnecessary for small libraries
- exhaustivestruct # replaced by exhaustruct
- funlen # rely on code review to limit function length
- gocognit # dubious "cognitive overhead" quantification
Expand Down
2 changes: 1 addition & 1 deletion codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestStableCodec(t *testing.T) {
func TestJSONCodec(t *testing.T) {
t.Parallel()

codec := &protoJSONCodec{name: "json"}
codec := &protoJSONCodec{name: codecNameJSON}

t.Run("success", func(t *testing.T) {
t.Parallel()
Expand Down
10 changes: 10 additions & 0 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type ErrorWriter struct {
grpcWebContentTypes map[string]struct{}
unaryConnectContentTypes map[string]struct{}
streamingConnectContentTypes map[string]struct{}
requireConnectProtocolHeader bool
}

// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
Expand All @@ -60,6 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
grpcWebContentTypes: make(map[string]struct{}),
unaryConnectContentTypes: make(map[string]struct{}),
streamingConnectContentTypes: make(map[string]struct{}),
requireConnectProtocolHeader: config.RequireConnectProtocolHeader,
}
for name := range config.Codecs {
unary := connectContentTypeFromCodecName(StreamTypeUnary, name)
Expand Down Expand Up @@ -87,9 +89,17 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
if _, ok := w.unaryConnectContentTypes[ctype]; ok {
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
}
if _, ok := w.streamingConnectContentTypes[ctype]; ok {
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
}
if _, ok := w.grpcContentTypes[ctype]; ok {
Expand Down
55 changes: 55 additions & 0 deletions error_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2021-2024 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect/internal/assert"
)

func TestErrorWriter(t *testing.T) {
t.Parallel()

t.Run("RequireConnectProtocolHeader", func(t *testing.T) {
t.Parallel()
writer := NewErrorWriter(WithRequireConnectProtocolHeader())

t.Run("Unary", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON)
assert.False(t, writer.IsSupported(req))
req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion)
assert.True(t, writer.IsSupported(req))
})
t.Run("UnaryGET", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
assert.False(t, writer.IsSupported(req))
query := req.URL.Query()
query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
req.URL.RawQuery = query.Encode()
assert.True(t, writer.IsSupported(req))
})
t.Run("Stream", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON)
assert.True(t, writer.IsSupported(req)) // ignores WithRequireConnectProtocolHeader
req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion)
assert.True(t, writer.IsSupported(req))
})
})
}
2 changes: 1 addition & 1 deletion option.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func WithRecover(handle func(context.Context, Spec, http.Header, any) error) Han
// header. This ensures that HTTP proxies and net/http middleware can easily
// identify valid Connect requests, even if they use a common Content-Type like
// application/json. However, it makes ad-hoc requests with tools like cURL
// more laborious.
// more laborious. Streaming requests are not affected by this option.
//
// This option has no effect if the client uses the gRPC or gRPC-Web protocols.
func WithRequireConnectProtocolHeader() HandlerOption {
Expand Down
42 changes: 26 additions & 16 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const (
connectFlagEnvelopeEndStream = 0b00000010

connectUnaryContentTypePrefix = "application/"
connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + "json"
connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + codecNameJSON
connectStreamingContentTypePrefix = "application/connect+"

connectUnaryEncodingQueryParameter = "encoding"
Expand Down Expand Up @@ -172,21 +172,9 @@ func (h *connectHandler) NewConn(
if failed == nil {
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
}
if failed == nil && request.Method == http.MethodGet {
version := query.Get(connectUnaryConnectQueryParameter)
if version == "" && h.RequireConnectProtocolHeader {
failed = errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
} else if version != "" && version != connectUnaryConnectQueryValue {
failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version)
}
}
if failed == nil && request.Method == http.MethodPost {
version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion)
if version == "" && h.RequireConnectProtocolHeader {
failed = errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
} else if version != "" && version != connectProtocolVersion {
failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version)
}
if failed == nil {
required := h.RequireConnectProtocolHeader && (h.Spec.StreamType == StreamTypeUnary)
failed = connectCheckProtocolVersion(request, required)
}

var requestBody io.ReadCloser
Expand Down Expand Up @@ -1442,3 +1430,25 @@ func connectValidateStreamResponseContentType(requestCodecName string, streamTyp
}
return nil
}

func connectCheckProtocolVersion(request *http.Request, required bool) *Error {
switch request.Method {
case http.MethodGet:
version := request.URL.Query().Get(connectUnaryConnectQueryParameter)
if version == "" && required {
return errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
} else if version != "" && version != connectUnaryConnectQueryValue {
return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version)
}
case http.MethodPost:
version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion)
if version == "" && required {
return errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
} else if version != "" && version != connectProtocolVersion {
return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version)
}
default:
return errorf(CodeInvalidArgument, "unsupported method: %q", request.Method)
}
return nil
}

0 comments on commit 6fab35e

Please sign in to comment.