diff --git a/cmd/jaeger/internal/extension/jaegermcp/server.go b/cmd/jaeger/internal/extension/jaegermcp/server.go index 1975ecb6f34..cd1ca3e682c 100644 --- a/cmd/jaeger/internal/extension/jaegermcp/server.go +++ b/cmd/jaeger/internal/extension/jaegermcp/server.go @@ -20,6 +20,7 @@ import ( "github.com/jaegertracing/jaeger/cmd/jaeger/internal/extension/jaegermcp/internal/handlers" "github.com/jaegertracing/jaeger/cmd/jaeger/internal/extension/jaegerquery" "github.com/jaegertracing/jaeger/cmd/jaeger/internal/extension/jaegerquery/querysvc" + "github.com/jaegertracing/jaeger/internal/tenancy" ) var ( @@ -61,6 +62,7 @@ func (s *server) Start(ctx context.Context, host component.Host) error { return fmt.Errorf("cannot get %s extension: %w", jaegerquery.ID, err) } s.queryAPI = queryExt.QueryService() + tenancyMgr := queryExt.TenancyManager() s.mcpServer = mcp.NewServer( &mcp.Implementation{ Name: s.config.ServerName, @@ -82,6 +84,8 @@ func (s *server) Start(ctx context.Context, host component.Host) error { }, ) + handler := tenancy.ExtractTenantHTTPHandler(tenancyMgr, mcpHandler) + s.listener, err = s.config.HTTP.ToListener(ctx) if err != nil { return fmt.Errorf("failed to listen on %s: %w", s.config.HTTP.NetAddr.Endpoint, err) @@ -91,7 +95,7 @@ func (s *server) Start(ctx context.Context, host component.Host) error { ctx, host.GetExtensions(), s.telset, - mcpHandler, + handler, ) if err != nil { s.listener.Close() diff --git a/cmd/jaeger/internal/extension/jaegermcp/server_test.go b/cmd/jaeger/internal/extension/jaegermcp/server_test.go index 34945f8fe6f..158f2c91a8e 100644 --- a/cmd/jaeger/internal/extension/jaegermcp/server_test.go +++ b/cmd/jaeger/internal/extension/jaegermcp/server_test.go @@ -31,25 +31,31 @@ import ( depstoremocks "github.com/jaegertracing/jaeger/internal/storage/v2/api/depstore/mocks" "github.com/jaegertracing/jaeger/internal/storage/v2/api/tracestore" tracestoremocks "github.com/jaegertracing/jaeger/internal/storage/v2/api/tracestore/mocks" + "github.com/jaegertracing/jaeger/internal/tenancy" ) // mockQueryExtension implements jaegerquery.Extension for testing type mockQueryExtension struct { extension.Extension svc *querysvc.QueryService + tm *tenancy.Manager } func newMockQueryExtension(svc *querysvc.QueryService) *mockQueryExtension { if svc == nil { svc = querysvc.NewQueryService(&tracestoremocks.Reader{}, &depstoremocks.Reader{}, querysvc.QueryServiceOptions{}) } - return &mockQueryExtension{svc: svc} + return &mockQueryExtension{svc: svc, tm: tenancy.NewManager(&tenancy.Options{})} } func (m *mockQueryExtension) QueryService() *querysvc.QueryService { return m.svc } +func (m *mockQueryExtension) TenancyManager() *tenancy.Manager { + return m.tm +} + // mockHost implements component.Host with a jaegerquery extension type mockHost struct { component.Host @@ -70,6 +76,16 @@ func newMockHostWithQueryService(svc *querysvc.QueryService) *mockHost { } } +func newMockHostWithQueryServiceAndTenancy(svc *querysvc.QueryService, tm *tenancy.Manager) *mockHost { + return &mockHost{ + Host: componenttest.NewNopHost(), + queryExt: &mockQueryExtension{ + svc: svc, + tm: tm, + }, + } +} + func (m *mockHost) GetExtensions() map[component.ID]component.Component { return map[component.ID]component.Component{ jaegerquery.ID: m.queryExt, @@ -632,3 +648,29 @@ func createTestTraceForIntegration() ptrace.Traces { return traces } + +func TestServerMCPEndpointEnforcesTenancy(t *testing.T) { + tm := tenancy.NewManager(&tenancy.Options{Enabled: true, Header: "x-tenant", Tenants: []string{"tenant-a"}}) + host := newMockHostWithQueryServiceAndTenancy(nil, tm) + telset := componenttest.NewNopTelemetrySettings() + config := &Config{ + HTTP: confighttp.ServerConfig{NetAddr: confignet.AddrConfig{Endpoint: "localhost:0", Transport: confignet.TransportTypeTCP}}, + ServerVersion: "1.0.0", + MaxSpanDetailsPerRequest: 20, + MaxSearchResults: 100, + } + + server := newServer(config, telset) + require.NoError(t, server.Start(context.Background(), host)) + t.Cleanup(func() { _ = server.Shutdown(context.Background()) }) + addr := server.listener.Addr().String() + + resp, err := http.Get(fmt.Sprintf("http://%s/mcp", addr)) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "missing tenant header") +} diff --git a/cmd/jaeger/internal/extension/jaegerquery/extension.go b/cmd/jaeger/internal/extension/jaegerquery/extension.go index c3bf88f6ac5..b05664c6904 100644 --- a/cmd/jaeger/internal/extension/jaegerquery/extension.go +++ b/cmd/jaeger/internal/extension/jaegerquery/extension.go @@ -10,6 +10,7 @@ import ( "go.opentelemetry.io/collector/extension" "github.com/jaegertracing/jaeger/cmd/jaeger/internal/extension/jaegerquery/querysvc" + "github.com/jaegertracing/jaeger/internal/tenancy" ) // Extension is the interface that the jaegerquery extension implements. @@ -18,6 +19,8 @@ type Extension interface { extension.Extension // QueryService returns the v2 query service. QueryService() *querysvc.QueryService + // TenancyManager returns the tenancy manager used by query endpoints. + TenancyManager() *tenancy.Manager } // GetExtension retrieves the jaegerquery extension from the host. diff --git a/cmd/jaeger/internal/extension/jaegerquery/extension_test.go b/cmd/jaeger/internal/extension/jaegerquery/extension_test.go index aa0cd957678..2d7f35b006f 100644 --- a/cmd/jaeger/internal/extension/jaegerquery/extension_test.go +++ b/cmd/jaeger/internal/extension/jaegerquery/extension_test.go @@ -14,22 +14,29 @@ import ( "go.opentelemetry.io/collector/extension" "github.com/jaegertracing/jaeger/cmd/jaeger/internal/extension/jaegerquery/querysvc" + "github.com/jaegertracing/jaeger/internal/tenancy" ) // mockExtension implements Extension for testing type mockExtension struct { extension.Extension qs *querysvc.QueryService + tm *tenancy.Manager } func (m *mockExtension) QueryService() *querysvc.QueryService { return m.qs } +func (m *mockExtension) TenancyManager() *tenancy.Manager { + return m.tm +} + func TestGetExtension_Success(t *testing.T) { // Create a mock QueryService mockQS := &querysvc.QueryService{} - mockExt := &mockExtension{qs: mockQS} + mockTM := tenancy.NewManager(&tenancy.Options{}) + mockExt := &mockExtension{qs: mockQS, tm: mockTM} // Create a mock host with the jaegerquery extension host := &mockHost{ @@ -45,6 +52,10 @@ func TestGetExtension_Success(t *testing.T) { // Verify we got the right extension qs := ext.QueryService() assert.Equal(t, mockQS, qs) + + // Verify we got the right tenancy manager + tm := ext.TenancyManager() + assert.Equal(t, mockTM, tm) } func TestGetExtension_NotFound(t *testing.T) { diff --git a/cmd/jaeger/internal/extension/jaegerquery/server.go b/cmd/jaeger/internal/extension/jaegerquery/server.go index a846a7fb912..39b4fa530d1 100644 --- a/cmd/jaeger/internal/extension/jaegerquery/server.go +++ b/cmd/jaeger/internal/extension/jaegerquery/server.go @@ -35,11 +35,12 @@ var ( ) type server struct { - config *Config - server *queryapp.Server - telset component.TelemetrySettings - closeTracer func(ctx context.Context) error - qs *querysvc.QueryService + config *Config + server *queryapp.Server + telset component.TelemetrySettings + closeTracer func(ctx context.Context) error + qs *querysvc.QueryService + tenancyManager *tenancy.Manager } func newServer(config *Config, otel component.TelemetrySettings) *server { @@ -117,6 +118,7 @@ func (s *server) Start(ctx context.Context, host component.Host) error { } tm := tenancy.NewManager(&s.config.Tenancy) + s.tenancyManager = tm caps := querysvc.StorageCapabilities{ ArchiveStorage: opts.ArchiveTraceReader != nil && opts.ArchiveTraceWriter != nil, @@ -218,3 +220,8 @@ func (s *server) Shutdown(ctx context.Context) error { func (s *server) QueryService() *querysvc.QueryService { return s.qs } + +// TenancyManager returns the tenancy manager used by query endpoints. +func (s *server) TenancyManager() *tenancy.Manager { + return s.tenancyManager +} diff --git a/cmd/jaeger/internal/extension/jaegerquery/server_test.go b/cmd/jaeger/internal/extension/jaegerquery/server_test.go index 980c994392c..8d125a8863f 100644 --- a/cmd/jaeger/internal/extension/jaegerquery/server_test.go +++ b/cmd/jaeger/internal/extension/jaegerquery/server_test.go @@ -433,7 +433,9 @@ func TestQueryService(t *testing.T) { require.NoError(t, server.Shutdown(context.Background())) }() - // Test QueryService method qs := server.QueryService() require.NotNil(t, qs, "QueryService should not be nil") + + tm := server.TenancyManager() + require.NotNil(t, tm, "TenancyManager should not be nil") }