diff --git a/CHANGELOG.md b/CHANGELOG.md index a0150e526a4..11fe8f064d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - The error received by `otelecho` middleware is then passed back to upstream middleware instead of being swallowed. (#3656) - Prevent taking from reservoir in AWS XRay Remote Sampler when there is zero capacity in `go.opentelemetry.io/contrib/samplers/aws/xray`. (#3684) +- Fix `otelhttp.Handler` in `go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp` to propagate multiple `WriteHeader` calls while persisting the initial `statusCode`. (#3580) ## [1.16.0-rc.2/0.41.0-rc.2/0.10.0-rc.2] - 2023-03-23 diff --git a/instrumentation/net/http/otelhttp/test/handler_test.go b/instrumentation/net/http/otelhttp/test/handler_test.go index 1ca75649d15..bb32935c176 100644 --- a/instrumentation/net/http/otelhttp/test/handler_test.go +++ b/instrumentation/net/http/otelhttp/test/handler_test.go @@ -178,6 +178,16 @@ func TestHandlerEmittedAttributes(t *testing.T) { attribute.Int("http.status_code", http.StatusOK), }, }, + { + name: "With persisting initial failing status in handler with multiple WriteHeader calls", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusOK) + }, + attributes: []attribute.KeyValue{ + attribute.Int("http.status_code", http.StatusInternalServerError), + }, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -201,6 +211,72 @@ func TestHandlerEmittedAttributes(t *testing.T) { } } +type respWriteHeaderCounter struct { + http.ResponseWriter + + headersWritten []int +} + +func (rw *respWriteHeaderCounter) WriteHeader(statusCode int) { + rw.headersWritten = append(rw.headersWritten, statusCode) + rw.ResponseWriter.WriteHeader(statusCode) +} + +func TestHandlerPropagateWriteHeaderCalls(t *testing.T) { + testCases := []struct { + name string + handler func(http.ResponseWriter, *http.Request) + expectHeadersWritten []int + }{ + { + name: "With a success handler", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + expectHeadersWritten: []int{http.StatusOK}, + }, + { + name: "With a failing handler", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + expectHeadersWritten: []int{http.StatusBadRequest}, + }, + { + name: "With an empty handler", + handler: func(w http.ResponseWriter, r *http.Request) { + }, + + expectHeadersWritten: nil, + }, + { + name: "With calling WriteHeader twice", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusOK) + }, + expectHeadersWritten: []int{http.StatusInternalServerError, http.StatusOK}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sr := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider() + provider.RegisterSpanProcessor(sr) + h := otelhttp.NewHandler( + http.HandlerFunc(tc.handler), "test_handler", + otelhttp.WithTracerProvider(provider), + ) + + recorder := httptest.NewRecorder() + rw := &respWriteHeaderCounter{ResponseWriter: recorder} + h.ServeHTTP(rw, httptest.NewRequest("GET", "/", nil)) + require.EqualValues(t, tc.expectHeadersWritten, rw.headersWritten, "should propagate all WriteHeader calls to underlying ResponseWriter") + }) + } +} + func TestHandlerRequestWithTraceContext(t *testing.T) { rr := httptest.NewRecorder() diff --git a/instrumentation/net/http/otelhttp/wrap.go b/instrumentation/net/http/otelhttp/wrap.go index da6468c4e59..11a35ed167f 100644 --- a/instrumentation/net/http/otelhttp/wrap.go +++ b/instrumentation/net/http/otelhttp/wrap.go @@ -50,7 +50,7 @@ func (w *bodyWrapper) Close() error { var _ http.ResponseWriter = &respWriterWrapper{} // respWriterWrapper wraps a http.ResponseWriter in order to track the number of -// bytes written, the last error, and to catch the returned statusCode +// bytes written, the last error, and to catch the first written statusCode. // TODO: The wrapped http.ResponseWriter doesn't implement any of the optional // types (http.Hijacker, http.Pusher, http.CloseNotifier, http.Flusher, etc) // that may be useful when using it in real life situations. @@ -85,11 +85,15 @@ func (w *respWriterWrapper) Write(p []byte) (int, error) { return n, err } +// WriteHeader persists initial statusCode for span attribution. +// All calls to WriteHeader will be propagated to the underlying ResponseWriter +// and will persist the statusCode from the first call. +// Blocking consecutive calls to WriteHeader alters expected behavior and will +// remove warning logs from net/http where developers will notice incorrect handler implementations. func (w *respWriterWrapper) WriteHeader(statusCode int) { - if w.wroteHeader { - return + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = statusCode } - w.wroteHeader = true - w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) }