diff --git a/server/etcdserver/api/v3rpc/grpc.go b/server/etcdserver/api/v3rpc/grpc.go index ed55e0357c9f..efa151437d39 100644 --- a/server/etcdserver/api/v3rpc/grpc.go +++ b/server/etcdserver/api/v3rpc/grpc.go @@ -67,8 +67,7 @@ func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnarySer } if s.Cfg.EnableDistributedTracing { - chainUnaryInterceptors = append(chainUnaryInterceptors, otelgrpc.UnaryServerInterceptor(s.Cfg.TracerOptions...)) - chainStreamInterceptors = append(chainStreamInterceptors, otelgrpc.StreamServerInterceptor(s.Cfg.TracerOptions...)) + opts = append(opts, grpc.StatsHandler(otelgrpc.NewServerHandler(s.Cfg.TracerOptions...))) } opts = append(opts, grpc.ChainUnaryInterceptor(chainUnaryInterceptors...)) diff --git a/tests/integration/tracing_test.go b/tests/integration/tracing_test.go index 74f231a3aba4..5ecfc12a5b9d 100644 --- a/tests/integration/tracing_test.go +++ b/tests/integration/tracing_test.go @@ -38,6 +38,48 @@ import ( func TestTracing(t *testing.T) { testutil.SkipTestIfShortMode(t, "Wal creation tests are depending on embedded etcd server so are integration-level tests.") + + // Test Unary RPC tracing + t.Run("UnaryRPC", func(t *testing.T) { + testRPCTracing(t, "UnaryRPC", containsUnaryRPCSpan, func(cli *clientv3.Client) error { + // make a request with the instrumented client + resp, err := cli.Get(context.TODO(), "key") + require.NoError(t, err) + require.Empty(t, resp.Kvs) + return nil + }) + }) + + // Test Stream RPC tracing + t.Run("StreamRPC", func(t *testing.T) { + testRPCTracing(t, "StreamRPC", containsStreamRPCSpan, func(cli *clientv3.Client) error { + // Create a context with a reasonable timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create a watch channel + watchChan := cli.Watch(ctx, "watch-key") + + // Put a value to trigger the watch + _, err := cli.Put(context.TODO(), "watch-key", "watch-value") + require.NoError(t, err) + + // Wait for watch event + select { + case watchResp := <-watchChan: + require.NoError(t, watchResp.Err()) + require.Len(t, watchResp.Events, 1) + t.Log("Received watch event successfully") + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for watch event") + } + return nil + }) + }) +} + +// testRPCTracing is a common test function for both Unary and Stream RPC tracing +func testRPCTracing(t *testing.T, testName string, filterFunc func(*traceservice.ExportTraceServiceRequest) bool, clientAction func(*clientv3.Client) error) { // set up trace collector listener, err := net.Listen("tcp", "localhost:") require.NoError(t, err) @@ -48,7 +90,7 @@ func TestTracing(t *testing.T) { srv := grpc.NewServer() traceservice.RegisterTraceServiceServer(srv, &traceServer{ traceFound: traceFound, - filterFunc: containsNodeListSpan, + filterFunc: filterFunc, }) go srv.Serve(listener) @@ -89,8 +131,7 @@ func TestTracing(t *testing.T) { } dialOptions := []grpc.DialOption{ - grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor(tracingOpts...)), - grpc.WithStreamInterceptor(otelgrpc.StreamClientInterceptor(tracingOpts...)), + grpc.WithStatsHandler(otelgrpc.NewClientHandler(tracingOpts...)), } ccfg := clientv3.Config{DialOptions: dialOptions, Endpoints: []string{cfg.AdvertiseClientUrls[0].String()}} cli, err := integration.NewClient(t, ccfg) @@ -100,21 +141,21 @@ func TestTracing(t *testing.T) { } defer cli.Close() - // make a request with the instrumented client - resp, err := cli.Get(context.TODO(), "key") + // Execute the client action (either Unary or Stream RPC) + err = clientAction(cli) require.NoError(t, err) - require.Empty(t, resp.Kvs) // Wait for a span to be recorded from our request select { case <-traceFound: + t.Logf("%s trace found", testName) return case <-time.After(30 * time.Second): t.Fatal("Timed out waiting for trace") } } -func containsNodeListSpan(req *traceservice.ExportTraceServiceRequest) bool { +func containsUnaryRPCSpan(req *traceservice.ExportTraceServiceRequest) bool { for _, resourceSpans := range req.GetResourceSpans() { for _, attr := range resourceSpans.GetResource().GetAttributes() { if attr.GetKey() != "service.name" && attr.GetValue().GetStringValue() != "integration-test-tracing" { @@ -132,6 +173,20 @@ func containsNodeListSpan(req *traceservice.ExportTraceServiceRequest) bool { return false } +// containsStreamRPCSpan checks for Watch/Watch spans in trace data +func containsStreamRPCSpan(req *traceservice.ExportTraceServiceRequest) bool { + for _, resourceSpans := range req.GetResourceSpans() { + for _, scoped := range resourceSpans.GetScopeSpans() { + for _, span := range scoped.GetSpans() { + if span.GetName() == "etcdserverpb.Watch/Watch" { + return true + } + } + } + } + return false +} + // traceServer implements TracesServiceServer type traceServer struct { traceFound chan struct{} @@ -142,7 +197,11 @@ type traceServer struct { func (t *traceServer) Export(ctx context.Context, req *traceservice.ExportTraceServiceRequest) (*traceservice.ExportTraceServiceResponse, error) { emptyValue := traceservice.ExportTraceServiceResponse{} if t.filterFunc(req) { - t.traceFound <- struct{}{} + select { + case t.traceFound <- struct{}{}: + default: + // Channel already notified + } } return &emptyValue, nil }