Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ public void clearRequests() {
requests.clear();
}

/**
* Removes all responses from the queue.
*/
public void clearResponses() {
responses.clear();
}

/**
* A utility method to peek into the requests and find out if #MockWebServer.takeRequests will not throw an out of bound exception
* @return true if more requests are available, false otherwise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.integration;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequestBuilder;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder;
import org.elasticsearch.action.support.PlainActionFuture;
Expand All @@ -19,6 +21,7 @@
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
Expand Down Expand Up @@ -64,6 +67,7 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {

public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";

private static final Logger logger = LogManager.getLogger(AuthorizationTaskExecutorIT.class);
private static final MockWebServer webServer = new MockWebServer();
private static String gatewayUrl;
private static String chatCompletionResponseBody;
Expand All @@ -80,8 +84,6 @@ public static void initClass() throws IOException {

@Before
public void createComponents() {
// Adding an empty response to ensure that the initial authorization polling request does not fail
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
modelRegistry = node().injector().getInstance(ModelRegistry.class);
authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class);
}
Expand All @@ -95,14 +97,26 @@ static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) {
// Delete all the eis preconfigured endpoints
var listener = new PlainActionFuture<Boolean>();
modelRegistry.deleteModels(EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
listener.actionGet(TimeValue.THIRTY_SECONDS);
try {
listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT);
} catch (Exception e) {
logger.atWarn().withThrowable(e).log("Failed to delete eis preconfigured endpoints");
}
}

@AfterClass
public static void cleanUpClass() {
webServer.close();
}

@Override
protected void startNode(long seed) throws Exception {
// Adding an empty response to ensure that the initial authorization polling request does not fail
// We're doing this before the node is started to ensure that the authorization task isn't created yet
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
super.startNode(seed);
}

@Override
protected Settings nodeSettings() {
return Settings.builder()
Expand Down Expand Up @@ -131,7 +145,7 @@ protected boolean resetNodeAfterTest() {
public void testCreatesEisChatCompletionEndpoint() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.clearRequests();
resetWebServerQueues();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();

Expand Down Expand Up @@ -246,14 +260,14 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.clearRequests();
resetWebServerQueues();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();

assertChatCompletionEndpointExists();

webServer.clearRequests();
resetWebServerQueues();
// Simulate that the model is no longer authorized
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
restartPollingTaskAndWaitForAuthResponse();
Expand All @@ -262,6 +276,11 @@ public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthor
assertChatCompletionEndpointExists();
}

private static void resetWebServerQueues() {
webServer.clearRequests();
webServer.clearResponses();
}

private void assertChatCompletionEndpointExists() throws Exception {
assertChatCompletionEndpointExists(modelRegistry);
}
Expand All @@ -286,22 +305,22 @@ static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesMode
public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.clearRequests();
resetWebServerQueues();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();

assertChatCompletionEndpointExists();

// Simulate that the model is no longer authorized
webServer.clearRequests();
resetWebServerQueues();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();

assertChatCompletionEndpointExists();

webServer.clearRequests();
resetWebServerQueues();
// Simulate that a text embedding model is now authorized
var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl)
.responseJson();
Expand All @@ -327,7 +346,7 @@ public void testRestartsTaskAfterAbort() throws Exception {
// Ensure the task is created and we get an initial authorization response
assertNoAuthorizedEisEndpoints();

webServer.clearRequests();
resetWebServerQueues();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
// Abort the task and ensure it is restarted
restartPollingTaskAndWaitForAuthResponse();
Expand Down