Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ public void testDeploymentSurvivesRestart() throws Exception {
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);
assertInfer(modelId);
assertBusy(() -> {
try {
assertInfer(modelId);
} catch (ResponseException e) {
throw new AssertionError("Inference failed on old cluster", e);
}
}, 90, TimeUnit.SECONDS);
} else {
ensureHealth(".ml-inference-*,.ml-config*", (request -> {
request.addParameter("wait_for_status", "yellow");
Expand Down Expand Up @@ -141,7 +147,10 @@ private void waitForDeploymentStarted(String modelId) throws Exception {

private void assertInfer(String modelId) throws IOException {
Response inference = infer("my words", modelId);
assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference_results\":[{\"predicted_value\":[[1.0,1.0]]}]}"));
String expectedResponse = oldClusterHasInferEndpoint()
? "{\"inference_results\":[{\"predicted_value\":[[1.0,1.0]]}]}"
: "{\"predicted_value\":[[1.0,1.0]]}";
assertThat(EntityUtils.toString(inference.getEntity()), equalTo(expectedResponse));
}

private void putModelDefinition(String modelId) throws IOException {
Expand Down Expand Up @@ -232,8 +241,15 @@ private Response getTrainedModelStats(String modelId) throws IOException {
return response;
}

private boolean oldClusterHasInferEndpoint() {
return isRunningAgainstOldCluster() == false || getOldClusterTestVersion().onOrAfter("8.3.0");
}

private Response infer(String input, String modelId) throws IOException {
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer");
String endpoint = oldClusterHasInferEndpoint()
? "/_ml/trained_models/" + modelId + "/_infer"
: "/_ml/trained_models/" + modelId + "/deployment/_infer?timeout=30s";
Request request = new Request("POST", endpoint);
request.setJsonEntity(Strings.format("""
{ "docs": [{"input":"%s"}] }
""", input));
Expand Down
Loading