Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions config/configgrpc/client_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
extensionmiddleware.GetGRPCClientOptionsFunc
}

// testClientMiddlewareContext is a mock implementation that uses context
type testClientMiddlewareContext struct {
extension.Extension
extensionmiddleware.GetGRPCClientOptionsContextFunc
}

func newTestMiddlewareConfig(name string) configmiddleware.Config {
return configmiddleware.Config{
ID: component.MustNewID(name),
Expand Down Expand Up @@ -87,6 +93,50 @@
}
}

func newTestClientMiddlewareContext(name string) extension.Extension {
return &testClientMiddlewareContext{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCClientOptionsContextFunc: func(ctx context.Context) ([]grpc.DialOption, error) {

Check failure on line 99 in config/configgrpc/client_middleware_test.go

View workflow job for this annotation

GitHub Actions / CodeQL-Build

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)

Check failure on line 99 in config/configgrpc/client_middleware_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
return []grpc.DialOption{
grpc.WithChainUnaryInterceptor(
func(
callCtx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
// Get existing metadata or create new metadata
md, ok := metadata.FromOutgoingContext(callCtx)
if !ok {
md = metadata.New(nil)
} else {
md = md.Copy()
}

sequence := ""
if values := md.Get("middleware-sequence"); len(values) > 0 {
sequence = values[0]
}

// Use "ctx-" prefix to indicate context-aware middleware was used
ctxName := "ctx-" + name
if sequence == "" {
sequence = ctxName
} else {
sequence = fmt.Sprintf("%s,%s", sequence, ctxName)
}

md.Set("middleware-sequence", sequence)
newCtx := metadata.NewOutgoingContext(callCtx, md)
return invoker(newCtx, method, req, reply, cc, opts...)
}),
}, nil
},
}
}

// TestClientMiddlewareOrdering verifies that client middleware
// interceptors are called in the right order.
func TestClientMiddlewareOrdering(t *testing.T) {
Expand Down Expand Up @@ -140,6 +190,58 @@
assert.Equal(t, expectedSequence, md[0])
}

// TestClientMiddlewareContextOrdering verifies that context-aware client middleware
// interceptors are called and work correctly.
func TestClientMiddlewareContextOrdering(t *testing.T) {
const middlewareTrackingHeader = "middleware-sequence"

// Create context-aware middleware extensions
mockMiddleware1 := newTestClientMiddlewareContext("middleware-1")
mockMiddleware2 := newTestClientMiddlewareContext("middleware-2")

mockExt := map[component.ID]component.Component{
component.MustNewID("middleware1"): mockMiddleware1,
component.MustNewID("middleware2"): mockMiddleware2,
}

// Start a gRPC server that will record the incoming metadata
server := &grpcTraceServer{}
srv, addr := server.startTestServer(t, configoptional.Some(ServerConfig{
NetAddr: confignet.AddrConfig{
Endpoint: "localhost:0",
Transport: confignet.TransportTypeTCP,
},
}))
defer srv.Stop()

// Create client config with middleware extensions
clientConfig := ClientConfig{
Endpoint: addr,
TLS: configtls.ClientConfig{
Insecure: true,
},
Middlewares: []configmiddleware.Config{
newTestMiddlewareConfig("middleware1"),
newTestMiddlewareConfig("middleware2"),
},
}

// Send a request using the client with middleware
resp, err := sendTestRequestWithExtensions(t, clientConfig, mockExt)
require.NoError(t, err)
assert.NotNil(t, resp)

// Verify that the context-aware middleware was used (indicated by "ctx-" prefix)
ictx, ok := metadata.FromIncomingContext(server.recordedContext)
require.True(t, ok, "middleware tracking header not found in metadata")
md := ictx[middlewareTrackingHeader]
require.Len(t, md, 1, "expected exactly one middleware tracking header value")

// The sequence should show ctx- prefix indicating context-aware interface was used
expectedSequence := "ctx-middleware-1,ctx-middleware-2"
assert.Equal(t, expectedSequence, md[0])
}

// TestClientMiddlewareToClientErrors tests failure cases for the ToClient method
// specifically related to middleware resolution and API calls.
func TestClientMiddlewareToClientErrors(t *testing.T) {
Expand Down
66 changes: 66 additions & 0 deletions config/configgrpc/server_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
extensionmiddleware.GetGRPCServerOptionsFunc
}

