Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/143567.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
area: Inference
issues: []
pr: 143567
summary: "[Inference API] Update authorized endpoints when their fingerprint or version\
\ changed"
type: enhancement
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
* Endpoint metadata contains descriptive information for an inference endpoint. This information allows an upstream service to communicate
Expand Down Expand Up @@ -104,6 +105,16 @@ public boolean isEmpty() {
return this.equals(EMPTY_INSTANCE);
}

public boolean fingerprintMatches(EndpointMetadata other) {
return Objects.equals(internal.fingerprint(), other.internal.fingerprint());
}

public boolean hasNewerVersionThan(EndpointMetadata other) {
long thisVersion = Optional.ofNullable(internal.version()).orElse(0L);
long otherVersion = Optional.ofNullable(other.internal.version()).orElse(0L);
return thisVersion > otherVersion;
}

public Params getXContentParamsExcludeInternalFields() {
return new ToXContent.MapParams(Map.of(INCLUDE_INTERNAL_FIELDS_PARAM_NAME, Boolean.FALSE.toString()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,84 @@ public void testToXContentExcludesInternalWhenParamSet() throws IOException {
assertThat(json, is(XContentHelper.stripWhitespace(NON_EMPTY_ENDPOINT_METADATA_JSON_WITHOUT_INTERNAL)));
}

public void testFingerprintMatches() {
EndpointMetadata endpointWithNullFingerprint1 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, null),
randomDisplay()
);
EndpointMetadata endpointWithNullFingerprint2 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, null),
randomDisplay()
);
EndpointMetadata endpointWithFingerprintAbc1 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal("abc", null),
randomDisplay()
);
EndpointMetadata endpointWithFingerprintAbc2 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal("abc", null),
randomDisplay()
);
EndpointMetadata endpointWithFingerprintXyz1 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal("xyz", null),
randomDisplay()
);
EndpointMetadata endpointWithFingerprintXyz2 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal("xyz", null),
randomDisplay()
);

assertThat(endpointWithNullFingerprint1.fingerprintMatches(endpointWithNullFingerprint2), is(true));
assertThat(endpointWithNullFingerprint1.fingerprintMatches(endpointWithFingerprintAbc1), is(false));
assertThat(endpointWithNullFingerprint1.fingerprintMatches(endpointWithFingerprintXyz1), is(false));

assertThat(endpointWithFingerprintAbc1.fingerprintMatches(endpointWithFingerprintAbc2), is(true));
assertThat(endpointWithFingerprintXyz1.fingerprintMatches(endpointWithFingerprintXyz2), is(true));

assertThat(endpointWithFingerprintXyz1.fingerprintMatches(endpointWithFingerprintAbc1), is(false));
}

public void testHasNewerVersionThan() {
EndpointMetadata endpointWithNullVersion1 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, null),
randomDisplay()
);
EndpointMetadata endpointWithNullVersion2 = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, null),
randomDisplay()
);
EndpointMetadata endpointWithVersionFour = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, 4L),
randomDisplay()
);
EndpointMetadata anotherEndpointWithVersionFour = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, 4L),
randomDisplay()
);
EndpointMetadata endpointWithVersionFive = new EndpointMetadata(
randomHeuristics(),
new EndpointMetadata.Internal(null, 5L),
randomDisplay()
);

assertThat(endpointWithNullVersion1.hasNewerVersionThan(endpointWithNullVersion2), is(false));
assertThat(endpointWithNullVersion1.hasNewerVersionThan(endpointWithVersionFour), is(false));
assertThat(endpointWithVersionFour.hasNewerVersionThan(endpointWithNullVersion1), is(true));
assertThat(endpointWithVersionFour.hasNewerVersionThan(anotherEndpointWithVersionFour), is(false));
assertThat(endpointWithVersionFour.hasNewerVersionThan(endpointWithVersionFive), is(false));
assertThat(endpointWithVersionFive.hasNewerVersionThan(endpointWithVersionFour), is(true));
assertThat(endpointWithVersionFive.hasNewerVersionThan(endpointWithNullVersion2), is(true));
}

