diff --git a/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index cfdc4cc7af77e..896acad42307e 100644 --- a/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -99,7 +99,13 @@ public void testDeploymentSurvivesRestart() throws Exception { putModelDefinition(modelId); putVocabulary(List.of("these", "are", "my", "words"), modelId); startDeployment(modelId); - assertInfer(modelId); + assertBusy(() -> { + try { + assertInfer(modelId); + } catch (ResponseException e) { + throw new AssertionError("Inference failed on old cluster", e); + } + }, 90, TimeUnit.SECONDS); } else { ensureHealth(".ml-inference-*,.ml-config*", (request -> { request.addParameter("wait_for_status", "yellow"); @@ -141,7 +147,10 @@ private void waitForDeploymentStarted(String modelId) throws Exception { private void assertInfer(String modelId) throws IOException { Response inference = infer("my words", modelId); - assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference_results\":[{\"predicted_value\":[[1.0,1.0]]}]}")); + String expectedResponse = oldClusterHasInferEndpoint() + ? "{\"inference_results\":[{\"predicted_value\":[[1.0,1.0]]}]}" + : "{\"predicted_value\":[[1.0,1.0]]}"; + assertThat(EntityUtils.toString(inference.getEntity()), equalTo(expectedResponse)); } private void putModelDefinition(String modelId) throws IOException { @@ -232,8 +241,15 @@ private Response getTrainedModelStats(String modelId) throws IOException { return response; } + private boolean oldClusterHasInferEndpoint() { + return isRunningAgainstOldCluster() == false || getOldClusterTestVersion().onOrAfter("8.3.0"); + } + private Response infer(String input, String modelId) throws IOException { - Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer"); + String endpoint = oldClusterHasInferEndpoint() + ? "/_ml/trained_models/" + modelId + "/_infer" + : "/_ml/trained_models/" + modelId + "/deployment/_infer?timeout=30s"; + Request request = new Request("POST", endpoint); request.setJsonEntity(Strings.format(""" { "docs": [{"input":"%s"}] } """, input));