diff --git a/plugin/build.gradle b/plugin/build.gradle index 51c53541e4..5694318c99 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -386,11 +386,7 @@ String baseName = "mlCommonsBwcCluster" String bwcMlPlugin = "opensearch-ml-" + bwcVersion + ".zip" String bwcFilePath = "src/test/resources/org/opensearch/ml/bwc/" String bwcRemoteFile = "https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/" + bwcShortVersion + "/latest/linux/x64/tar/builds/opensearch/plugins/" + bwcMlPlugin -String project_no_snapshot = project.version.replace("-SNAPSHOT","") -String opensearch_no_snapshot = opensearch_version.replace("-SNAPSHOT","") -String opensearchMlPlugin = "opensearch-ml-" + project_no_snapshot + ".zip" -String opensearchMlRemoteFile = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + opensearch_no_snapshot + - '/latest/linux/x64/tar/builds/opensearch/plugins/' + opensearchMlPlugin +String opensearchMlPlugin = "opensearch-ml-" + project.version + ".zip" 2.times {i -> testClusters { @@ -404,6 +400,10 @@ String opensearchMlRemoteFile = 'https://ci.opensearch.org/ci/dbc/distribution-b return new RegularFile() { @Override File getAsFile() { + File bwcDir = new File('./plugin/' + bwcFilePath) + if (!bwcDir.exists()) { + bwcDir.mkdirs() + } File dir = new File('./plugin/' + bwcFilePath + bwcVersion) if (!dir.exists()) { dir.mkdirs() @@ -430,13 +430,11 @@ List> plugins = [ return new RegularFile() { @Override File getAsFile() { - if (new File('./plugin/' + bwcFilePath + project.version).exists()) { - project.delete(files('./plugin/' + bwcFilePath + project.version)) + project.mkdir "$bwcFilePath/$project.version" + copy { + from "$buildDir/distributions/$opensearchMlPlugin" + into "$bwcFilePath/$project.version" } - project.mkdir bwcFilePath + project.version - ant.get(src: opensearchMlRemoteFile, - dest: bwcFilePath + project.version, - httpusecaches: false) return fileTree(bwcFilePath + project.version).getSingleFile() } } diff --git a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityIT.java b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityIT.java index 393a5e4e48..00b4440b36 100644 --- a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityIT.java +++ b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityIT.java @@ -13,9 +13,12 @@ import java.util.stream.Collectors; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; import org.junit.Assume; import org.junit.Before; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.FunctionName; @@ -26,6 +29,9 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + public class MLCommonsBackwardsCompatibilityIT extends MLCommonsBackwardsCompatibilityRestTestCase { private final ClusterType CLUSTER_TYPE = ClusterType.parse(System.getProperty("tests.rest.bwcsuite")); @@ -171,31 +177,120 @@ public void testBackwardsCompatibility() throws Exception { ArrayList rows = (ArrayList) predictionResult.get("rows"); assertTrue(rows.size() > 1); }); - } else if (opensearchVersion.equals("2.5.0")) { - // train predict with old data + } else if (isNewerVersion(opensearchVersion)) { ingestIrisData(irisIndex); - trainAndPredict(client(), FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, predictionResult -> { - ArrayList rows = (ArrayList) predictionResult.get("rows"); - assertTrue(rows.size() > 0); - }); + try { + trainAndPredict( + client(), + FunctionName.KMEANS, + irisIndex, + kMeansParams, + searchSourceBuilder, + predictionResult -> { + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 0); + } + ); + } catch (ResponseException e1) { + mlNodeSettingShifting(); + try { + trainAndPredict( + client(), + FunctionName.KMEANS, + irisIndex, + kMeansParams, + searchSourceBuilder, + predictionResult -> { + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 0); + } + ); + } catch (ResponseException e2) { + Map modelResponseMap = gson.fromJson(("{" + e2.getMessage().split("[{]", 2)[1]), Map.class); + Map errorMap = (Map) modelResponseMap.get("error"); + List> rootCauses = (List>) errorMap.get("root_cause"); + Set rootCauseTypeSet = rootCauses.stream().map(map -> map.get("type")).collect(Collectors.toSet()); + assertEquals("m_l_limit_exceeded_exception", rootCauseTypeSet.iterator().next().toString()); + break; + } + } } else { throw new AssertionError("Cannot get the correct version for opensearch ml-commons plugin for the bwc test."); } break; case UPGRADED: assertTrue(pluginNames.contains("opensearch-ml")); - assertEquals("2.5.0", opensearchVersion); + assertTrue(isNewerVersion(opensearchVersion)); ingestIrisData(irisIndex); - trainAndPredict(client(), FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, predictionResult -> { - ArrayList rows = (ArrayList) predictionResult.get("rows"); - assertTrue(rows.size() > 0); - }); + try { + trainAndPredict(client(), FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, predictionResult -> { + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 0); + }); + } catch (ResponseException e1) { + mlNodeSettingShifting(); + try { + trainAndPredict( + client(), + FunctionName.KMEANS, + irisIndex, + kMeansParams, + searchSourceBuilder, + predictionResult -> { + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 0); + } + ); + } catch (ResponseException e2) { + Map modelResponseMap = gson.fromJson(("{" + e2.getMessage().split("[{]", 2)[1]), Map.class); + Map errorMap = (Map) modelResponseMap.get("error"); + List> rootCauses = (List>) errorMap.get("root_cause"); + Set rootCauseTypeSet = rootCauses.stream().map(map -> map.get("type")).collect(Collectors.toSet()); + assertEquals("m_l_limit_exceeded_exception", rootCauseTypeSet.iterator().next().toString()); + memoryThresholdSettingShifting(); + trainAndPredict( + client(), + FunctionName.KMEANS, + irisIndex, + kMeansParams, + searchSourceBuilder, + predictionResult -> { + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 0); + } + ); + } + } break; } break; } } + private void mlNodeSettingShifting() throws IOException { + Response bwcResponse = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.only_run_on_ml_node\":false}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, bwcResponse.getStatusLine().getStatusCode()); + } + + private void memoryThresholdSettingShifting() throws IOException { + String jsonEntity = "{\n" + + " \"persistent\" : {\n" + + " \"plugins.ml_commons.native_memory_threshold\" : 100 \n" + + " }\n" + + "}"; + Response bwcResponse = TestHelper + .makeRequest(client(), "PUT", "_cluster/settings", ImmutableMap.of(), TestHelper.toHttpEntity(jsonEntity), null); + assertEquals(200, bwcResponse.getStatusLine().getStatusCode()); + } + private String getModelIdWithFunctionName(FunctionName functionName) throws IOException { String modelQuery = "{\"query\": {" + "\"term\": {" @@ -217,6 +312,10 @@ private String getModelIdWithFunctionName(FunctionName functionName) throws IOEx return modelIdSet.iterator().next().toString(); } + private boolean isNewerVersion(String osVersion) { + return (Integer.parseInt(osVersion.substring(2, 3)) > 4) || (Integer.parseInt(osVersion.substring(0, 1)) > 2); + } + private void verifyMlResponse(String uri) throws Exception { Response response = TestHelper.makeRequest(client(), "GET", uri, null, TestData.matchAllSearchQuery(), null); HttpEntity entity = response.getEntity();