@Override
protected EndpointMetadata createTestInstance() {
return randomInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.AdminClient;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
Expand All @@ -24,13 +25,16 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests;
import org.junit.After;
import org.junit.AfterClass;
Expand All @@ -51,6 +55,7 @@
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.createAuthorizedEndpoint;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -351,6 +356,69 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep
assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME));
}

public void testEndpointGetsUpdated_GivenFingerprintChanges_FromNull() throws Exception {
testEndpointGetsUpdated_GivenFingerprintChanged(null, randomAlphaOfLength(10));
}

public void testEndpointGetsUpdated_GivenFingerprintChanges_FromNonNull() throws Exception {
String originalFingerprint = randomAlphaOfLength(10);
testEndpointGetsUpdated_GivenFingerprintChanged(
originalFingerprint,
randomValueOtherThan(originalFingerprint, () -> randomAlphaOfLength(10))
);
}

private void testEndpointGetsUpdated_GivenFingerprintChanged(String originalFingerprint, String updatedFingerprint) throws Exception {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up I've struggled to get these type of tests to succeed reliably. A number of them have been muted: #138012

If you have time, another set of eyes would be good to figure out why they are so flaky. I've tried to fix them a few times 😞

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did notice those issues. I wonder if #143584 might have been some of that. But I couldn't find full logs. Next time we get a failure I'll dive in.

assertNoAuthorizedEisEndpoints();

resetWebServerQueues();
String endpointId = randomAlphaOfLength(10);
AuthorizedEndpoint originalAuthEndpoint = createAuthorizedEndpoint(
endpointId,
randomFrom(
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION,
TaskType.EMBEDDING,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING,
TaskType.SPARSE_EMBEDDING
),
() -> originalFingerprint
);
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(toJsonWrappedInInferenceEndpointsList(originalAuthEndpoint)));
restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();

assertBusy(() -> assertThat(getEisEndpoints(modelRegistry).size(), is(1)));

var eisEndpoints = getEisEndpoints(modelRegistry);
assertThat(eisEndpoints.size(), is(1));
var endpoint = eisEndpoints.get(0);
assertThat(endpoint.inferenceEntityId(), is(originalAuthEndpoint.id()));
assertThat(endpoint.endpointMetadata().internal().fingerprint(), is(originalFingerprint));

resetWebServerQueues();
// Simulate the fingerprint has now been set
AuthorizedEndpoint updatedAuthEndpoint = createAuthorizedEndpoint(
originalAuthEndpoint.id(),
endpoint.taskType(),
() -> updatedFingerprint
);
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(toJsonWrappedInInferenceEndpointsList(updatedAuthEndpoint)));

restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();

assertBusy(() -> {
var postUpdateEndpoints = getEisEndpoints(modelRegistry);
assertThat(postUpdateEndpoints.size(), is(1));
var updated = postUpdateEndpoints.get(0);
assertThat(updated.inferenceEntityId(), is(updatedAuthEndpoint.id()));
assertThat(updated.endpointMetadata().internal().fingerprint(), is(updatedFingerprint));
});

}

public void testRestartsTaskAfterAbort() throws Exception {
// Ensure the task is created and we get an initial authorization response
assertNoAuthorizedEisEndpoints();
Expand All @@ -361,4 +429,17 @@ public void testRestartsTaskAfterAbort() throws Exception {
restartPollingTaskAndWaitForAuthResponse();
assertWebServerReceivedRequest();
}

private static String toJsonWrappedInInferenceEndpointsList(AuthorizedEndpoint... endpoints) throws IOException {
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
builder.startObject();
builder.startArray("inference_endpoints");
for (AuthorizedEndpoint endpoint : endpoints) {
builder.value(endpoint);
}
builder.endArray();
builder.endObject();
return Strings.toString(builder);
}
}
}
Loading