Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
}
handleForwardResponseServerMetadata(w, mux, md)

w.Header().Set("Transfer-Encoding", "chunked")
if !mux.disableChunkedEncoding {
w.Header().Set("Transfer-Encoding", "chunked")
}
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
Expand Down
32 changes: 25 additions & 7 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ func TestForwardResponseStream(t *testing.T) {
err error
}
tests := []struct {
name string
msgs []msg
statusCode int
responseBody bool
name string
msgs []msg
statusCode int
responseBody bool
disableChunkedEncoding bool
}{{
name: "encoding",
msgs: []msg{
Expand Down Expand Up @@ -74,6 +75,13 @@ func TestForwardResponseStream(t *testing.T) {
},
responseBody: true,
statusCode: http.StatusOK,
}, {
name: "disable chunked encoding",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
},
statusCode: http.StatusOK,
disableChunkedEncoding: true,
}}

newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
Expand All @@ -97,14 +105,24 @@ func TestForwardResponseStream(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
mux := runtime.NewServeMux()
if tt.disableChunkedEncoding {
mux = runtime.NewServeMux(runtime.WithDisableChunkedEncoding())
}
runtime.ForwardResponseStream(ctx, mux, marshaler, resp, req, recv)

w := resp.Result()
if w.StatusCode != tt.statusCode {
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
}
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
if !tt.disableChunkedEncoding {
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
}
} else {
if h := w.Header.Get("Transfer-Encoding"); h != "" {
t.Errorf("ForwardResponseStream unexpected Transfer-Encoding header %s", h)
}
}
body, err := io.ReadAll(w.Body)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type ServeMux struct {
disablePathLengthFallback bool
unescapingMode UnescapingMode
writeContentLength bool
disableChunkedEncoding bool
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand Down Expand Up @@ -125,6 +126,16 @@ func WithMiddlewares(middlewares ...Middleware) ServeMuxOption {
}
}

// WithDisableChunkedEncoding disables the Transfer-Encoding: chunked header
// for streaming responses. This is useful for streaming implementations that use
// Content-Length, which is mutually exclusive with Transfer-Encoding:chunked.
// Note that this option will not automatically add Content-Length headers, so it should be used with caution.
func WithDisableChunkedEncoding() ServeMuxOption {
return func(mux *ServeMux) {
mux.disableChunkedEncoding = true
}
}

// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
Expand Down
Loading