diff --git a/CHANGELOG.md b/CHANGELOG.md index f40d8501cc0..f5fe8c4bc22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +- Implement `http.Hijacker` in`go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux`. (#6562) + ### Added - Generate server metrics with semantic conventions v1.26 in `go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp` when `OTEL_SEMCONV_STABILITY_OPT_IN` is set to `http/dup`. (#6411) diff --git a/instrumentation/github.com/gorilla/mux/otelmux/mux.go b/instrumentation/github.com/gorilla/mux/otelmux/mux.go index c7b2355eca8..d7f083d8c28 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/mux.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/mux.go @@ -4,7 +4,9 @@ package otelmux // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" import ( + "bufio" "fmt" + "net" "net/http" "sync" @@ -108,6 +110,13 @@ func getRRW(writer http.ResponseWriter) *recordingResponseWriter { return rrw } +func (h *recordingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := h.writer.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking") +} + func putRRW(rrw *recordingResponseWriter) { rrw.writer = nil rrwPool.Put(rrw) diff --git a/instrumentation/github.com/gorilla/mux/otelmux/mux_test.go b/instrumentation/github.com/gorilla/mux/otelmux/mux_test.go index a83e1689ebb..0c3ba0dbd2b 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/mux_test.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/mux_test.go @@ -35,17 +35,14 @@ func TestPassthroughSpanFromGlobalTracer(t *testing.T) { // span context in the incoming request context. router.HandleFunc("/user/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true - got := trace.SpanFromContext(r.Context()).SpanContext() - assert.Equal(t, sc, got) - w.WriteHeader(http.StatusOK) })) - r := httptest.NewRequest("GET", "/user/123", nil) - r = r.WithContext(trace.ContextWithRemoteSpanContext(context.Background(), sc)) + req := httptest.NewRequest("GET", "/user/123", nil) + req = req.WithContext(trace.ContextWithSpanContext(context.Background(), sc)) w := httptest.NewRecorder() + router.ServeHTTP(w, req) - router.ServeHTTP(w, r) - assert.True(t, called, "failed to run test") + assert.True(t, called) } func TestPropagationWithGlobalPropagators(t *testing.T) { @@ -190,3 +187,18 @@ func TestFilter(t *testing.T) { assert.Equal(t, 1, calledHealth, "failed to run test") assert.Equal(t, 1, calledTest, "failed to run test") } + +func TestRecordingResponseWriterHijack(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rrw := getRRW(w) + conn, rw, err := rrw.Hijack() + assert.Nil(t, conn) + assert.Nil(t, rw) + assert.NotNil(t, err) + assert.Equal(t, "underlying ResponseWriter does not support hijacking", err.Error()) + }) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) +} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go b/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go index b009022b1bb..95f0f5809fb 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go @@ -282,3 +282,25 @@ func TestWithPublicEndpointFn(t *testing.T) { }) } } + +func getRRW(w http.ResponseWriter) http.Hijacker { + if hijacker, ok := w.(http.Hijacker); ok { + return hijacker + } + return nil +} + +func TestRecordingResponseWriterHijack(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rrw := getRRW(w) + conn, rw, err := rrw.Hijack() + assert.Nil(t, conn) + assert.Nil(t, rw) + assert.NotNil(t, err) + assert.Equal(t, "underlying ResponseWriter does not support hijacking", err.Error()) + }) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) +}