diff --git a/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper.go b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper.go new file mode 100644 index 00000000000..afea8c35f3d --- /dev/null +++ b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper.go @@ -0,0 +1,75 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/request" + +import ( + "io" + "sync" +) + +var _ io.ReadCloser = &BodyWrapper{} + +// BodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number +// of bytes read and the last error. +type BodyWrapper struct { + io.ReadCloser + OnRead func(n int64) // must not be nil + + mu sync.Mutex + read int64 + err error +} + +// NewBodyWrapper creates a new BodyWrapper. +// +// The onRead attribute is a callback that will be called every time the data +// is read, with the number of bytes being read. +func NewBodyWrapper(body io.ReadCloser, onRead func(int64)) *BodyWrapper { + return &BodyWrapper{ + ReadCloser: body, + OnRead: onRead, + } +} + +// Read reads the data from the io.ReadCloser, and stores the number of bytes +// read and the error. +func (w *BodyWrapper) Read(b []byte) (int, error) { + n, err := w.ReadCloser.Read(b) + n1 := int64(n) + + w.updateReadData(n1, err) + w.OnRead(n1) + return n, err +} + +func (w *BodyWrapper) updateReadData(n int64, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.read += n + if err != nil { + w.err = err + } +} + +// Closes closes the io.ReadCloser. +func (w *BodyWrapper) Close() error { + return w.ReadCloser.Close() +} + +// BytesRead returns the number of bytes read up to this point. +func (w *BodyWrapper) BytesRead() int64 { + w.mu.Lock() + defer w.mu.Unlock() + + return w.read +} + +// Error returns the last error. +func (w *BodyWrapper) Error() error { + w.mu.Lock() + defer w.mu.Unlock() + + return w.err +} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper_test.go b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper_test.go new file mode 100644 index 00000000000..794e54fb9c8 --- /dev/null +++ b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper_test.go @@ -0,0 +1,74 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request + +import ( + "errors" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var errFirstCall = errors.New("first call") + +func TestBodyWrapper(t *testing.T) { + bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {}) + + data, err := io.ReadAll(bw) + require.NoError(t, err) + assert.Equal(t, "hello world", string(data)) + + assert.Equal(t, int64(11), bw.BytesRead()) + assert.Equal(t, io.EOF, bw.Error()) +} + +type multipleErrorsReader struct { + calls int +} + +type errorWrapper struct{} + +func (errorWrapper) Error() string { + return "subsequent calls" +} + +func (mer *multipleErrorsReader) Read([]byte) (int, error) { + mer.calls = mer.calls + 1 + if mer.calls == 1 { + return 0, errFirstCall + } + + return 0, errorWrapper{} +} + +func TestBodyWrapperWithErrors(t *testing.T) { + bw := NewBodyWrapper(io.NopCloser(&multipleErrorsReader{}), func(int64) {}) + + data, err := io.ReadAll(bw) + require.Equal(t, errFirstCall, err) + assert.Equal(t, "", string(data)) + require.Equal(t, errFirstCall, bw.Error()) + + data, err = io.ReadAll(bw) + require.Equal(t, errorWrapper{}, err) + assert.Equal(t, "", string(data)) + require.Equal(t, errorWrapper{}, bw.Error()) +} + +func TestConcurrentBodyWrapper(t *testing.T) { + bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {}) + + go func() { + _, _ = io.ReadAll(bw) + }() + + assert.NotNil(t, bw.BytesRead()) + assert.Eventually(t, func() bool { + return errors.Is(bw.Error(), io.EOF) + }, time.Second, 10*time.Millisecond) +} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper.go b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper.go new file mode 100644 index 00000000000..537465b05b1 --- /dev/null +++ b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper.go @@ -0,0 +1,119 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/request" + +import ( + "net/http" + "sync" +) + +var _ http.ResponseWriter = &RespWriterWrapper{} + +// RespWriterWrapper wraps a http.ResponseWriter in order to track the number of +// bytes written, the last error, and to catch the first written statusCode. +// TODO: The wrapped http.ResponseWriter doesn't implement any of the optional +// types (http.Hijacker, http.Pusher, http.CloseNotifier, etc) +// that may be useful when using it in real life situations. +type RespWriterWrapper struct { + http.ResponseWriter + OnWrite func(n int64) // must not be nil + + mu sync.RWMutex + written int64 + statusCode int + err error + wroteHeader bool +} + +// NewRespWriterWrapper creates a new RespWriterWrapper. +// +// The onWrite attribute is a callback that will be called every time the data +// is written, with the number of bytes that were written. +func NewRespWriterWrapper(w http.ResponseWriter, onWrite func(int64)) *RespWriterWrapper { + return &RespWriterWrapper{ + ResponseWriter: w, + OnWrite: onWrite, + statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything + } +} + +// Write writes the bytes array into the [ResponseWriter], and tracks the +// number of bytes written and last error. +func (w *RespWriterWrapper) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.wroteHeader { + w.writeHeader(http.StatusOK) + } + + n, err := w.ResponseWriter.Write(p) + n1 := int64(n) + w.OnWrite(n1) + w.written += n1 + w.err = err + return n, err +} + +// WriteHeader persists initial statusCode for span attribution. +// All calls to WriteHeader will be propagated to the underlying ResponseWriter +// and will persist the statusCode from the first call. +// Blocking consecutive calls to WriteHeader alters expected behavior and will +// remove warning logs from net/http where developers will notice incorrect handler implementations. +func (w *RespWriterWrapper) WriteHeader(statusCode int) { + w.mu.Lock() + defer w.mu.Unlock() + + w.writeHeader(statusCode) +} + +// writeHeader persists the status code for span attribution, and propagates +// the call to the underlying ResponseWriter. +// It does not acquire a lock, and therefore assumes that is being handled by a +// parent method. +func (w *RespWriterWrapper) writeHeader(statusCode int) { + if !w.wroteHeader { + w.wroteHeader = true + w.statusCode = statusCode + } + w.ResponseWriter.WriteHeader(statusCode) +} + +// Flush implements [http.Flusher]. +func (w *RespWriterWrapper) Flush() { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.wroteHeader { + w.writeHeader(http.StatusOK) + } + + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// BytesWritten returns the number of bytes written. +func (w *RespWriterWrapper) BytesWritten() int64 { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.written +} + +// BytesWritten returns the HTTP status code that was sent. +func (w *RespWriterWrapper) StatusCode() int { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.statusCode +} + +// Error returns the last error. +func (w *RespWriterWrapper) Error() error { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.err +} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper_test.go b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper_test.go new file mode 100644 index 00000000000..21229b4dc69 --- /dev/null +++ b/instrumentation/github.com/gorilla/mux/otelmux/internal/request/resp_writer_wrapper_test.go @@ -0,0 +1,63 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package request + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRespWriterWriteHeader(t *testing.T) { + rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) + + rw.WriteHeader(http.StatusTeapot) + assert.Equal(t, http.StatusTeapot, rw.statusCode) + assert.True(t, rw.wroteHeader) + + rw.WriteHeader(http.StatusGone) + assert.Equal(t, http.StatusTeapot, rw.statusCode) +} + +func TestRespWriterFlush(t *testing.T) { + rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) + + rw.Flush() + assert.Equal(t, http.StatusOK, rw.statusCode) + assert.True(t, rw.wroteHeader) +} + +type nonFlushableResponseWriter struct{} + +func (_ nonFlushableResponseWriter) Header() http.Header { + return http.Header{} +} + +func (_ nonFlushableResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (_ nonFlushableResponseWriter) WriteHeader(int) {} + +func TestRespWriterFlushNoFlusher(t *testing.T) { + rw := NewRespWriterWrapper(nonFlushableResponseWriter{}, func(int64) {}) + + rw.Flush() + assert.Equal(t, http.StatusOK, rw.statusCode) + assert.True(t, rw.wroteHeader) +} + +func TestConcurrentRespWriterWrapper(t *testing.T) { + rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) + + go func() { + _, _ = rw.Write([]byte("hello world")) + }() + + assert.NotNil(t, rw.BytesWritten()) + assert.NotNil(t, rw.StatusCode()) + assert.NoError(t, rw.Error()) +} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/mux.go b/instrumentation/github.com/gorilla/mux/otelmux/mux.go index c7b2355eca8..73087e2f377 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/mux.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/mux.go @@ -6,12 +6,13 @@ package otelmux // import "go.opentelemetry.io/contrib/instrumentation/github.co import ( "fmt" "net/http" - "sync" "github.com/felixge/httpsnoop" "github.com/gorilla/mux" + "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/request" "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/semconvutil" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.20.0" @@ -70,49 +71,6 @@ type traceware struct { filters []Filter } -type recordingResponseWriter struct { - writer http.ResponseWriter - written bool - status int -} - -var rrwPool = &sync.Pool{ - New: func() interface{} { - return &recordingResponseWriter{} - }, -} - -func getRRW(writer http.ResponseWriter) *recordingResponseWriter { - rrw := rrwPool.Get().(*recordingResponseWriter) - rrw.written = false - rrw.status = http.StatusOK - rrw.writer = httpsnoop.Wrap(writer, httpsnoop.Hooks{ - Write: func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc { - return func(b []byte) (int, error) { - if !rrw.written { - rrw.written = true - } - return next(b) - } - }, - WriteHeader: func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { - return func(statusCode int) { - if !rrw.written { - rrw.written = true - rrw.status = statusCode - } - next(statusCode) - } - }, - }) - return rrw -} - -func putRRW(rrw *recordingResponseWriter) { - rrw.writer = nil - rrwPool.Put(rrw) -} - // defaultSpanNameFunc just reuses the route name as the span name. func defaultSpanNameFunc(routeName string, _ *http.Request) string { return routeName } @@ -163,12 +121,41 @@ func (tw traceware) ServeHTTP(w http.ResponseWriter, r *http.Request) { spanName := tw.spanNameFormatter(routeStr, r) ctx, span := tw.tracer.Start(ctx, spanName, opts...) defer span.End() - r2 := r.WithContext(ctx) - rrw := getRRW(w) - defer putRRW(rrw) - tw.handler.ServeHTTP(rrw.writer, r2) - if rrw.status > 0 { - span.SetAttributes(semconv.HTTPStatusCode(rrw.status)) + + readRecordFunc := func(int64) {} + // if request body is nil or NoBody, we don't want to mutate the body as it + // will affect the identity of it in an unforeseeable way because we assert + // ReadCloser fulfills a certain interface and it is indeed nil or NoBody. + bw := request.NewBodyWrapper(r.Body, readRecordFunc) + if r.Body != nil && r.Body != http.NoBody { + r.Body = bw + } + + writeRecordFunc := func(int64) {} + rww := request.NewRespWriterWrapper(w, writeRecordFunc) + + // Wrap w to use our ResponseWriter methods while also exposing + // other interfaces that w may implement (http.CloseNotifier, + // http.Flusher, http.Hijacker, http.Pusher, io.ReaderFrom). + w = httpsnoop.Wrap(w, httpsnoop.Hooks{ + Header: func(httpsnoop.HeaderFunc) httpsnoop.HeaderFunc { + return rww.Header + }, + Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return rww.Write + }, + WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return rww.WriteHeader + }, + Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc { + return rww.Flush + }, + }) + + tw.handler.ServeHTTP(w, r.WithContext(ctx)) + statusCode := rww.StatusCode() + if statusCode > 0 { + span.SetAttributes(semconv.HTTPStatusCode(statusCode)) } - span.SetStatus(semconvutil.HTTPServerStatus(rrw.status)) + span.SetStatus(semconvutil.HTTPServerStatus(statusCode)) }