diff --git a/core.go b/core.go index e80ddee..f34b8c0 100644 --- a/core.go +++ b/core.go @@ -225,7 +225,9 @@ func httpSpanAttributes(r *http.Request) []attribute.KeyValue { attrs := []attribute.KeyValue{ semconv.HTTPRequestMethodKey.String(r.Method), semconv.URLPath(r.URL.Path), - semconv.ServerAddress(host), + } + if host != "" { + attrs = append(attrs, semconv.ServerAddress(host)) } if port != "" { if p, err := strconv.Atoi(port); err == nil { @@ -235,9 +237,17 @@ func httpSpanAttributes(r *http.Request) []attribute.KeyValue { if r.URL.RawQuery != "" { attrs = append(attrs, semconv.URLQuery(r.URL.RawQuery)) } - if r.URL.Scheme != "" { - attrs = append(attrs, semconv.URLScheme(r.URL.Scheme)) + scheme := r.URL.Scheme + if scheme == "" { + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + scheme = proto + } else if r.TLS != nil { + scheme = "https" + } else { + scheme = "http" + } } + attrs = append(attrs, semconv.URLScheme(scheme)) return attrs } @@ -256,7 +266,18 @@ func tracingWrapper(h http.Handler) http.Handler { ) rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} w = rec - defer endSpan(serverSpan, rec) + defer func() { + if recovered := recover(); recovered != nil { + if !rec.wroteHeader { + rec.status = http.StatusInternalServerError + } + serverSpan.RecordError(fmt.Errorf("panic: %v", recovered)) + serverSpan.SetStatus(codes.Error, "panic") + endSpan(serverSpan, rec) + panic(recovered) + } + endSpan(serverSpan, rec) + }() } _, han := interceptors.NRHttpTracer("", h.ServeHTTP) @@ -268,12 +289,15 @@ func tracingWrapper(h http.Handler) http.Handler { // spanRouteMiddleware is a grpc-gateway middleware that updates the OTEL span // name and http.route attribute with the matched route pattern after routing. +// It uses runtime.HTTPPattern (the Pattern struct set by handleHandler) rather +// than runtime.HTTPPathPattern (the string set later inside AnnotateContext). func spanRouteMiddleware(next runtime.HandlerFunc) runtime.HandlerFunc { return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { - if pattern, ok := runtime.HTTPPathPattern(r.Context()); ok { + if pattern, ok := runtime.HTTPPattern(r.Context()); ok { + route := pattern.String() span := oteltrace.SpanFromContext(r.Context()) - span.SetName(r.Method + " " + pattern) - span.SetAttributes(semconv.HTTPRoute(pattern)) + span.SetName(r.Method + " " + route) + span.SetAttributes(semconv.HTTPRoute(route)) } next(w, r, pathParams) } diff --git a/core_coverage_test.go b/core_coverage_test.go index 9103d67..56792b8 100644 --- a/core_coverage_test.go +++ b/core_coverage_test.go @@ -12,6 +12,10 @@ import ( "github.com/go-coldbrew/core/config" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/codes" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -286,6 +290,143 @@ func TestSpanRouteMiddleware(t *testing.T) { }) } +// setupTestTracer installs an in-memory span exporter and returns it along with +// a cleanup function that restores the previous tracer provider. +func setupTestTracer() (*tracetest.InMemoryExporter, func()) { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + prev := otel.GetTracerProvider() + otel.SetTracerProvider(tp) + return exporter, func() { + otel.SetTracerProvider(prev) + _ = tp.Shutdown(context.Background()) + } +} + +// findSpanByName returns the first span with the given name, or nil if not found. +func findSpanByName(spans tracetest.SpanStubs, name string) *tracetest.SpanStub { + for i := range spans { + if spans[i].Name == name { + return &spans[i] + } + } + return nil +} + +// spanAttrMap returns a map of attribute key to value for a span. +func spanAttrMap(span *tracetest.SpanStub) map[string]any { + m := make(map[string]any, len(span.Attributes)) + for _, a := range span.Attributes { + m[string(a.Key)] = a.Value.AsInterface() + } + return m +} + +func TestTracingWrapperSpanAttributes(t *testing.T) { + exporter, cleanup := setupTestTracer() + defer cleanup() + + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := tracingWrapper(inner) + + req := httptest.NewRequest("GET", "/api/v1/rules?page=1", nil) + req.Host = "example.com:9091" + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, req) + + span := findSpanByName(exporter.GetSpans(), "GET") + if span == nil { + t.Fatal("expected span named 'GET'") + } + + attrs := spanAttrMap(span) + if v := attrs["http.request.method"]; v != "GET" { + t.Fatalf("expected http.request.method=GET, got %v", v) + } + if v := attrs["url.path"]; v != "/api/v1/rules" { + t.Fatalf("expected url.path=/api/v1/rules, got %v", v) + } + if v := attrs["url.query"]; v != "page=1" { + t.Fatalf("expected url.query=page=1, got %v", v) + } + if v := attrs["server.address"]; v != "example.com" { + t.Fatalf("expected server.address=example.com, got %v", v) + } + if v := attrs["server.port"]; v != int64(9091) { + t.Fatalf("expected server.port=9091, got %v", v) + } + if v := attrs["http.response.status_code"]; v != int64(200) { + t.Fatalf("expected http.response.status_code=200, got %v", v) + } + if v := attrs["url.scheme"]; v != "http" { + t.Fatalf("expected url.scheme=http, got %v", v) + } +} + +func TestTracingWrapperSpanErrorStatus(t *testing.T) { + exporter, cleanup := setupTestTracer() + defer cleanup() + + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + wrapped := tracingWrapper(inner) + + req := httptest.NewRequest("GET", "/api/error", nil) + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, req) + + span := findSpanByName(exporter.GetSpans(), "GET") + if span == nil { + t.Fatal("expected span named 'GET'") + } + + attrs := spanAttrMap(span) + if v := attrs["http.response.status_code"]; v != int64(500) { + t.Fatalf("expected http.response.status_code=500, got %v", v) + } + if span.Status.Code != codes.Error { + t.Fatalf("expected span status Error, got %v", span.Status.Code) + } +} + +func TestTracingWrapperGatewaySpanName(t *testing.T) { + exporter, cleanup := setupTestTracer() + defer cleanup() + + // Create a grpc-gateway mux with spanRouteMiddleware and a test handler. + mux := runtime.NewServeMux(runtime.WithMiddlewares(spanRouteMiddleware)) + err := mux.HandlePath("GET", "/api/v1/rules/{rule_id}", func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + w.WriteHeader(http.StatusOK) + }) + if err != nil { + t.Fatal(err) + } + + wrapped := tracingWrapper(mux) + req := httptest.NewRequest("GET", "/api/v1/rules/123", nil) + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, req) + + // Pattern.String() includes wildcard spec, e.g. {rule_id=*} + wantName := "GET /api/v1/rules/{rule_id=*}" + span := findSpanByName(exporter.GetSpans(), wantName) + if span == nil { + names := make([]string, 0) + for _, s := range exporter.GetSpans() { + names = append(names, s.Name) + } + t.Fatalf("expected span named %q, got spans: %v", wantName, names) + } + + attrs := spanAttrMap(span) + if v := attrs["http.route"]; v != "/api/v1/rules/{rule_id=*}" { + t.Fatalf("expected http.route=/api/v1/rules/{rule_id=*}, got %v", v) + } +} + func TestGetCustomHeaderMatcher_EmptyPrefixes(t *testing.T) { t.Parallel() matcher := getCustomHeaderMatcher(nil, "X-Trace-Id")