diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 5f168166c236d..b7075fa6dc7f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -478,7 +478,9 @@ private CCMRelatedComponents createCCMDependentComponents( var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool(), - ccmAuthApplierFactory + ccmAuthApplierFactory, + ccmFeature, + ccmService ); var authTaskExecutor = AuthorizationTaskExecutor.create( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index e776ba0690613..fd4b981375009 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -158,7 +158,7 @@ private void getEisAuthorization(ActionListener getServiceConfigurationsForServices( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java index f7e3c1e8f3896..db9d253361b28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java @@ -112,7 +112,9 @@ protected void masterOperation( var authRequestHandler = new ElasticInferenceServiceAuthorizationRequestHandler( eisSettings.getElasticInferenceServiceUrl(), threadPool, - new ValidationAuthenticationFactory(request.getApiKey()) + new ValidationAuthenticationFactory(request.getApiKey()), + ccmFeature, + ccmService ); var errorListener = authValidationListener.delegateResponse((delegate, exception) -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index e1506a4b37ef1..4598d01d30bed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -23,6 +23,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.AuthenticationFactory; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -56,17 +58,23 @@ private static ResponseHandler createAuthResponseHandler() { private final Logger logger; private final CountDownLatch requestCompleteLatch = new CountDownLatch(1); private final AuthenticationFactory authFactory; + private final CCMFeature ccmFeature; + private final CCMService ccmService; public ElasticInferenceServiceAuthorizationRequestHandler( @Nullable String baseUrl, ThreadPool threadPool, - AuthenticationFactory authFactory + AuthenticationFactory authFactory, + CCMFeature ccmFeature, + CCMService ccmService ) { this( baseUrl, Objects.requireNonNull(threadPool), LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class), - authFactory + authFactory, + ccmFeature, + ccmService ); } @@ -75,12 +83,16 @@ public ElasticInferenceServiceAuthorizationRequestHandler( @Nullable String baseUrl, ThreadPool threadPool, Logger logger, - AuthenticationFactory authFactory + AuthenticationFactory authFactory, + CCMFeature ccmFeature, + CCMService ccmService ) { this.baseUrl = baseUrl; this.threadPool = Objects.requireNonNull(threadPool); this.logger = Objects.requireNonNull(logger); this.authFactory = Objects.requireNonNull(authFactory); + this.ccmFeature = Objects.requireNonNull(ccmFeature); + this.ccmService = Objects.requireNonNull(ccmService); } /** @@ -89,56 +101,92 @@ public ElasticInferenceServiceAuthorizationRequestHandler( * @param sender a {@link Sender} for making the request to the Elastic Inference Service */ public void getAuthorization(ActionListener listener, Sender sender) { - try { - logger.debug("Retrieving authorization information from the Elastic Inference Service."); + getAuthorization(listener, sender, false); + } - if (Strings.isNullOrEmpty(baseUrl)) { - logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); - listener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); - return; - } + /** + * Retrieves the authorization information from Elastic Inference Service. This will skip making a request if CCM is not enabled + * and it is a supported environment. A supported environment is on-prem or ECK. ECH and serverless are not supported + * environments for CCM (because they can already connect to EIS). For environments where CCM is not supported, it will always + * attempt to retrieve the authorization information. + * @param listener a listener to receive the response + * @param sender a {@link Sender} for making the request to the Elastic Inference Service + */ + public void getAuthorizationIfPermittedEnvironment(ActionListener listener, Sender sender) { + getAuthorization(listener, sender, true); + } - var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> { - // unwrap because it's likely a retry exception - var exception = ExceptionsHelper.unwrapCause(e); - - logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); - authModelListener.onFailure(e); - }); - - SubscribableListener.newForked(sender::startAsynchronously) - .andThen(authFactory::getAuthenticationApplier) - .andThen((authListener, authApplier) -> { - var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); - var request = new ElasticInferenceServiceAuthorizationRequest( - baseUrl, - getCurrentTraceInfo(), - requestMetadata, - authApplier - ); - sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); - }) - .andThenApply(authResult -> { - if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { - logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); - return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity, baseUrl); + private void getAuthorization( + ActionListener listener, + Sender sender, + boolean checkCcmState + ) { + var countdownListener = ActionListener.runAfter(listener, requestCompleteLatch::countDown); + + try { + if (checkCcmState && ccmFeature.isCcmSupportedEnvironment()) { + var isCcmEnabledListener = ActionListener.wrap(enabled -> { + if (enabled == null || enabled == false) { + logger.debug("CCM is not enabled, skipping authorization request to Elastic Inference Service"); + countdownListener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); + } else { + retrieveAuthorizationInformation(countdownListener, sender); } + }, e -> { + logger.atWarn().withThrowable(e).log("Failed to determine if CCM is enabled, returning unauthorized"); + countdownListener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); + }); + + ccmService.isEnabled(isCcmEnabledListener); + } else { + retrieveAuthorizationInformation(countdownListener, sender); + } + } catch (Exception e) { + logger.atWarn().withThrowable(e).log("Retrieving the authorization information encountered an exception"); + countdownListener.onFailure(e); + } + } - var errorMessage = Strings.format( - "%s Received an invalid response type from the Elastic Inference Service: %s", - FAILED_TO_RETRIEVE_MESSAGE, - authResult.getClass().getSimpleName() - ); + private void retrieveAuthorizationInformation(ActionListener listener, Sender sender) { + logger.debug("Retrieving authorization information from the Elastic Inference Service."); - logger.warn(errorMessage); - throw new ElasticsearchException(errorMessage); - }) - .addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); - } catch (Exception e) { - logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); - requestCompleteLatch.countDown(); - listener.onFailure(e); + if (Strings.isNullOrEmpty(baseUrl)) { + logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); + listener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); + return; } + + var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> { + // unwrap because it's likely a retry exception + var exception = ExceptionsHelper.unwrapCause(e); + + logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); + authModelListener.onFailure(e); + }); + + SubscribableListener.newForked(sender::startAsynchronously) + .andThen(authFactory::getAuthenticationApplier) + .andThen((authListener, authApplier) -> { + var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); + var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata, authApplier); + sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); + }) + .andThenApply(authResult -> { + if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); + return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity, baseUrl); + } + + var errorMessage = Strings.format( + "%s Received an invalid response type from the Elastic Inference Service: %s", + FAILED_TO_RETRIEVE_MESSAGE, + authResult.getClass().getSimpleName() + ); + + logger.warn(errorMessage); + throw new ElasticsearchException(errorMessage); + }) + .addListener(handleFailuresListener); } private TraceContext getCurrentTraceInfo() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 4efed45b9cf7e..3a81008c84c21 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -28,6 +28,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -36,7 +38,6 @@ import java.util.EnumSet; import java.util.List; import java.util.Set; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -51,20 +52,18 @@ import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -87,13 +86,20 @@ public void shutdown() throws IOException { public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger, createNoopApplierFactory()); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + null, + threadPool, + logger, + createNoopApplierFactory(), + createMockCcmFeature(false), + createMockCcmService(false) + ); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertTrue(authResponse.getTaskTypes().isEmpty()); assertTrue(authResponse.getEndpointIds().isEmpty()); assertFalse(authResponse.isAuthorized()); @@ -109,13 +115,20 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger, createNoopApplierFactory()); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + "", + threadPool, + logger, + createNoopApplierFactory(), + createMockCcmFeature(false), + createMockCcmService(false) + ); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertTrue(authResponse.getTaskTypes().isEmpty()); assertTrue(authResponse.getEndpointIds().isEmpty()); assertFalse(authResponse.isAuthorized()); @@ -136,7 +149,9 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep eisGatewayUrl, threadPool, logger, - createNoopApplierFactory() + createNoopApplierFactory(), + createMockCcmFeature(false), + createMockCcmService(false) ); try (var sender = senderFactory.createSender()) { @@ -160,7 +175,7 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT)); + var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat(exception.getMessage(), containsString("failed to parse field [inference_endpoints]")); var stringCaptor = ArgumentCaptor.forClass(String.class); @@ -183,19 +198,20 @@ private void queueWebServerResponsesForRetries(String responseJson) { } } - public void testGetAuthorization_ReturnsFailure_WhenExceptionOccurs() throws IOException { + public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); - - var exceptionToThrow = new IllegalStateException("exception"); var logger = mock(Logger.class); - doThrow(exceptionToThrow).when(logger).debug(anyString()); + var mockCcmFeature = createMockCcmFeature(false); + var mockCcmService = createMockCcmService(false); var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( eisGatewayUrl, threadPool, logger, - createNoopApplierFactory() + createNoopApplierFactory(), + mockCcmFeature, + mockCcmService ); try (var sender = senderFactory.createSender()) { @@ -206,22 +222,105 @@ public void testGetAuthorization_ReturnsFailure_WhenExceptionOccurs() throws IOE PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat( + authResponse.getTaskTypes(), + is( + EnumSet.of( + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING, + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION + ) + ) + ); + assertThat(authResponse.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new))); + assertTrue(authResponse.isAuthorized()); + assertThat( + authResponse.getEndpoints(responseData.inferenceIds()), + containsInAnyOrder(responseData.expectedEndpoints().toArray(ElasticInferenceServiceModel[]::new)) + ); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(1)).debug(loggerArgsCaptor.capture()); + + var message = loggerArgsCaptor.getValue(); + assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + + assertNoAuthHeader(webServer.requests()); + + authHandler.waitForAuthRequestCompletion(ESTestCase.TEST_REQUEST_TIMEOUT); + + // It should never check if the CCM environment is supported since getAuthorization does not attempt to skip the authorization + // check + verify(mockCcmFeature, never()).isCcmSupportedEnvironment(); + verify(mockCcmService, never()).isEnabled(any()); + } + } + + private static void assertNoAuthHeader(List requests) { + assertThat(requests.size(), is(1)); + assertNull(requests.get(0).getHeader(HttpHeaders.AUTHORIZATION)); + } + + public void testGetAuthorizationIfPermittedEnvironment_ReturnsFailure_WhenExceptionThrown() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + + var exceptionToThrow = new IllegalStateException("exception"); + var mockCcmFeature = mock(CCMFeature.class); + when(mockCcmFeature.isCcmSupportedEnvironment()).thenThrow(exceptionToThrow); + var mockCcmService = createMockCcmService(false); + + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + createNoopApplierFactory(), + mockCcmFeature, + mockCcmService + ); + + try (var sender = senderFactory.createSender()) { + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorizationIfPermittedEnvironment(listener, sender); + + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat(exception, is(exceptionToThrow)); - assertThat(webServer.requests().size(), is(0)); + // There should be no requests made to EIS because an exception should be thrown before the request is made + assertThat(webServer.requests(), empty()); + + authHandler.waitForAuthRequestCompletion(TimeValue.THIRTY_SECONDS); + + verify(mockCcmService, never()).isEnabled(any()); } } - public void testGetAuthorization_ReturnsAValidResponse() throws IOException { + public void testGetAuthorizationIfPermittedEnvironment_ReturnsAValidResponse_WhenNotCcmSupportedEnvironment() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); - var logger = mock(Logger.class); + + var mockCcmFeature = createMockCcmFeature(false); + var mockCcmService = createMockCcmService(randomBoolean()); + + assertReturnsValidResponse(eisGatewayUrl, mockCcmFeature, mockCcmService, senderFactory, 0); + } + + private void assertReturnsValidResponse( + String eisGatewayUrl, + CCMFeature mockCcmFeature, + CCMService mockCcmService, + HttpRequestSender.Factory senderFactory, + int numIsEnabledCalls + ) throws IOException { var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( eisGatewayUrl, threadPool, logger, - createNoopApplierFactory() + createNoopApplierFactory(), + mockCcmFeature, + mockCcmService ); try (var sender = senderFactory.createSender()) { @@ -230,9 +329,9 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); PlainActionFuture listener = new PlainActionFuture<>(); - authHandler.getAuthorization(listener, sender); + authHandler.getAuthorizationIfPermittedEnvironment(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat( authResponse.getTaskTypes(), is( @@ -252,19 +351,102 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { containsInAnyOrder(responseData.expectedEndpoints().toArray(ElasticInferenceServiceModel[]::new)) ); - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger, times(1)).debug(loggerArgsCaptor.capture()); + assertNoAuthHeader(webServer.requests()); - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + authHandler.waitForAuthRequestCompletion(TimeValue.THIRTY_SECONDS); - assertNoAuthHeader(webServer.requests()); + verify(mockCcmFeature, times(1)).isCcmSupportedEnvironment(); + verify(mockCcmService, times(numIsEnabledCalls)).isEnabled(any()); } } - private static void assertNoAuthHeader(List requests) { - assertThat(requests.size(), is(1)); - assertNull(requests.get(0).getHeader(HttpHeaders.AUTHORIZATION)); + public void testGetAuthorizationIfPermittedEnvironment_ReturnsAValidResponse_WhenCcmSupportedEnvironmentAndEnabled() + throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + + var mockCcmFeature = createMockCcmFeature(true); + var mockCcmService = createMockCcmService(true); + + assertReturnsValidResponse(eisGatewayUrl, mockCcmFeature, mockCcmService, senderFactory, 1); + } + + public void testGetAuthorizationIfPermittedEnvironment_ReturnsFailure_ForCcmSupportedEnvironment_WhenEnabledThrows() + throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + + var mockCcmFeature = createMockCcmFeature(true); + + var exceptionToThrow = new IllegalStateException("exception"); + var mockCcmService = createMockCcmServiceWithOnFailureCall(exceptionToThrow); + + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + createNoopApplierFactory(), + mockCcmFeature, + mockCcmService + ); + + try (var sender = senderFactory.createSender()) { + var responseData = getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); + + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorizationIfPermittedEnvironment(listener, sender); + + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(authResponse.getTaskTypes(), is(EnumSet.noneOf(TaskType.class))); + assertThat(authResponse.getEndpointIds(), empty()); + assertFalse(authResponse.isAuthorized()); + assertThat(authResponse.getEndpoints(responseData.inferenceIds()), empty()); + + // There should be no requests made to EIS because it is not configured + assertThat(webServer.requests(), empty()); + + authHandler.waitForAuthRequestCompletion(TimeValue.THIRTY_SECONDS); + + verify(mockCcmFeature, times(1)).isCcmSupportedEnvironment(); + } + } + + public void testGetAuthorizationIfPermittedEnvironment_ReturnsUnauthorized_WhenCcmSupportedEnvironmentAndDisabled() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + var logger = mock(Logger.class); + + var mockCcmFeature = createMockCcmFeature(true); + var mockCcmService = createMockCcmService(false); + + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + logger, + createNoopApplierFactory(), + mockCcmFeature, + mockCcmService + ); + + try (var sender = senderFactory.createSender()) { + var responseData = getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); + + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorizationIfPermittedEnvironment(listener, sender); + + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(authResponse.getTaskTypes(), is(EnumSet.noneOf(TaskType.class))); + assertThat(authResponse.getEndpointIds(), empty()); + assertFalse(authResponse.isAuthorized()); + assertThat(authResponse.getEndpoints(responseData.inferenceIds()), empty()); + + // There should be no requests made to EIS because it is not configured + assertThat(webServer.requests(), empty()); + + authHandler.waitForAuthRequestCompletion(TimeValue.THIRTY_SECONDS); + + verify(mockCcmFeature, times(1)).isCcmSupportedEnvironment(); + verify(mockCcmService, times(1)).isEnabled(any()); + } } public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws IOException { @@ -277,7 +459,9 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I eisGatewayUrl, threadPool, logger, - createApplierFactory(secret) + createApplierFactory(secret), + createMockCcmFeature(false), + createMockCcmService(false) ); var elserResponseBody = getEisElserAuthorizationResponse(eisGatewayUrl).responseJson(); @@ -288,7 +472,7 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); assertThat(authResponse.getEndpointIds(), is(Set.of(ELSER_V2_ENDPOINT_ID))); @@ -317,7 +501,9 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { eisGatewayUrl, threadPool, logger, - createNoopApplierFactory() + createNoopApplierFactory(), + createMockCcmFeature(false), + createMockCcmService(false) ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -329,9 +515,9 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { try (var sender = senderFactory.createSender()) { authHandler.getAuthorization(onlyOnceListener, sender); - authHandler.waitForAuthRequestCompletion(TIMEOUT); + authHandler.waitForAuthRequestCompletion(ESTestCase.TEST_REQUEST_TIMEOUT); - var authResponse = listener.actionGet(TIMEOUT); + var authResponse = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); assertThat(authResponse.getEndpointIds(), is(Set.of(ELSER_V2_ENDPOINT_ID))); assertTrue(authResponse.isAuthorized()); @@ -357,13 +543,20 @@ public void testGetAuthorization_InvalidResponse() throws IOException { }).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any()); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger, createNoopApplierFactory()); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + "abc", + threadPool, + logger, + createNoopApplierFactory(), + createMockCcmFeature(false), + createMockCcmService(false) + ); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service")); @@ -373,4 +566,30 @@ public void testGetAuthorization_InvalidResponse() throws IOException { assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service.")); } } + + private static CCMFeature createMockCcmFeature(boolean isCcmSupportedEnvironment) { + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(isCcmSupportedEnvironment); + return ccmFeature; + } + + private static CCMService createMockCcmService(boolean isEnabled) { + var ccmService = mock(CCMService.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(isEnabled); + return Void.TYPE; + }).when(ccmService).isEnabled(any()); + return ccmService; + } + + private static CCMService createMockCcmServiceWithOnFailureCall(Exception exception) { + var ccmService = mock(CCMService.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(exception); + return Void.TYPE; + }).when(ccmService).isEnabled(any()); + return ccmService; + } }