diff --git a/CHANGELOG.md b/CHANGELOG.md index b460bb8e344ea..2e4e63d83cdd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Refactor the GetStats, FlushStats and QueryCacheStats class to use the Builder pattern instead of constructors ([#19935](https://github.com/opensearch-project/OpenSearch/pull/19935)) - Add RangeSemver for `dependencies` in `plugin-descriptor.properties` ([#19939](https://github.com/opensearch-project/OpenSearch/pull/19939)) - Refactor the FieldDataStats and CompletionStats class to use the Builder pattern instead of constructors ([#19936](https://github.com/opensearch-project/OpenSearch/pull/19936)) +- Thread Context Preservation by gRPC Interceptor ([#19776](https://github.com/opensearch-project/OpenSearch/pull/19776)) + ### Fixed - Fix Allocation and Rebalance Constraints of WeightFunction are incorrectly reset ([#19012](https://github.com/opensearch-project/OpenSearch/pull/19012)) diff --git a/modules/transport-grpc/spi/README.md b/modules/transport-grpc/spi/README.md index 9cc97e824d12e..9199b64f31ad1 100644 --- a/modules/transport-grpc/spi/README.md +++ b/modules/transport-grpc/spi/README.md @@ -334,7 +334,9 @@ The k-NN query's `filter` field is a `QueryContainer` protobuf type that can con ### Overview -Intercept incoming gRPC requests for authentication, authorization, logging, metrics, rate limiting,etc +Intercept incoming gRPC requests for authentication, authorization, logging, metrics, rate limiting, etc. Interceptors have access to OpenSearch's `ThreadContext` to store and retrieve request-scoped data. + +**Context Preservation:** The transport-grpc module automatically preserves ThreadContext across async boundaries. Any data set by interceptors will be available in the gRPC service implementation, even when execution switches to different threads. ### Basic Usage @@ -342,7 +344,7 @@ Intercept incoming gRPC requests for authentication, authorization, logging, met ```java public class SampleInterceptorProvider implements GrpcInterceptorProvider { @Override - public List getOrderedGrpcInterceptors() { + public List getOrderedGrpcInterceptors(ThreadContext threadContext) { return Arrays.asList( // First interceptor (order = 5, runs first) new GrpcInterceptorProvider.OrderedGrpcInterceptor() { @@ -353,6 +355,7 @@ public class SampleInterceptorProvider implements GrpcInterceptorProvider { public ServerInterceptor getInterceptor() { return (call, headers, next) -> { String methodName = call.getMethodDescriptor().getFullMethodName(); + threadContext.putTransient("grpc.method", methodName); System.out.println("First interceptor - Method: " + methodName); return next.startCall(call, headers); }; diff --git a/modules/transport-grpc/spi/src/main/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProvider.java b/modules/transport-grpc/spi/src/main/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProvider.java index d111ca6fa71f5..c838f63bc434a 100644 --- a/modules/transport-grpc/spi/src/main/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProvider.java +++ b/modules/transport-grpc/spi/src/main/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProvider.java @@ -7,6 +7,8 @@ */ package org.opensearch.transport.grpc.spi; +import org.opensearch.common.util.concurrent.ThreadContext; + import java.util.List; import io.grpc.ServerInterceptor; @@ -19,12 +21,17 @@ public interface GrpcInterceptorProvider { /** - * Returns a list of ordered gRPC interceptors. + * Returns a list of ordered gRPC interceptors with access to ThreadContext. * Each interceptor must have a unique order value. * + * This follows the pattern established by REST handler wrappers where + * the thread context is provided to allow interceptors to: + * - Extract headers from gRPC metadata and store in ThreadContext + * - Preserve context across async boundaries + * @param threadContext The thread context for managing request context * @return List of ordered gRPC interceptors */ - List getOrderedGrpcInterceptors(); + List getOrderedGrpcInterceptors(ThreadContext threadContext); /** * Provides a gRPC interceptor with an order value for execution priority. @@ -42,6 +49,8 @@ interface OrderedGrpcInterceptor { /** * Returns the actual gRPC ServerInterceptor instance. + * The interceptor can use the ThreadContext provided to the parent + * GrpcInterceptorProvider to manage request context. * * @return the server interceptor */ diff --git a/modules/transport-grpc/spi/src/test/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProviderTests.java b/modules/transport-grpc/spi/src/test/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProviderTests.java index 103fa610c3d0d..9c308e82f9a9a 100644 --- a/modules/transport-grpc/spi/src/test/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProviderTests.java +++ b/modules/transport-grpc/spi/src/test/java/org/opensearch/transport/grpc/spi/GrpcInterceptorProviderTests.java @@ -8,6 +8,8 @@ package org.opensearch.transport.grpc.spi; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; import java.util.Collections; @@ -22,26 +24,45 @@ public class GrpcInterceptorProviderTests extends OpenSearchTestCase { public void testBasicProviderImplementation() { TestGrpcInterceptorProvider provider = new TestGrpcInterceptorProvider(10); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - List interceptors = provider.getOrderedGrpcInterceptors(); + List interceptors = provider.getOrderedGrpcInterceptors(threadContext); assertNotNull(interceptors); assertEquals(1, interceptors.size()); assertEquals(10, interceptors.get(0).order()); } public void testProviderReturnsEmptyList() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); GrpcInterceptorProvider provider = new GrpcInterceptorProvider() { @Override - public List getOrderedGrpcInterceptors() { + public List getOrderedGrpcInterceptors(ThreadContext threadContext) { return Collections.emptyList(); } }; - List interceptors = provider.getOrderedGrpcInterceptors(); + List interceptors = provider.getOrderedGrpcInterceptors(threadContext); assertNotNull(interceptors); assertTrue(interceptors.isEmpty()); } + public void testProviderReceivesThreadContext() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + threadContext.putHeader("X-Test-Header", "test-value"); + + GrpcInterceptorProvider provider = new GrpcInterceptorProvider() { + @Override + public List getOrderedGrpcInterceptors(ThreadContext ctx) { + // Verify that the provider receives the ThreadContext + assertNotNull("ThreadContext should not be null", ctx); + assertEquals("test-value", ctx.getHeader("X-Test-Header")); + return Collections.emptyList(); + } + }; + + provider.getOrderedGrpcInterceptors(threadContext); + } + private static class TestGrpcInterceptorProvider implements GrpcInterceptorProvider { private final int order; @@ -50,7 +71,7 @@ private static class TestGrpcInterceptorProvider implements GrpcInterceptorProvi } @Override - public List getOrderedGrpcInterceptors() { + public List getOrderedGrpcInterceptors(ThreadContext threadContext) { return Collections.singletonList(createTestInterceptor(order, "test-interceptor")); } } diff --git a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/GrpcPlugin.java b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/GrpcPlugin.java index 45e4a5b1eb022..3b59a28590901 100644 --- a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/GrpcPlugin.java +++ b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/GrpcPlugin.java @@ -52,6 +52,7 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import java.util.stream.Collectors; import io.grpc.BindableService; @@ -84,7 +85,8 @@ public final class GrpcPlugin extends Plugin implements NetworkPlugin, Extensibl private final List servicesFactory = new ArrayList<>(); private QueryBuilderProtoConverterRegistryImpl queryRegistry; private AbstractQueryBuilderProtoUtils queryUtils; - private GrpcInterceptorChain serverInterceptor = new GrpcInterceptorChain(); + private GrpcInterceptorChain serverInterceptor; // Initialized in createComponents + private List interceptorProviders = new ArrayList<>(); private Client client; /** @@ -118,39 +120,10 @@ public void loadExtensions(ExtensiblePlugin.ExtensionLoader loader) { } List providers = loader.loadExtensions(GrpcInterceptorProvider.class); if (providers != null) { - List orderedList = new ArrayList<>(); - for (GrpcInterceptorProvider provider : providers) { - orderedList.addAll(provider.getOrderedGrpcInterceptors()); - } - - // Validate that no two interceptors have the same order - Map> orderMap = new HashMap<>(); - for (OrderedGrpcInterceptor interceptor : orderedList) { - int order = interceptor.order(); - orderMap.computeIfAbsent(order, k -> new ArrayList<>()).add(interceptor); - } - - // Check for duplicates and throw exception if found - for (Map.Entry> entry : orderMap.entrySet()) { - if (entry.getValue().size() > 1) { - throw new IllegalArgumentException( - "Multiple gRPC interceptors have the same order value: " - + entry.getKey() - + ". Each interceptor must have a unique order value." - ); - } - } - - // Sort by order and create a chain - similar to OpenSearch's ActionFilter pattern - orderedList.sort(Comparator.comparingInt(OrderedGrpcInterceptor::order)); - - if (!orderedList.isEmpty()) { - // Create a single chain interceptor that manages the execution - // This ensures proper ordering and exception handling - serverInterceptor.addInterceptors(orderedList); - - logger.info("Loaded {} gRPC interceptors into chain", orderedList.size()); - } + // Note: ThreadContext will be provided during component creation + // For now, we collect providers to be initialized later with ThreadContext + this.interceptorProviders = providers; + logger.info("Found {} gRPC interceptor providers, will initialize during component creation", providers.size()); } // Load discovered gRPC service factories List services = loader.loadExtensions(GrpcServiceFactory.class); @@ -363,6 +336,53 @@ public Collection createComponents( ) { this.client = client; + // Initialize the interceptor chain with ThreadContext + this.serverInterceptor = new GrpcInterceptorChain(threadPool.getThreadContext()); + + List orderedList = new ArrayList<>(); + + // Then add plugin-provided interceptors + if (!interceptorProviders.isEmpty()) { + for (GrpcInterceptorProvider provider : interceptorProviders) { + orderedList.addAll(provider.getOrderedGrpcInterceptors(threadPool.getThreadContext())); + } + + // Validate that no two interceptors have the same order + Map> orderMap = new HashMap<>(); + for (OrderedGrpcInterceptor interceptor : orderedList) { + int order = interceptor.order(); + orderMap.computeIfAbsent(order, k -> new ArrayList<>()).add(interceptor); + } + + // Check for duplicates and throw exception if found + for (Map.Entry> entry : orderMap.entrySet()) { + if (entry.getValue().size() > 1) { + String conflictingInterceptors = entry.getValue() + .stream() + .map(i -> i.getInterceptor().getClass().getName()) + .collect(Collectors.joining(", ")); + throw new IllegalArgumentException( + "Multiple gRPC interceptors have the same order value [" + + entry.getKey() + + "]: " + + conflictingInterceptors + + ". Each interceptor must have a unique order value." + ); + } + } + + // Sort by order and create a chain - similar to OpenSearch's ActionFilter pattern + orderedList.sort(Comparator.comparingInt(OrderedGrpcInterceptor::order)); + + if (!orderedList.isEmpty()) { + // Create a single chain interceptor that manages the execution + // This ensures proper ordering and exception handling + serverInterceptor.addInterceptors(orderedList); + + logger.info("Loaded {} gRPC interceptors into chain", orderedList.size()); + } + } + // Create the registry this.queryRegistry = new QueryBuilderProtoConverterRegistryImpl(); diff --git a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/Netty4GrpcServerTransport.java b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/Netty4GrpcServerTransport.java index de2fb0079c652..8ea67aa9d7191 100644 --- a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/Netty4GrpcServerTransport.java +++ b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/Netty4GrpcServerTransport.java @@ -55,6 +55,7 @@ import static org.opensearch.common.settings.Setting.listSetting; import static org.opensearch.common.util.concurrent.OpenSearchExecutors.daemonThreadFactory; import static org.opensearch.transport.Transport.resolveTransportPublishPort; +import static org.opensearch.transport.grpc.GrpcPlugin.GRPC_THREAD_POOL_NAME; /** * Netty4 gRPC server implemented as a LifecycleComponent. @@ -275,7 +276,7 @@ public Netty4GrpcServerTransport( NetworkService networkService, ThreadPool threadPool ) { - this(settings, services, networkService, threadPool, new GrpcInterceptorChain()); + this(settings, services, networkService, threadPool, new GrpcInterceptorChain(threadPool.getThreadContext())); } /** @@ -320,7 +321,7 @@ protected void doStart() { this.workerEventLoopGroup = new NioEventLoopGroup(nettyEventLoopThreads, daemonThreadFactory(settings, "grpc_worker")); // Use OpenSearch's managed thread pool for gRPC request processing - this.grpcExecutor = threadPool.executor("grpc"); + this.grpcExecutor = threadPool.executor(GRPC_THREAD_POOL_NAME); bindServer(); success = true; diff --git a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChain.java b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChain.java index 1e5f075cab302..396cfcdb6c1e7 100644 --- a/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChain.java +++ b/modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChain.java @@ -10,12 +10,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.transport.grpc.spi.GrpcInterceptorProvider.OrderedGrpcInterceptor; import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.function.Supplier; +import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -24,7 +27,9 @@ import io.grpc.StatusRuntimeException; /** - * Simple gRPC interceptor chain that executes OrderedGrpcInterceptors in order and handles exceptions + * gRPC interceptor chain that executes OrderedGrpcInterceptors in order and handles exceptions. + * Captures OpenSearch ThreadContext after interceptors run and restores it during callback execution + * to survive thread switches between gRPC threads and OpenSearch executor threads. */ public class GrpcInterceptorChain implements ServerInterceptor { @@ -34,22 +39,31 @@ public class GrpcInterceptorChain implements ServerInterceptor { }; private final List interceptors = new ArrayList<>(); + private final ThreadContext threadContext; /** * Constructs an empty GrpcInterceptorChain. + * + * @param threadContext The ThreadContext to capture and propagate */ - public GrpcInterceptorChain() {} + public GrpcInterceptorChain(ThreadContext threadContext) { + this.threadContext = Objects.requireNonNull(threadContext, "ThreadContext cannot be null"); + } /** * Constructs a GrpcInterceptorChain with the provided list of ordered interceptors. - * @param interceptors List of OrderedGrpcInterceptor instances to be applied in order + * + * @param threadContext The ThreadContext to capture and propagate + * @param interceptors List of OrderedGrpcInterceptor instances to be applied in order */ - public GrpcInterceptorChain(List interceptors) { + public GrpcInterceptorChain(ThreadContext threadContext, List interceptors) { + this.threadContext = Objects.requireNonNull(threadContext, "ThreadContext cannot be null"); this.interceptors.addAll(Objects.requireNonNull(interceptors)); } /** * Adds interceptors to the chain. + * * @param interceptors List of OrderedGrpcInterceptor instances to be added */ public void addInterceptors(List interceptors) { @@ -58,9 +72,11 @@ public void addInterceptors(List interceptors) { /** * Intercepts a gRPC call, executing the chain of interceptors in order. - * @param call object to receive response messages + * Captures ThreadContext after interceptors execute and restores it in all listener callbacks. + * + * @param call object to receive response messages * @param headers which can contain extra call metadata - * @param next next processor in the interceptor chain + * @param next next processor in the interceptor chain * @return a listener for processing incoming request messages */ @Override @@ -99,7 +115,47 @@ public ServerCall.Listener startCall(ServerCall call, Metadat } }; } - return currentHandler.startCall(call, headers); + + ServerCall.Listener delegate = currentHandler.startCall(call, headers); + + // Capture ThreadContext state AFTER interceptors have executed using newRestorableContext + // This follows the same pattern as ContextPreservingActionListener. + // Interceptors may have added transients/headers that need to propagate across thread switches. + final Supplier contextSupplier = threadContext.newRestorableContext(false); + + // Wrap the listener to restore ThreadContext in all callbacks + return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(delegate) { + private void runWithThreadContext(Runnable r) { + try (ThreadContext.StoredContext ignored = contextSupplier.get()) { + r.run(); + } + } + + @Override + public void onMessage(ReqT message) { + runWithThreadContext(() -> super.onMessage(message)); + } + + @Override + public void onHalfClose() { + runWithThreadContext(super::onHalfClose); + } + + @Override + public void onReady() { + runWithThreadContext(super::onReady); + } + + @Override + public void onCancel() { + runWithThreadContext(super::onCancel); + } + + @Override + public void onComplete() { + runWithThreadContext(super::onComplete); + } + }; } /** diff --git a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/GrpcPluginTests.java b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/GrpcPluginTests.java index 31148fe6337d7..544b2d91d8755 100644 --- a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/GrpcPluginTests.java +++ b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/GrpcPluginTests.java @@ -12,6 +12,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.indices.breaker.CircuitBreakerService; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SecureAuxTransportSettingsProvider; @@ -100,11 +101,14 @@ public void setup() { // Create a real ClusterSettings instance with the plugin's settings plugin = new GrpcPlugin(); + // Mock ThreadPool and ThreadContext + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + // Set the client in the plugin plugin.createComponents( client, null, // ClusterService - null, // ThreadPool + threadPool, // ThreadPool (now properly mocked) null, // ResourceWatcherService null, // ScriptService null, // NamedXContentRegistry @@ -254,7 +258,9 @@ public void testGetSecureAuxTransportsWithNullClient() { public void testGetAuxTransportsWithServiceFactories() { GrpcPlugin newPlugin = new GrpcPlugin(); - newPlugin.createComponents(Mockito.mock(Client.class), null, null, null, null, null, null, null, null, null, null); + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + newPlugin.createComponents(Mockito.mock(Client.class), null, mockThreadPool, null, null, null, null, null, null, null, null); ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class); when(mockLoader.loadExtensions(GrpcServiceFactory.class)).thenReturn(List.of(new LoadableMockServiceFactory())); plugin.loadExtensions(mockLoader); @@ -273,7 +279,9 @@ public void testGetAuxTransportsWithServiceFactories() { public void testGetSecureAuxTransportsWithServiceFactories() { GrpcPlugin newPlugin = new GrpcPlugin(); - newPlugin.createComponents(Mockito.mock(Client.class), null, null, null, null, null, null, null, null, null, null); + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + newPlugin.createComponents(Mockito.mock(Client.class), null, mockThreadPool, null, null, null, null, null, null, null, null); ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class); when(mockLoader.loadExtensions(GrpcServiceFactory.class)).thenReturn(List.of(new LoadableMockServiceFactory())); plugin.loadExtensions(mockLoader); @@ -341,11 +349,15 @@ public void testCreateComponents() { newPlugin.loadExtensions(extensionLoader); when(extensionLoader.loadExtensions(QueryBuilderProtoConverter.class)).thenReturn(List.of(mockConverter)); + // Mock ThreadPool for createComponents + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + // Call createComponents Collection components = newPlugin.createComponents( client, null, // ClusterService - null, // ThreadPool + mockThreadPool, // ThreadPool null, // ResourceWatcherService null, // ScriptService null, // NamedXContentRegistry @@ -378,11 +390,15 @@ public void testCreateComponentsWithExternalConverters() { // Verify the converter was added to the queryConverters list assertEquals("Should have 1 query converter loaded", 1, newPlugin.getQueryConverters().size()); + // Mock ThreadPool for createComponents + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + // Call createComponents to trigger registration of external converters Collection components = newPlugin.createComponents( client, null, // ClusterService - null, // ThreadPool + mockThreadPool, // ThreadPool null, // ResourceWatcherService null, // ScriptService null, // NamedXContentRegistry @@ -421,11 +437,45 @@ public void testLoadExtensionsWithGrpcInterceptorsOrdering() { } public void testLoadExtensionsWithDuplicateGrpcInterceptorOrder() { - testInterceptorLoading(List.of(1, 1), IllegalArgumentException.class); + GrpcPlugin plugin = new GrpcPlugin(); + ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoader(List.of(1, 1)); + + assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader)); + + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null) + ); + + String errorMessage = exception.getMessage(); + assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [1]")); + assertTrue(errorMessage.contains("ServerInterceptor")); // Mock class name will contain this + assertTrue(errorMessage.contains("Each interceptor must have a unique order value")); } public void testLoadExtensionsWithMultipleProvidersAndDuplicateOrder() { - testInterceptorLoadingWithMultipleProviders(List.of(List.of(5), List.of(5)), IllegalArgumentException.class); + GrpcPlugin plugin = new GrpcPlugin(); + ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoaderWithMultipleProviders(List.of(List.of(5), List.of(5))); + + // loadExtensions should succeed + assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader)); + + // createComponents should fail with duplicate order + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null) + ); + + String errorMessage = exception.getMessage(); + assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [5]")); + assertTrue(errorMessage.contains("ServerInterceptor")); + assertTrue(errorMessage.contains("Each interceptor must have a unique order value")); } public void testLoadExtensionsWithNullGrpcInterceptorProviders() { @@ -437,7 +487,23 @@ public void testLoadExtensionsWithEmptyGrpcInterceptorList() { } public void testLoadExtensionsWithSameExplicitOrderInterceptors() { - testInterceptorLoading(List.of(5, 5), IllegalArgumentException.class); + GrpcPlugin plugin = new GrpcPlugin(); + ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoader(List.of(5, 5)); + + assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader)); + + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null) + ); + + String errorMessage = exception.getMessage(); + assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [5]")); + assertTrue(errorMessage.contains("ServerInterceptor")); + assertTrue(errorMessage.contains("Each interceptor must have a unique order value")); } // Test cases for interceptor chain failure handling @@ -560,13 +626,13 @@ private ExtensiblePlugin.ExtensionLoader createMockLoader(List orders) when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(null); } else if (orders.isEmpty()) { GrpcInterceptorProvider mockProvider = Mockito.mock(GrpcInterceptorProvider.class); - when(mockProvider.getOrderedGrpcInterceptors()).thenReturn(new ArrayList<>()); + when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(new ArrayList<>()); when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(List.of(mockProvider)); } else { List interceptors = orders.stream().map(order -> createMockInterceptor(order)).toList(); GrpcInterceptorProvider mockProvider = Mockito.mock(GrpcInterceptorProvider.class); - when(mockProvider.getOrderedGrpcInterceptors()).thenReturn(interceptors); + when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors); when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(List.of(mockProvider)); } @@ -583,7 +649,7 @@ private ExtensiblePlugin.ExtensionLoader createMockLoaderWithMultipleProviders(L List providers = providerOrders.stream().map(orders -> { List interceptors = orders.stream().map(this::createMockInterceptor).toList(); GrpcInterceptorProvider provider = Mockito.mock(GrpcInterceptorProvider.class); - when(provider.getOrderedGrpcInterceptors()).thenReturn(interceptors); + when(provider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors); return provider; }).toList(); @@ -817,7 +883,7 @@ public void testGrpcInterceptorChainIntegrationWithPlugin() { createTestInterceptor(20, false), createTestInterceptor(30, false) ); - when(mockProvider.getOrderedGrpcInterceptors()).thenReturn(interceptors); + when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors); ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class); when(mockLoader.loadExtensions(QueryBuilderProtoConverter.class)).thenReturn(null); @@ -825,8 +891,14 @@ public void testGrpcInterceptorChainIntegrationWithPlugin() { GrpcPlugin plugin = new GrpcPlugin(); - // Should not throw exception and should create chain + // Should not throw exception and should load providers assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader)); + + // Need to call createComponents to actually initialize the chain + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + + assertDoesNotThrow(() -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null)); } public void testGrpcInterceptorChainWithDuplicateOrders() { @@ -836,7 +908,7 @@ public void testGrpcInterceptorChainWithDuplicateOrders() { createTestInterceptor(10, false), createTestInterceptor(10, false) // Duplicate order ); - when(mockProvider.getOrderedGrpcInterceptors()).thenReturn(interceptors); + when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors); ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class); when(mockLoader.loadExtensions(QueryBuilderProtoConverter.class)).thenReturn(null); @@ -844,15 +916,32 @@ public void testGrpcInterceptorChainWithDuplicateOrders() { GrpcPlugin plugin = new GrpcPlugin(); - // Should throw exception due to duplicate orders - expectThrows(IllegalArgumentException.class, () -> plugin.loadExtensions(mockLoader)); + // Load extensions first + plugin.loadExtensions(mockLoader); + + // Mock ThreadPool for createComponents + ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class); + when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)); + + // Should throw exception due to duplicate orders during createComponents + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null) + ); + + // Verify error message includes order value and interceptor class names + String errorMessage = exception.getMessage(); + assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [10]")); + assertTrue(errorMessage.contains("GrpcPluginTests")); + assertTrue(errorMessage.contains("Each interceptor must have a unique order value")); } /** * Helper method to test GrpcInterceptorChain behavior */ private void testGrpcInterceptorChain(List interceptors, boolean shouldSucceed, String expectedErrorMessage) { - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); @SuppressWarnings("unchecked") ServerCall mockCall = Mockito.mock(ServerCall.class); diff --git a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChainTests.java b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChainTests.java index 540dfb34dfdc1..edc607704928c 100644 --- a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChainTests.java +++ b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChainTests.java @@ -8,6 +8,8 @@ package org.opensearch.transport.grpc.interceptor; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.grpc.spi.GrpcInterceptorProvider.OrderedGrpcInterceptor; import org.junit.Before; @@ -46,6 +48,7 @@ public class GrpcInterceptorChainTests extends OpenSearchTestCase { private ServerCall.Listener mockListener; private Metadata headers; + private ThreadContext threadContext; @Before public void setUp() throws Exception { @@ -53,21 +56,22 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); when(mockHandler.startCall(any(), any())).thenReturn(mockListener); headers = new Metadata(); + threadContext = new ThreadContext(Settings.EMPTY); } public void testEmptyChain() { - GrpcInterceptorChain chain = new GrpcInterceptorChain(Collections.emptyList()); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, Collections.emptyList()); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); - assertEquals(mockListener, result); + // The result is now wrapped in a ThreadContextPreservingListener, not the raw mockListener verify(mockHandler).startCall(mockCall, headers); } public void testSingleSuccessfulInterceptor() { List interceptors = Arrays.asList(createTestInterceptor(10, false, null)); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); @@ -81,7 +85,7 @@ public void testMultipleSuccessfulInterceptors() { createTestInterceptor(30, false, null) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); @@ -96,7 +100,7 @@ public void testFirstInterceptorFails() { createTestInterceptor(30, false, null) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); chain.interceptCall(mockCall, headers, mockHandler); verify(mockCall).close( @@ -112,7 +116,7 @@ public void testMiddleInterceptorFails() { createTestInterceptor(30, false, null) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); chain.interceptCall(mockCall, headers, mockHandler); verify(mockCall).close( @@ -128,7 +132,7 @@ public void testLastInterceptorFails() { createTestInterceptor(30, true, "Last failure") ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); chain.interceptCall(mockCall, headers, mockHandler); verify(mockCall).close( @@ -144,7 +148,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionPermissionDenied() { createTestInterceptor(30, false, null) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); @@ -160,7 +164,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionUnauthenticated() { createTestInterceptor(20, false, null) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); @@ -177,7 +181,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionResourceExhausted() { createStatusRuntimeExceptionInterceptor(30, Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded")) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); assertNotNull(result); @@ -199,7 +203,7 @@ public void testInterceptorOrdering() { // Sort as GrpcPlugin would interceptors.sort((a, b) -> Integer.compare(a.order(), b.order())); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); chain.interceptCall(mockCall, headers, mockHandler); // Verify execution order @@ -221,7 +225,7 @@ public void testChainIntegrationWithRealScenario() { createLoggingInterceptor(30, "METRICS", executionLog) ); - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); chain.interceptCall(mockCall, headers, mockHandler); assertEquals(Arrays.asList("AUTH", "LOGGING", "METRICS"), executionLog); @@ -231,7 +235,7 @@ public void testChainIntegrationWithRealScenario() { * Generic test method that can be extended for different scenarios */ public void testChainWithPattern(List interceptors, boolean expectSuccess, String expectedErrorMessage) { - GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors); + GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors); if (expectSuccess) { ServerCall.Listener result = chain.interceptCall(mockCall, headers, mockHandler); diff --git a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/ssl/SecureNetty4GrpcServerTransportTests.java b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/ssl/SecureNetty4GrpcServerTransportTests.java index cf35ad449c234..aeb59e247dcb2 100644 --- a/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/ssl/SecureNetty4GrpcServerTransportTests.java +++ b/modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/ssl/SecureNetty4GrpcServerTransportTests.java @@ -52,7 +52,7 @@ public void setup() { Settings settings = Settings.builder().put("node.name", "test-node").put("grpc.netty.executor_count", 4).build(); ExecutorBuilder grpcExecutorBuilder = new FixedExecutorBuilder(settings, "grpc", 4, 1000, "thread_pool.grpc"); threadPool = new ThreadPool(settings, grpcExecutorBuilder); - serverInterceptor = new GrpcInterceptorChain(Collections.emptyList()); + serverInterceptor = new GrpcInterceptorChain(threadPool.getThreadContext(), Collections.emptyList()); } @After