// testServerMiddlewareContext is a test implementation that uses context
type testServerMiddlewareContext struct {
extension.Extension
extensionmiddleware.GetGRPCServerOptionsContextFunc
}

func newTestServerMiddleware(name string) extension.Extension {
return &testServerMiddleware{
Extension: extensionmiddlewaretest.NewNop(),
Expand All @@ -60,6 +66,25 @@
}
}

func newTestServerMiddlewareContext(name string) extension.Extension {
return &testServerMiddlewareContext{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCServerOptionsContextFunc: func(ctx context.Context) ([]grpc.ServerOption, error) {

Check failure on line 72 in config/configgrpc/server_middleware_test.go

View workflow job for this annotation

GitHub Actions / CodeQL-Build

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)

Check failure on line 72 in config/configgrpc/server_middleware_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
// Use "ctx-" prefix to indicate context-aware middleware was used
ctxName := "ctx-" + name
return []grpc.ServerOption{grpc.ChainUnaryInterceptor(
func(
callCtx context.Context,
req any, _ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
callCtx = context.WithValue(callCtx, middlewareCallsKey, append(getMiddlewareCalls(callCtx), ctxName))
return handler(callCtx, req)
})}, nil
},
}
}

func TestGrpcServerUnaryInterceptor(t *testing.T) {
// Register two test extensions
extensions := map[component.ID]component.Component{
Expand Down Expand Up @@ -101,6 +126,47 @@
assert.Equal(t, []string{"test1", "test2"}, getMiddlewareCalls(server.recordedContext))
}

func TestGrpcServerUnaryInterceptorContext(t *testing.T) {
// Register two context-aware test extensions
extensions := map[component.ID]component.Component{
component.MustNewID("test1"): newTestServerMiddlewareContext("test1"),
component.MustNewID("test2"): newTestServerMiddlewareContext("test2"),
}

// Setup the server with both middleware options
server := &grpcTraceServer{}
var addr string

// Create the server with middleware interceptors
{
var srv *grpc.Server
srv, addr = server.startTestServerWithExtensions(t, configoptional.Some(ServerConfig{
NetAddr: confignet.AddrConfig{
Endpoint: "localhost:0",
Transport: confignet.TransportTypeTCP,
},
Middlewares: []configmiddleware.Config{
newTestMiddlewareConfig("test1"),
newTestMiddlewareConfig("test2"),
},
}), extensions)
defer srv.Stop()
}

// Send a request to trigger the interceptors
resp, errResp := sendTestRequest(t, ClientConfig{
Endpoint: addr,
TLS: configtls.ClientConfig{
Insecure: true,
},
})
require.NoError(t, errResp)
require.NotNil(t, resp)

// Verify context-aware interceptors were called (indicated by "ctx-" prefix)
assert.Equal(t, []string{"ctx-test1", "ctx-test2"}, getMiddlewareCalls(server.recordedContext))
}

// TestServerMiddlewareToServerErrors tests failure cases for the ToServer method
// specifically related to middleware resolution and API calls.
func TestServerMiddlewareToServerErrors(t *testing.T) {
Expand Down
18 changes: 16 additions & 2 deletions config/configmiddleware/configmiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ func (m Config) GetHTTPServerHandler(_ context.Context, extensions map[component
// extensionmiddleware.GRPCClient from the map of extensions, and
// returns the gRPC dial options. If a middleware is not found, an
// error is returned. This should only be used by gRPC clients.
func (m Config) GetGRPCClientOptions(_ context.Context, extensions map[component.ID]component.Component) ([]grpc.DialOption, error) {
//
// This function first checks if the extension implements
// GRPCClientContext (which accepts a context), falling back to
// GRPCClient for backwards compatibility.
func (m Config) GetGRPCClientOptions(ctx context.Context, extensions map[component.ID]component.Component) ([]grpc.DialOption, error) {
if ext, found := extensions[m.ID]; found {
if client, ok := ext.(extensionmiddleware.GRPCClientContext); ok {
return client.GetGRPCClientOptionsContext(ctx)
}
if client, ok := ext.(extensionmiddleware.GRPCClient); ok {
return client.GetGRPCClientOptions()
}
Expand All @@ -82,8 +89,15 @@ func (m Config) GetGRPCClientOptions(_ context.Context, extensions map[component
// extensionmiddleware.GRPCServer from the map of extensions, and
// returns the gRPC server options. If a middleware is not found, an
// error is returned. This should only be used by gRPC servers.
func (m Config) GetGRPCServerOptions(_ context.Context, extensions map[component.ID]component.Component) ([]grpc.ServerOption, error) {
//
// This function first checks if the extension implements
// GRPCServerContext (which accepts a context), falling back to
// GRPCServer for backwards compatibility.
func (m Config) GetGRPCServerOptions(ctx context.Context, extensions map[component.ID]component.Component) ([]grpc.ServerOption, error) {
if ext, found := extensions[m.ID]; found {
if server, ok := ext.(extensionmiddleware.GRPCServerContext); ok {
return server.GetGRPCServerOptionsContext(ctx)
}
if server, ok := ext.(extensionmiddleware.GRPCServer); ok {
return server.GetGRPCServerOptions()
}
Expand Down
88 changes: 88 additions & 0 deletions config/configmiddleware/configmiddleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ func TestConfig_GetGRPCServerOptions(t *testing.T) {
},
wantErr: nil,
},
{
name: "found_and_valid_context",
middleware: Config{
ID: testID,
},
extensions: map[component.ID]component.Component{
testID: struct {
extension.Extension
extensionmiddleware.GetGRPCServerOptionsContextFunc
}{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCServerOptionsContextFunc: func(context.Context) ([]grpc.ServerOption, error) {
return []grpc.ServerOption{
grpc.EmptyServerOption{},
}, nil
},
},
},
wantErr: nil,
},
{
name: "middleware_not_found",
middleware: Config{
Expand Down Expand Up @@ -221,6 +241,26 @@ func TestConfig_GetGRPCClientOptions(t *testing.T) {
},
wantErr: nil,
},
{
name: "found_and_valid_context",
middleware: Config{
ID: testID,
},
extensions: map[component.ID]component.Component{
testID: struct {
extension.Extension
extensionmiddleware.GetGRPCClientOptionsContextFunc
}{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCClientOptionsContextFunc: func(context.Context) ([]grpc.DialOption, error) {
return []grpc.DialOption{
grpc.EmptyDialOption{},
}, nil
},
},
},
wantErr: nil,
},
{
name: "middleware_not_found",
middleware: Config{
Expand Down Expand Up @@ -254,3 +294,51 @@ func TestConfig_GetGRPCClientOptions(t *testing.T) {
})
}
}

func TestConfig_GetGRPCClientOptions_ContextPassed(t *testing.T) {
type ctxKey struct{}
testCtx := context.WithValue(context.Background(), ctxKey{}, "test-value")
var receivedCtx context.Context

middleware := Config{ID: testID}
extensions := map[component.ID]component.Component{
testID: struct {
extension.Extension
extensionmiddleware.GetGRPCClientOptionsContextFunc
}{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCClientOptionsContextFunc: func(ctx context.Context) ([]grpc.DialOption, error) {
receivedCtx = ctx
return []grpc.DialOption{grpc.EmptyDialOption{}}, nil
},
},
}

_, err := middleware.GetGRPCClientOptions(testCtx, extensions)
require.NoError(t, err)
require.Equal(t, "test-value", receivedCtx.Value(ctxKey{}))
}

func TestConfig_GetGRPCServerOptions_ContextPassed(t *testing.T) {
type ctxKey struct{}
testCtx := context.WithValue(context.Background(), ctxKey{}, "server-test-value")
var receivedCtx context.Context

middleware := Config{ID: testID}
extensions := map[component.ID]component.Component{
testID: struct {
extension.Extension
extensionmiddleware.GetGRPCServerOptionsContextFunc
}{
Extension: extensionmiddlewaretest.NewNop(),
GetGRPCServerOptionsContextFunc: func(ctx context.Context) ([]grpc.ServerOption, error) {
receivedCtx = ctx
return []grpc.ServerOption{grpc.EmptyServerOption{}}, nil
},
},
}

_, err := middleware.GetGRPCServerOptions(testCtx, extensions)
require.NoError(t, err)
require.Equal(t, "server-test-value", receivedCtx.Value(ctxKey{}))
}
2 changes: 2 additions & 0 deletions extension/extensionmiddleware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,5 @@ server-side middleware object.

- **GRPCClient**: The extension returns `[]grpc.DialOption`.
- **GRPCServer**: The extension returns `[]grpc.ServerOption`.
- **GRPCClientContext**: Like `GRPCClient`, but accepts a `context.Context` for operations requiring it (e.g., fetching TLS credentials from a remote source).
- **GRPCServerContext**: Like `GRPCServer`, but accepts a `context.Context`.
Loading
Loading