diff --git a/config/configgrpc/client_middleware_test.go b/config/configgrpc/client_middleware_test.go index 7ddb99ba164..0e3ffe5b972 100644 --- a/config/configgrpc/client_middleware_test.go +++ b/config/configgrpc/client_middleware_test.go @@ -31,6 +31,12 @@ type testClientMiddleware struct { extensionmiddleware.GetGRPCClientOptionsFunc } +// testClientMiddlewareContext is a mock implementation that uses context +type testClientMiddlewareContext struct { + extension.Extension + extensionmiddleware.GetGRPCClientOptionsContextFunc +} + func newTestMiddlewareConfig(name string) configmiddleware.Config { return configmiddleware.Config{ ID: component.MustNewID(name), @@ -87,6 +93,50 @@ func newTestClientMiddleware(name string) extension.Extension { } } +func newTestClientMiddlewareContext(name string) extension.Extension { + return &testClientMiddlewareContext{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCClientOptionsContextFunc: func(ctx context.Context) ([]grpc.DialOption, error) { + return []grpc.DialOption{ + grpc.WithChainUnaryInterceptor( + func( + callCtx context.Context, + method string, + req, reply any, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + // Get existing metadata or create new metadata + md, ok := metadata.FromOutgoingContext(callCtx) + if !ok { + md = metadata.New(nil) + } else { + md = md.Copy() + } + + sequence := "" + if values := md.Get("middleware-sequence"); len(values) > 0 { + sequence = values[0] + } + + // Use "ctx-" prefix to indicate context-aware middleware was used + ctxName := "ctx-" + name + if sequence == "" { + sequence = ctxName + } else { + sequence = fmt.Sprintf("%s,%s", sequence, ctxName) + } + + md.Set("middleware-sequence", sequence) + newCtx := metadata.NewOutgoingContext(callCtx, md) + return invoker(newCtx, method, req, reply, cc, opts...) + }), + }, nil + }, + } +} + // TestClientMiddlewareOrdering verifies that client middleware // interceptors are called in the right order. func TestClientMiddlewareOrdering(t *testing.T) { @@ -140,6 +190,58 @@ func TestClientMiddlewareOrdering(t *testing.T) { assert.Equal(t, expectedSequence, md[0]) } +// TestClientMiddlewareContextOrdering verifies that context-aware client middleware +// interceptors are called and work correctly. +func TestClientMiddlewareContextOrdering(t *testing.T) { + const middlewareTrackingHeader = "middleware-sequence" + + // Create context-aware middleware extensions + mockMiddleware1 := newTestClientMiddlewareContext("middleware-1") + mockMiddleware2 := newTestClientMiddlewareContext("middleware-2") + + mockExt := map[component.ID]component.Component{ + component.MustNewID("middleware1"): mockMiddleware1, + component.MustNewID("middleware2"): mockMiddleware2, + } + + // Start a gRPC server that will record the incoming metadata + server := &grpcTraceServer{} + srv, addr := server.startTestServer(t, configoptional.Some(ServerConfig{ + NetAddr: confignet.AddrConfig{ + Endpoint: "localhost:0", + Transport: confignet.TransportTypeTCP, + }, + })) + defer srv.Stop() + + // Create client config with middleware extensions + clientConfig := ClientConfig{ + Endpoint: addr, + TLS: configtls.ClientConfig{ + Insecure: true, + }, + Middlewares: []configmiddleware.Config{ + newTestMiddlewareConfig("middleware1"), + newTestMiddlewareConfig("middleware2"), + }, + } + + // Send a request using the client with middleware + resp, err := sendTestRequestWithExtensions(t, clientConfig, mockExt) + require.NoError(t, err) + assert.NotNil(t, resp) + + // Verify that the context-aware middleware was used (indicated by "ctx-" prefix) + ictx, ok := metadata.FromIncomingContext(server.recordedContext) + require.True(t, ok, "middleware tracking header not found in metadata") + md := ictx[middlewareTrackingHeader] + require.Len(t, md, 1, "expected exactly one middleware tracking header value") + + // The sequence should show ctx- prefix indicating context-aware interface was used + expectedSequence := "ctx-middleware-1,ctx-middleware-2" + assert.Equal(t, expectedSequence, md[0]) +} + // TestClientMiddlewareToClientErrors tests failure cases for the ToClient method // specifically related to middleware resolution and API calls. func TestClientMiddlewareToClientErrors(t *testing.T) { diff --git a/config/configgrpc/server_middleware_test.go b/config/configgrpc/server_middleware_test.go index c62748c78be..ff924752d1c 100644 --- a/config/configgrpc/server_middleware_test.go +++ b/config/configgrpc/server_middleware_test.go @@ -43,6 +43,12 @@ type testServerMiddleware struct { extensionmiddleware.GetGRPCServerOptionsFunc } +// testServerMiddlewareContext is a test implementation that uses context +type testServerMiddlewareContext struct { + extension.Extension + extensionmiddleware.GetGRPCServerOptionsContextFunc +} + func newTestServerMiddleware(name string) extension.Extension { return &testServerMiddleware{ Extension: extensionmiddlewaretest.NewNop(), @@ -60,6 +66,25 @@ func newTestServerMiddleware(name string) extension.Extension { } } +func newTestServerMiddlewareContext(name string) extension.Extension { + return &testServerMiddlewareContext{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCServerOptionsContextFunc: func(ctx context.Context) ([]grpc.ServerOption, error) { + // Use "ctx-" prefix to indicate context-aware middleware was used + ctxName := "ctx-" + name + return []grpc.ServerOption{grpc.ChainUnaryInterceptor( + func( + callCtx context.Context, + req any, _ *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (any, error) { + callCtx = context.WithValue(callCtx, middlewareCallsKey, append(getMiddlewareCalls(callCtx), ctxName)) + return handler(callCtx, req) + })}, nil + }, + } +} + func TestGrpcServerUnaryInterceptor(t *testing.T) { // Register two test extensions extensions := map[component.ID]component.Component{ @@ -101,6 +126,47 @@ func TestGrpcServerUnaryInterceptor(t *testing.T) { assert.Equal(t, []string{"test1", "test2"}, getMiddlewareCalls(server.recordedContext)) } +func TestGrpcServerUnaryInterceptorContext(t *testing.T) { + // Register two context-aware test extensions + extensions := map[component.ID]component.Component{ + component.MustNewID("test1"): newTestServerMiddlewareContext("test1"), + component.MustNewID("test2"): newTestServerMiddlewareContext("test2"), + } + + // Setup the server with both middleware options + server := &grpcTraceServer{} + var addr string + + // Create the server with middleware interceptors + { + var srv *grpc.Server + srv, addr = server.startTestServerWithExtensions(t, configoptional.Some(ServerConfig{ + NetAddr: confignet.AddrConfig{ + Endpoint: "localhost:0", + Transport: confignet.TransportTypeTCP, + }, + Middlewares: []configmiddleware.Config{ + newTestMiddlewareConfig("test1"), + newTestMiddlewareConfig("test2"), + }, + }), extensions) + defer srv.Stop() + } + + // Send a request to trigger the interceptors + resp, errResp := sendTestRequest(t, ClientConfig{ + Endpoint: addr, + TLS: configtls.ClientConfig{ + Insecure: true, + }, + }) + require.NoError(t, errResp) + require.NotNil(t, resp) + + // Verify context-aware interceptors were called (indicated by "ctx-" prefix) + assert.Equal(t, []string{"ctx-test1", "ctx-test2"}, getMiddlewareCalls(server.recordedContext)) +} + // TestServerMiddlewareToServerErrors tests failure cases for the ToServer method // specifically related to middleware resolution and API calls. func TestServerMiddlewareToServerErrors(t *testing.T) { diff --git a/config/configmiddleware/configmiddleware.go b/config/configmiddleware/configmiddleware.go index 831ba7766fe..0662e805bad 100644 --- a/config/configmiddleware/configmiddleware.go +++ b/config/configmiddleware/configmiddleware.go @@ -68,8 +68,15 @@ func (m Config) GetHTTPServerHandler(_ context.Context, extensions map[component // extensionmiddleware.GRPCClient from the map of extensions, and // returns the gRPC dial options. If a middleware is not found, an // error is returned. This should only be used by gRPC clients. -func (m Config) GetGRPCClientOptions(_ context.Context, extensions map[component.ID]component.Component) ([]grpc.DialOption, error) { +// +// This function first checks if the extension implements +// GRPCClientContext (which accepts a context), falling back to +// GRPCClient for backwards compatibility. +func (m Config) GetGRPCClientOptions(ctx context.Context, extensions map[component.ID]component.Component) ([]grpc.DialOption, error) { if ext, found := extensions[m.ID]; found { + if client, ok := ext.(extensionmiddleware.GRPCClientContext); ok { + return client.GetGRPCClientOptionsContext(ctx) + } if client, ok := ext.(extensionmiddleware.GRPCClient); ok { return client.GetGRPCClientOptions() } @@ -82,8 +89,15 @@ func (m Config) GetGRPCClientOptions(_ context.Context, extensions map[component // extensionmiddleware.GRPCServer from the map of extensions, and // returns the gRPC server options. If a middleware is not found, an // error is returned. This should only be used by gRPC servers. -func (m Config) GetGRPCServerOptions(_ context.Context, extensions map[component.ID]component.Component) ([]grpc.ServerOption, error) { +// +// This function first checks if the extension implements +// GRPCServerContext (which accepts a context), falling back to +// GRPCServer for backwards compatibility. +func (m Config) GetGRPCServerOptions(ctx context.Context, extensions map[component.ID]component.Component) ([]grpc.ServerOption, error) { if ext, found := extensions[m.ID]; found { + if server, ok := ext.(extensionmiddleware.GRPCServerContext); ok { + return server.GetGRPCServerOptionsContext(ctx) + } if server, ok := ext.(extensionmiddleware.GRPCServer); ok { return server.GetGRPCServerOptions() } diff --git a/config/configmiddleware/configmiddleware_test.go b/config/configmiddleware/configmiddleware_test.go index 5cdb2628fba..20e60132c39 100644 --- a/config/configmiddleware/configmiddleware_test.go +++ b/config/configmiddleware/configmiddleware_test.go @@ -158,6 +158,26 @@ func TestConfig_GetGRPCServerOptions(t *testing.T) { }, wantErr: nil, }, + { + name: "found_and_valid_context", + middleware: Config{ + ID: testID, + }, + extensions: map[component.ID]component.Component{ + testID: struct { + extension.Extension + extensionmiddleware.GetGRPCServerOptionsContextFunc + }{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCServerOptionsContextFunc: func(context.Context) ([]grpc.ServerOption, error) { + return []grpc.ServerOption{ + grpc.EmptyServerOption{}, + }, nil + }, + }, + }, + wantErr: nil, + }, { name: "middleware_not_found", middleware: Config{ @@ -221,6 +241,26 @@ func TestConfig_GetGRPCClientOptions(t *testing.T) { }, wantErr: nil, }, + { + name: "found_and_valid_context", + middleware: Config{ + ID: testID, + }, + extensions: map[component.ID]component.Component{ + testID: struct { + extension.Extension + extensionmiddleware.GetGRPCClientOptionsContextFunc + }{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCClientOptionsContextFunc: func(context.Context) ([]grpc.DialOption, error) { + return []grpc.DialOption{ + grpc.EmptyDialOption{}, + }, nil + }, + }, + }, + wantErr: nil, + }, { name: "middleware_not_found", middleware: Config{ @@ -254,3 +294,51 @@ func TestConfig_GetGRPCClientOptions(t *testing.T) { }) } } + +func TestConfig_GetGRPCClientOptions_ContextPassed(t *testing.T) { + type ctxKey struct{} + testCtx := context.WithValue(context.Background(), ctxKey{}, "test-value") + var receivedCtx context.Context + + middleware := Config{ID: testID} + extensions := map[component.ID]component.Component{ + testID: struct { + extension.Extension + extensionmiddleware.GetGRPCClientOptionsContextFunc + }{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCClientOptionsContextFunc: func(ctx context.Context) ([]grpc.DialOption, error) { + receivedCtx = ctx + return []grpc.DialOption{grpc.EmptyDialOption{}}, nil + }, + }, + } + + _, err := middleware.GetGRPCClientOptions(testCtx, extensions) + require.NoError(t, err) + require.Equal(t, "test-value", receivedCtx.Value(ctxKey{})) +} + +func TestConfig_GetGRPCServerOptions_ContextPassed(t *testing.T) { + type ctxKey struct{} + testCtx := context.WithValue(context.Background(), ctxKey{}, "server-test-value") + var receivedCtx context.Context + + middleware := Config{ID: testID} + extensions := map[component.ID]component.Component{ + testID: struct { + extension.Extension + extensionmiddleware.GetGRPCServerOptionsContextFunc + }{ + Extension: extensionmiddlewaretest.NewNop(), + GetGRPCServerOptionsContextFunc: func(ctx context.Context) ([]grpc.ServerOption, error) { + receivedCtx = ctx + return []grpc.ServerOption{grpc.EmptyServerOption{}}, nil + }, + }, + } + + _, err := middleware.GetGRPCServerOptions(testCtx, extensions) + require.NoError(t, err) + require.Equal(t, "server-test-value", receivedCtx.Value(ctxKey{})) +} diff --git a/extension/extensionmiddleware/README.md b/extension/extensionmiddleware/README.md index fd4c8747ed8..c7efded7260 100644 --- a/extension/extensionmiddleware/README.md +++ b/extension/extensionmiddleware/README.md @@ -54,3 +54,5 @@ server-side middleware object. - **GRPCClient**: The extension returns `[]grpc.DialOption`. - **GRPCServer**: The extension returns `[]grpc.ServerOption`. +- **GRPCClientContext**: Like `GRPCClient`, but accepts a `context.Context` for operations requiring it (e.g., fetching TLS credentials from a remote source). +- **GRPCServerContext**: Like `GRPCServer`, but accepts a `context.Context`. diff --git a/extension/extensionmiddleware/client.go b/extension/extensionmiddleware/client.go index fb24a18339e..d820210e335 100644 --- a/extension/extensionmiddleware/client.go +++ b/extension/extensionmiddleware/client.go @@ -4,28 +4,46 @@ package extensionmiddleware // import "go.opentelemetry.io/collector/extension/extensionmiddleware" import ( + "context" "net/http" "google.golang.org/grpc" ) // HTTPClient is an interface for HTTP client middleware extensions. +// This is an open interface for capability detection - extensions +// implement this interface to provide HTTP client middleware. type HTTPClient interface { // GetHTTPRoundTripper wraps the provided client RoundTripper. GetHTTPRoundTripper(http.RoundTripper) (http.RoundTripper, error) } // GRPCClient is an interface for gRPC client middleware extensions. +// This is an open interface for capability detection - extensions +// implement this interface to provide gRPC client middleware. type GRPCClient interface { // GetGRPCClientOptions returns the gRPC dial options to use for client connections. GetGRPCClientOptions() ([]grpc.DialOption, error) } +// GRPCClientContext is an extended interface for gRPC client middleware +// that accepts a context parameter. Extensions should implement this interface +// when they need context for operations like fetching TLS credentials from +// a remote source. +// +// This interface is consumed by configgrpc via configmiddleware. +type GRPCClientContext interface { + // GetGRPCClientOptionsContext returns the gRPC dial options with context support. + GetGRPCClientOptionsContext(context.Context) ([]grpc.DialOption, error) +} + var _ HTTPClient = (*GetHTTPRoundTripperFunc)(nil) // GetHTTPRoundTripperFunc is a function that implements HTTPClient. +// The nil value is a valid no-op implementation that returns the base unchanged. type GetHTTPRoundTripperFunc func(base http.RoundTripper) (http.RoundTripper, error) +// GetHTTPRoundTripper implements HTTPClient. A nil function returns base unchanged. func (f GetHTTPRoundTripperFunc) GetHTTPRoundTripper(base http.RoundTripper) (http.RoundTripper, error) { if f == nil { return base, nil @@ -36,11 +54,27 @@ func (f GetHTTPRoundTripperFunc) GetHTTPRoundTripper(base http.RoundTripper) (ht var _ GRPCClient = (*GetGRPCClientOptionsFunc)(nil) // GetGRPCClientOptionsFunc is a function that implements GRPCClient. +// The nil value is a valid no-op implementation that returns no options. type GetGRPCClientOptionsFunc func() ([]grpc.DialOption, error) +// GetGRPCClientOptions implements GRPCClient. A nil function returns nil options. func (f GetGRPCClientOptionsFunc) GetGRPCClientOptions() ([]grpc.DialOption, error) { if f == nil { return nil, nil } return f() } + +var _ GRPCClientContext = (*GetGRPCClientOptionsContextFunc)(nil) + +// GetGRPCClientOptionsContextFunc is a function that implements GRPCClientContext. +// The nil value is a valid no-op implementation that returns no options. +type GetGRPCClientOptionsContextFunc func(context.Context) ([]grpc.DialOption, error) + +// GetGRPCClientOptionsContext implements GRPCClientContext. A nil function returns nil options. +func (f GetGRPCClientOptionsContextFunc) GetGRPCClientOptionsContext(ctx context.Context) ([]grpc.DialOption, error) { + if f == nil { + return nil, nil + } + return f(ctx) +} diff --git a/extension/extensionmiddleware/client_test.go b/extension/extensionmiddleware/client_test.go index 8d4eadcade2..3b4b3993a7c 100644 --- a/extension/extensionmiddleware/client_test.go +++ b/extension/extensionmiddleware/client_test.go @@ -4,6 +4,7 @@ package extensionmiddleware import ( + "context" "errors" "net/http" "testing" @@ -89,3 +90,42 @@ func TestGetGRPCClientOptionsFunc(t *testing.T) { require.Nil(t, options) }) } + +func TestGetGRPCClientOptionsContextFunc(t *testing.T) { + ctx := context.Background() + + t.Run("nil function", func(t *testing.T) { + var nilFunc GetGRPCClientOptionsContextFunc + options, err := nilFunc.GetGRPCClientOptionsContext(ctx) + require.NoError(t, err) + require.Nil(t, options) + }) + + t.Run("receives context", func(t *testing.T) { + type ctxKey struct{} + testCtx := context.WithValue(ctx, ctxKey{}, "test-value") + var receivedCtx context.Context + + f := GetGRPCClientOptionsContextFunc(func(c context.Context) ([]grpc.DialOption, error) { + receivedCtx = c + return []grpc.DialOption{grpc.WithAuthority("test")}, nil + }) + + options, err := f.GetGRPCClientOptionsContext(testCtx) + require.NoError(t, err) + require.Len(t, options, 1) + require.Equal(t, "test-value", receivedCtx.Value(ctxKey{})) + }) + + t.Run("error function", func(t *testing.T) { + expectedErr := errors.New("context grpc options error") + errorFunc := GetGRPCClientOptionsContextFunc(func(context.Context) ([]grpc.DialOption, error) { + return nil, expectedErr + }) + + options, err := errorFunc.GetGRPCClientOptionsContext(ctx) + require.Error(t, err) + require.Equal(t, expectedErr, err) + require.Nil(t, options) + }) +} diff --git a/extension/extensionmiddleware/extensionmiddlewaretest/err.go b/extension/extensionmiddleware/extensionmiddlewaretest/err.go index b881dae2c51..978884f9489 100644 --- a/extension/extensionmiddleware/extensionmiddlewaretest/err.go +++ b/extension/extensionmiddleware/extensionmiddlewaretest/err.go @@ -4,6 +4,7 @@ package extensionmiddlewaretest // import "go.opentelemetry.io/collector/extension/extensionmiddleware/extensionmiddlewaretest" import ( + "context" "net/http" "google.golang.org/grpc" @@ -14,11 +15,13 @@ import ( ) var ( - _ extension.Extension = (*baseExtension)(nil) - _ extensionmiddleware.HTTPClient = (*baseExtension)(nil) - _ extensionmiddleware.GRPCClient = (*baseExtension)(nil) - _ extensionmiddleware.HTTPServer = (*baseExtension)(nil) - _ extensionmiddleware.GRPCServer = (*baseExtension)(nil) + _ extension.Extension = (*baseExtension)(nil) + _ extensionmiddleware.HTTPClient = (*baseExtension)(nil) + _ extensionmiddleware.GRPCClient = (*baseExtension)(nil) + _ extensionmiddleware.GRPCClientContext = (*baseExtension)(nil) + _ extensionmiddleware.HTTPServer = (*baseExtension)(nil) + _ extensionmiddleware.GRPCServer = (*baseExtension)(nil) + _ extensionmiddleware.GRPCServerContext = (*baseExtension)(nil) ) type baseExtension struct { @@ -26,8 +29,10 @@ type baseExtension struct { component.ShutdownFunc extensionmiddleware.GetHTTPHandlerFunc extensionmiddleware.GetGRPCServerOptionsFunc + extensionmiddleware.GetGRPCServerOptionsContextFunc extensionmiddleware.GetHTTPRoundTripperFunc extensionmiddleware.GetGRPCClientOptionsFunc + extensionmiddleware.GetGRPCClientOptionsContextFunc } // NewErr returns a new [extension.Extension] that implements all @@ -40,11 +45,17 @@ func NewErr(err error) extension.Extension { GetGRPCClientOptionsFunc: func() ([]grpc.DialOption, error) { return nil, err }, + GetGRPCClientOptionsContextFunc: func(context.Context) ([]grpc.DialOption, error) { + return nil, err + }, GetHTTPHandlerFunc: func(http.Handler) (http.Handler, error) { return nil, err }, GetGRPCServerOptionsFunc: func() ([]grpc.ServerOption, error) { return nil, err }, + GetGRPCServerOptionsContextFunc: func(context.Context) ([]grpc.ServerOption, error) { + return nil, err + }, } } diff --git a/extension/extensionmiddleware/extensionmiddlewaretest/err_test.go b/extension/extensionmiddleware/extensionmiddlewaretest/err_test.go index 14bcf291b5b..b9dd9dacf86 100644 --- a/extension/extensionmiddleware/extensionmiddlewaretest/err_test.go +++ b/extension/extensionmiddleware/extensionmiddlewaretest/err_test.go @@ -4,6 +4,7 @@ package extensionmiddlewaretest import ( + "context" "errors" "testing" @@ -24,6 +25,11 @@ func TestErrClient(t *testing.T) { require.True(t, ok) _, err = grpcClient.GetGRPCClientOptions() require.Error(t, err) + + grpcClientContext, ok := client.(extensionmiddleware.GRPCClientContext) + require.True(t, ok) + _, err = grpcClientContext.GetGRPCClientOptionsContext(context.Background()) + require.Error(t, err) } func TestErrServer(t *testing.T) { @@ -38,4 +44,9 @@ func TestErrServer(t *testing.T) { require.True(t, ok) _, err = grpcServer.GetGRPCServerOptions() require.Error(t, err) + + grpcServerContext, ok := server.(extensionmiddleware.GRPCServerContext) + require.True(t, ok) + _, err = grpcServerContext.GetGRPCServerOptionsContext(context.Background()) + require.Error(t, err) } diff --git a/extension/extensionmiddleware/extensionmiddlewaretest/nop_test.go b/extension/extensionmiddleware/extensionmiddlewaretest/nop_test.go index d1190fbff98..e7b366defb8 100644 --- a/extension/extensionmiddleware/extensionmiddlewaretest/nop_test.go +++ b/extension/extensionmiddleware/extensionmiddlewaretest/nop_test.go @@ -4,6 +4,7 @@ package extensionmiddlewaretest import ( + "context" "net/http" "testing" @@ -26,6 +27,12 @@ func TestNopClient(t *testing.T) { grpcOpts, err := grpcClient.GetGRPCClientOptions() require.NoError(t, err) require.Nil(t, grpcOpts) + + grpcClientContext, ok := client.(extensionmiddleware.GRPCClientContext) + require.True(t, ok) + grpcOptsContext, err := grpcClientContext.GetGRPCClientOptionsContext(context.Background()) + require.NoError(t, err) + require.Nil(t, grpcOptsContext) } func TestNopServer(t *testing.T) { @@ -42,6 +49,12 @@ func TestNopServer(t *testing.T) { grpcOpts, err := grpcServer.GetGRPCServerOptions() require.NoError(t, err) require.Nil(t, grpcOpts) + + grpcServerContext, ok := client.(extensionmiddleware.GRPCServerContext) + require.True(t, ok) + grpcOptsContext, err := grpcServerContext.GetGRPCServerOptionsContext(context.Background()) + require.NoError(t, err) + require.Nil(t, grpcOptsContext) } func TestRoundTripperFunc(t *testing.T) { diff --git a/extension/extensionmiddleware/server.go b/extension/extensionmiddleware/server.go index 6fa9ca1ae60..41657787a9d 100644 --- a/extension/extensionmiddleware/server.go +++ b/extension/extensionmiddleware/server.go @@ -4,28 +4,46 @@ package extensionmiddleware // import "go.opentelemetry.io/collector/extension/extensionmiddleware" import ( + "context" "net/http" "google.golang.org/grpc" ) // HTTPServer defines the interface for HTTP server middleware extensions. +// This is an open interface for capability detection - extensions +// implement this interface to provide HTTP server middleware. type HTTPServer interface { // GetHTTPHandler wraps the provided base http.Handler. GetHTTPHandler(base http.Handler) (http.Handler, error) } // GRPCServer defines the interface for gRPC server middleware extensions. +// This is an open interface for capability detection - extensions +// implement this interface to provide gRPC server middleware. type GRPCServer interface { // GetGRPCServerOptions returns options for a gRPC server. GetGRPCServerOptions() ([]grpc.ServerOption, error) } +// GRPCServerContext is an extended interface for gRPC server middleware +// that accepts a context parameter. Extensions should implement this interface +// when they need context for operations like fetching TLS credentials from +// a remote source. +// +// This interface is consumed by configgrpc via configmiddleware. +type GRPCServerContext interface { + // GetGRPCServerOptionsContext returns the gRPC server options with context support. + GetGRPCServerOptionsContext(context.Context) ([]grpc.ServerOption, error) +} + var _ HTTPServer = (*GetHTTPHandlerFunc)(nil) // GetHTTPHandlerFunc is a function that implements HTTPServer. +// The nil value is a valid no-op implementation that returns base unchanged. type GetHTTPHandlerFunc func(base http.Handler) (http.Handler, error) +// GetHTTPHandler implements HTTPServer. A nil function returns base unchanged. func (f GetHTTPHandlerFunc) GetHTTPHandler(base http.Handler) (http.Handler, error) { if f == nil { return base, nil @@ -36,11 +54,27 @@ func (f GetHTTPHandlerFunc) GetHTTPHandler(base http.Handler) (http.Handler, err var _ GRPCServer = (*GetGRPCServerOptionsFunc)(nil) // GetGRPCServerOptionsFunc is a function that implements GRPCServer. +// The nil value is a valid no-op implementation that returns no options. type GetGRPCServerOptionsFunc func() ([]grpc.ServerOption, error) +// GetGRPCServerOptions implements GRPCServer. A nil function returns nil options. func (f GetGRPCServerOptionsFunc) GetGRPCServerOptions() ([]grpc.ServerOption, error) { if f == nil { return nil, nil } return f() } + +var _ GRPCServerContext = (*GetGRPCServerOptionsContextFunc)(nil) + +// GetGRPCServerOptionsContextFunc is a function that implements GRPCServerContext. +// The nil value is a valid no-op implementation that returns no options. +type GetGRPCServerOptionsContextFunc func(context.Context) ([]grpc.ServerOption, error) + +// GetGRPCServerOptionsContext implements GRPCServerContext. A nil function returns nil options. +func (f GetGRPCServerOptionsContextFunc) GetGRPCServerOptionsContext(ctx context.Context) ([]grpc.ServerOption, error) { + if f == nil { + return nil, nil + } + return f(ctx) +} diff --git a/extension/extensionmiddleware/server_test.go b/extension/extensionmiddleware/server_test.go index eee61f5410e..b67efe48567 100644 --- a/extension/extensionmiddleware/server_test.go +++ b/extension/extensionmiddleware/server_test.go @@ -104,3 +104,51 @@ func TestGetGRPCServerOptionsFunc(t *testing.T) { require.Nil(t, opts) }) } + +func TestGetGRPCServerOptionsContextFunc(t *testing.T) { + ctx := context.Background() + + t.Run("nil_function", func(t *testing.T) { + var f GetGRPCServerOptionsContextFunc + opts, err := f.GetGRPCServerOptionsContext(ctx) + require.NoError(t, err) + require.Nil(t, opts) + }) + + t.Run("receives_context", func(t *testing.T) { + type ctxKey struct{} + testCtx := context.WithValue(ctx, ctxKey{}, "test-value") + var receivedCtx context.Context + + var interceptor grpc.UnaryServerInterceptor = func( + context.Context, + any, + *grpc.UnaryServerInfo, + grpc.UnaryHandler, + ) (resp any, err error) { + return nil, nil + } + expectedOpts := []grpc.ServerOption{grpc.UnaryInterceptor(interceptor)} + + f := GetGRPCServerOptionsContextFunc(func(c context.Context) ([]grpc.ServerOption, error) { + receivedCtx = c + return expectedOpts, nil + }) + + opts, err := f.GetGRPCServerOptionsContext(testCtx) + require.NoError(t, err) + require.Equal(t, expectedOpts, opts) + require.Equal(t, "test-value", receivedCtx.Value(ctxKey{})) + }) + + t.Run("returns_error", func(t *testing.T) { + expectedErr := errors.New("context server options error") + f := GetGRPCServerOptionsContextFunc(func(context.Context) ([]grpc.ServerOption, error) { + return nil, expectedErr + }) + + opts, err := f.GetGRPCServerOptionsContext(ctx) + require.Equal(t, expectedErr, err) + require.Nil(t, opts) + }) +}