Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ErrorWriter IsSupportedCheck with required connect protocol option #700

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ linters:
enable-all: true
disable:
- cyclop # covered by gocyclo
- depguard # unnecessary for small libraries
- deadcode # abandoned
- depguard # unnecessary for small libraries
- exhaustivestruct # replaced by exhaustruct
- funlen # rely on code review to limit function length
- gocognit # dubious "cognitive overhead" quantification
Expand Down
2 changes: 1 addition & 1 deletion codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestStableCodec(t *testing.T) {
func TestJSONCodec(t *testing.T) {
t.Parallel()

codec := &protoJSONCodec{name: "json"}
codec := &protoJSONCodec{name: codecNameJSON}

t.Run("success", func(t *testing.T) {
t.Parallel()
Expand Down
10 changes: 10 additions & 0 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type ErrorWriter struct {
grpcWebContentTypes map[string]struct{}
unaryConnectContentTypes map[string]struct{}
streamingConnectContentTypes map[string]struct{}
requireConnectProtocolHeader bool
}

// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
Expand All @@ -60,6 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
grpcWebContentTypes: make(map[string]struct{}),
unaryConnectContentTypes: make(map[string]struct{}),
streamingConnectContentTypes: make(map[string]struct{}),
requireConnectProtocolHeader: config.RequireConnectProtocolHeader,
}
for name := range config.Codecs {
unary := connectContentTypeFromCodecName(StreamTypeUnary, name)
Expand Down Expand Up @@ -87,9 +89,17 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
if _, ok := w.unaryConnectContentTypes[ctype]; ok {
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
}
if _, ok := w.streamingConnectContentTypes[ctype]; ok {
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
}
if _, ok := w.grpcContentTypes[ctype]; ok {
Expand Down
55 changes: 55 additions & 0 deletions error_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2021-2024 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect/internal/assert"
)

func TestErrorWriter(t *testing.T) {
t.Parallel()

t.Run("RequireConnectProtocolHeader", func(t *testing.T) {
t.Parallel()
writer := NewErrorWriter(WithRequireConnectProtocolHeader())

t.Run("Unary", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON)
assert.False(t, writer.IsSupported(req))
req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion)
assert.True(t, writer.IsSupported(req))
})
t.Run("UnaryGET", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
assert.False(t, writer.IsSupported(req))
query := req.URL.Query()
query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
req.URL.RawQuery = query.Encode()
assert.True(t, writer.IsSupported(req))
})
t.Run("Stream", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON)
assert.True(t, writer.IsSupported(req)) // ignores WithRequireConnectProtocolHeader
req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion)
assert.True(t, writer.IsSupported(req))
})
})
}
2 changes: 1 addition & 1 deletion option.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func WithRecover(handle func(context.Context, Spec, http.Header, any) error) Han
// header. This ensures that HTTP proxies and net/http middleware can easily
// identify valid Connect requests, even if they use a common Content-Type like
// application/json. However, it makes ad-hoc requests with tools like cURL
// more laborious.
// more laborious. Streaming requests are not affected by this option.
//
// This option has no effect if the client uses the gRPC or gRPC-Web protocols.
func WithRequireConnectProtocolHeader() HandlerOption {
Expand Down
42 changes: 26 additions & 16 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const (
connectFlagEnvelopeEndStream = 0b00000010

connectUnaryContentTypePrefix = "application/"
connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + "json"
connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + codecNameJSON
connectStreamingContentTypePrefix = "application/connect+"

connectUnaryEncodingQueryParameter = "encoding"
Expand Down Expand Up @@ -172,21 +172,9 @@ func (h *connectHandler) NewConn(
if failed == nil {
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
}
if failed == nil && request.Method == http.MethodGet {
version := query.Get(connectUnaryConnectQueryParameter)
if version == "" && h.RequireConnectProtocolHeader {
failed = errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
} else if version != "" && version != connectUnaryConnectQueryValue {
failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version)
}
}
if failed == nil && request.Method == http.MethodPost {
version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion)
if version == "" && h.RequireConnectProtocolHeader {
failed = errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
} else if version != "" && version != connectProtocolVersion {
failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version)
}
if failed == nil {
required := h.RequireConnectProtocolHeader && (h.Spec.StreamType == StreamTypeUnary)
failed = connectCheckProtocolVersion(request, required)
}

var requestBody io.ReadCloser
Expand Down Expand Up @@ -1442,3 +1430,25 @@ func connectValidateStreamResponseContentType(requestCodecName string, streamTyp
}
return nil
}

func connectCheckProtocolVersion(request *http.Request, required bool) *Error {
switch request.Method {
case http.MethodGet:
version := request.URL.Query().Get(connectUnaryConnectQueryParameter)
if version == "" && required {
return errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue)
} else if version != "" && version != connectUnaryConnectQueryValue {
return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version)
}
case http.MethodPost:
version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion)
if version == "" && required {
return errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
} else if version != "" && version != connectProtocolVersion {
return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version)
}
default:
return errorf(CodeInvalidArgument, "unsupported method: %q", request.Method)
}
return nil
}
Loading