diff --git a/runtime/handler.go b/runtime/handler.go index 2f0b9e9e0f8..4c9083d790f 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -28,7 +28,9 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal } handleForwardResponseServerMetadata(w, mux, md) - w.Header().Set("Transfer-Encoding", "chunked") + if !mux.disableChunkedEncoding { + w.Header().Set("Transfer-Encoding", "chunked") + } if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { HTTPError(ctx, mux, marshaler, w, req, err) return diff --git a/runtime/handler_test.go b/runtime/handler_test.go index 42ce5eb13d0..fef8298eff7 100644 --- a/runtime/handler_test.go +++ b/runtime/handler_test.go @@ -33,10 +33,11 @@ func TestForwardResponseStream(t *testing.T) { err error } tests := []struct { - name string - msgs []msg - statusCode int - responseBody bool + name string + msgs []msg + statusCode int + responseBody bool + disableChunkedEncoding bool }{{ name: "encoding", msgs: []msg{ @@ -74,6 +75,13 @@ func TestForwardResponseStream(t *testing.T) { }, responseBody: true, statusCode: http.StatusOK, + }, { + name: "disable chunked encoding", + msgs: []msg{ + {&pb.SimpleMessage{Id: "One"}, nil}, + }, + statusCode: http.StatusOK, + disableChunkedEncoding: true, }} newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) { @@ -97,14 +105,24 @@ func TestForwardResponseStream(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/foo", nil) resp := httptest.NewRecorder() - runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv) + mux := runtime.NewServeMux() + if tt.disableChunkedEncoding { + mux = runtime.NewServeMux(runtime.WithDisableChunkedEncoding()) + } + runtime.ForwardResponseStream(ctx, mux, marshaler, resp, req, recv) w := resp.Result() if w.StatusCode != tt.statusCode { t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode) } - if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { - t.Errorf("ForwardResponseStream missing header chunked") + if !tt.disableChunkedEncoding { + if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { + t.Errorf("ForwardResponseStream missing header chunked") + } + } else { + if h := w.Header.Get("Transfer-Encoding"); h != "" { + t.Errorf("ForwardResponseStream unexpected Transfer-Encoding header %s", h) + } } body, err := io.ReadAll(w.Body) if err != nil { diff --git a/runtime/mux.go b/runtime/mux.go index 3eb16167173..4e684c7de6c 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -73,6 +73,7 @@ type ServeMux struct { disablePathLengthFallback bool unescapingMode UnescapingMode writeContentLength bool + disableChunkedEncoding bool } // ServeMuxOption is an option that can be given to a ServeMux on construction. @@ -125,6 +126,16 @@ func WithMiddlewares(middlewares ...Middleware) ServeMuxOption { } } +// WithDisableChunkedEncoding disables the Transfer-Encoding: chunked header +// for streaming responses. This is useful for streaming implementations that use +// Content-Length, which is mutually exclusive with Transfer-Encoding:chunked. +// Note that this option will not automatically add Content-Length headers, so it should be used with caution. +func WithDisableChunkedEncoding() ServeMuxOption { + return func(mux *ServeMux) { + mux.disableChunkedEncoding = true + } +} + // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters. // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be // done with careful consideration.