diff --git a/.github/workflows/sql-cli-release-workflow.yml b/.github/workflows/sql-cli-release-workflow.yml index a7042bcd327..a5eb0c4da05 100644 --- a/.github/workflows/sql-cli-release-workflow.yml +++ b/.github/workflows/sql-cli-release-workflow.yml @@ -20,6 +20,17 @@ jobs: - name: Checkout SQL CLI uses: actions/checkout@v2 + # dependencies: ml-commons + - name: Checkout ml-commons + uses: actions/checkout@v2 + with: + repository: 'opensearch-project/ml-commons' + path: ml-commons + ref: 'main' + - name: Build ml-commons + working-directory: ./ml-commons + run: ./gradlew publishToMavenLocal + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/.github/workflows/sql-cli-test-and-build-workflow.yml b/.github/workflows/sql-cli-test-and-build-workflow.yml index 876780a86c2..c07ff95ecaa 100644 --- a/.github/workflows/sql-cli-test-and-build-workflow.yml +++ b/.github/workflows/sql-cli-test-and-build-workflow.yml @@ -17,6 +17,17 @@ jobs: - name: Checkout SQL CLI uses: actions/checkout@v2 + # dependencies: ml-commons + - name: Checkout ml-commons + uses: actions/checkout@v2 + with: + repository: 'opensearch-project/ml-commons' + path: ml-commons + ref: 'main' + - name: Build ml-commons + working-directory: ./ml-commons + run: ./gradlew publishToMavenLocal + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/.github/workflows/sql-release-workflow.yml b/.github/workflows/sql-release-workflow.yml index 974f801d368..a7d6947f4b1 100644 --- a/.github/workflows/sql-release-workflow.yml +++ b/.github/workflows/sql-release-workflow.yml @@ -15,6 +15,17 @@ jobs: runs-on: ubuntu-latest steps: + # dependencies: ml-commons + - name: Checkout ml-commons + uses: actions/checkout@v2 + with: + repository: 'opensearch-project/ml-commons' + path: ml-commons + ref: 'main' + - name: Build ml-commons + working-directory: ./ml-commons + run: ./gradlew publishToMavenLocal + - name: Checkout SQL uses: actions/checkout@v1 diff --git a/.github/workflows/sql-test-and-build-workflow.yml b/.github/workflows/sql-test-and-build-workflow.yml index c6c010fd83e..4ffffb0268a 100644 --- a/.github/workflows/sql-test-and-build-workflow.yml +++ b/.github/workflows/sql-test-and-build-workflow.yml @@ -17,6 +17,17 @@ jobs: uses: actions/setup-java@v1 with: java-version: 1.14 + + # dependencies: ml-commons + - name: Checkout ml-commons + uses: actions/checkout@v2 + with: + repository: 'opensearch-project/ml-commons' + path: ml-commons + ref: 'main' + - name: Build ml-commons + working-directory: ./ml-commons + run: ./gradlew publishToMavenLocal - name: Build with Gradle run: ./gradlew build assemble -Dopensearch.version=${{ env.OPENSEARCH_VERSION }} diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 3ab6dcb420d..93367cc4138 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -36,6 +36,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -47,6 +48,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.data.model.ExprMissingValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -60,6 +62,7 @@ import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRareTopN; @@ -366,6 +369,20 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) { return new LogicalValues(valueExprs); } + /** + * Build {@link LogicalMLCommons} for Kmeans command. + */ + @Override + public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List options = node.getOptions(); + + TypeEnvironment currentEnv = context.peek(); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); + + return new LogicalMLCommons(child, "kmeans", options); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index aa04aa5ccef..f591007ad15 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -37,6 +37,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -234,4 +235,8 @@ public T visitLimit(Limit node, C context) { public T visitSpan(Span node, C context) { return visitChildren(node, context); } + + public T visitKmeans(Kmeans node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 1266eae73f3..e17318eda18 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -142,6 +142,14 @@ public static Literal longLiteral(Long value) { return literal(value, DataType.LONG); } + public static Literal shortLiteral(Short value) { + return literal(value, DataType.SHORT); + } + + public static Literal floatLiteral(Float value) { + return literal(value, DataType.FLOAT); + } + public static Literal dateLiteral(String value) { return literal(value, DataType.DATE); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java b/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java index ddea7f2f26d..8755a15177f 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java @@ -21,6 +21,8 @@ public enum DataType { INTEGER(ExprCoreType.INTEGER), LONG(ExprCoreType.LONG), + SHORT(ExprCoreType.SHORT), + FLOAT(ExprCoreType.FLOAT), DOUBLE(ExprCoreType.DOUBLE), STRING(ExprCoreType.STRING), BOOLEAN(ExprCoreType.BOOLEAN), diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java new file mode 100644 index 00000000000..9adfd04fb4c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java @@ -0,0 +1,40 @@ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class Kmeans extends UnresolvedPlan { + private UnresolvedPlan child; + + private final List options; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitKmeans(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java new file mode 100644 index 00000000000..c4b44317dd8 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java @@ -0,0 +1,38 @@ +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Argument; + +/** + * ml-commons logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalMLCommons extends LogicalPlan { + private final String algorithm; + + private final List arguments; + + /** + * Constructor of LogicalMLCommons. + * @param child child logical plan + * @param algorithm algorithm name + * @param arguments arguments of the algorithm + */ + public LogicalMLCommons(LogicalPlan child, String algorithm, + List arguments) { + super(Collections.singletonList(child)); + this.algorithm = algorithm; + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitMLCommons(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 5c11d230a10..c1f0d5d0418 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -69,4 +69,8 @@ public R visitRareTopN(LogicalRareTopN plan, C context) { public R visitLimit(LogicalLimit plan, C context) { return visitNode(plan, context); } + + public R visitMLCommons(LogicalMLCommons plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 110a4ff16b2..fb7e3d0fe3f 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -72,4 +72,8 @@ public R visitLimit(LimitOperator node, C context) { return visitNode(node, context); } + public R visitMLCommons(PhysicalPlan node, C context) { + return visitNode(node, context); + } + } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 0908a8bc8ad..2e9a6fe843a 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -41,11 +41,13 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.ContextConfiguration; @@ -644,4 +646,15 @@ public void named_aggregator_with_condition() { ) ); } + + @Test + public void kmeanns_relation() { + assertAnalyzeEqual( + new LogicalMLCommons(LogicalPlanDSL.relation("schema"), + "kmeans", + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))), + new Kmeans(AstDSL.relation("schema"), + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) + ); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index a6a0a9d519e..f3fe6b5a84f 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -108,6 +109,12 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { relation, CommandType.TOP, ImmutableList.of(expression), expression); assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + + LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), + "kmeans", + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); + assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { + }, null)); } private static class NodesCount extends LogicalPlanNodeVisitor { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 092abb87caa..7e86f3e68a1 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -134,6 +134,14 @@ public void test_PhysicalPlanVisitor_should_return_null() { }, null)); } + @Test + public void test_visitMLCommons() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null)); + } + public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor { public String print(PhysicalPlan node) { diff --git a/opensearch/build.gradle b/opensearch/build.gradle index ebe5372e2fe..c2905c54f33 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,6 +37,7 @@ dependencies { compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.11.4' compile group: 'org.json', name: 'json', version:'20180813' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" + compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' @@ -75,7 +76,9 @@ jacocoTestCoverageVerification { rule { element = 'CLASS' excludes = [ - 'org.opensearch.sql.opensearch.security.SecurityAccess' + 'org.opensearch.sql.opensearch.security.SecurityAccess', + 'org.opensearch.sql.opensearch.planner.physical.*', + 'org.opensearch.sql.opensearch.client.MLClient' ] limit { counter = 'LINE' diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java new file mode 100644 index 00000000000..19f49d0e5f2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java @@ -0,0 +1,25 @@ +package org.opensearch.sql.opensearch.client; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; + + +public class MLClient { + private static MachineLearningNodeClient INSTANCE; + + private MLClient() { + + } + + /** + * get machine learning client. + * @param nodeClient node client + * @return machine learning client + */ + public static MachineLearningNodeClient getMLClient(NodeClient nodeClient) { + if (INSTANCE == null) { + INSTANCE = new MachineLearningNodeClient(nodeClient); + } + return INSTANCE; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java index a6ecaa13d32..67a5ac9e6ae 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.client.node.NodeClient; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -64,4 +65,6 @@ public interface OpenSearchClient { * @param task task */ void schedule(Runnable task); + + NodeClient getNodeClient(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index 9c06586067a..18197e9c339 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -147,6 +147,11 @@ public void schedule(Runnable task) { ); } + @Override + public NodeClient getNodeClient() { + return client; + } + private String[] resolveIndexExpression(ClusterState state, String[] indices) { return resolver.concreteIndexNames(state, IndicesOptions.strictExpandOpen(), true, indices); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index c6a8661daef..91eddfc39ab 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -24,6 +24,7 @@ import org.opensearch.client.indices.GetIndexResponse; import org.opensearch.client.indices.GetMappingsRequest; import org.opensearch.client.indices.GetMappingsResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.sql.opensearch.mapping.IndexMapping; @@ -135,4 +136,9 @@ public void cleanup(OpenSearchRequest request) { public void schedule(Runnable task) { task.run(); } + + @Override + public NodeClient getNodeClient() { + throw new UnsupportedOperationException("Unsupported method."); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index a286737cc4c..2ae4255a546 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.planner.physical.AggregationOperator; import org.opensearch.sql.planner.physical.DedupeOperator; import org.opensearch.sql.planner.physical.EvalOperator; @@ -125,6 +126,17 @@ public PhysicalPlan visitLimit(LimitOperator node, Object context) { node.getOffset()); } + @Override + public PhysicalPlan visitMLCommons(PhysicalPlan node, Object context) { + MLCommonsOperator mlCommonsOperator = (MLCommonsOperator) node; + return doProtect( + new MLCommonsOperator(visitInput(mlCommonsOperator.getInput(), context), + mlCommonsOperator.getAlgorithm(), + mlCommonsOperator.getArguments(), + mlCommonsOperator.getNodeClient()) + ); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; @@ -133,7 +145,7 @@ PhysicalPlan visitInput(PhysicalPlan node, Object context) { } } - private PhysicalPlan doProtect(PhysicalPlan node) { + protected PhysicalPlan doProtect(PhysicalPlan node) { if (isProtected(node)) { return node; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java new file mode 100644 index 00000000000..5401298070d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; + +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +/** + * ml-commons Physical operator to call machine learning interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class MLCommonsOperator extends PhysicalPlan { + @Getter + private final PhysicalPlan input; + + @Getter + private final String algorithm; + + @Getter + private final List arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); + MLInput mlinput = MLInput.builder() + .algorithm(FunctionName.valueOf(algorithm.toUpperCase())) + .parameters(mlAlgoParams) + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + MLPredictionOutput predictionResult = (MLPredictionOutput) machineLearningClient + .trainAndPredict(mlinput) + .actionGet(30, TimeUnit.SECONDS); + Iterator inputRowIter = inputDataFrame.iterator(); + Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); + iterator = new Iterator() { + @Override + public boolean hasNext() { + return inputRowIter.hasNext(); + } + + @Override + public ExprValue next() { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + inputRowIter.next())); + resultBuilder.putAll(convertRowIntoExprValue( + predictionResult.getPredictionResult().columnMetas(), + resultRowIter.next())); + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitMLCommons(this, context); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String algorithm) { + switch (FunctionName.valueOf(algorithm.toUpperCase())) { + case KMEANS: + if (argument.getValue().getValue() instanceof Number) { + return KMeansParams.builder().centroids((Integer) argument.getValue().getValue()).build(); + } else { + throw new IllegalArgumentException("unsupported Kmeans argument type:" + + argument.getValue().getType()); + } + default: + // TODO: update available algorithms in the message when adding a new case + throw new IllegalArgumentException( + String.format("unsupported algorithm: %s, available algorithms: %s.", + FunctionName.valueOf(algorithm.toUpperCase()), KMEANS)); + } + } + + private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + default: + break; + } + } + return resultBuilder.build(); + } + + private DataFrame generateInputDataset() { + List> inputData = new LinkedList<>(); + while (input.hasNext()) { + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); + } + + return DataFrameBuilder.load(inputData); + } +} + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index a90b31f40b4..f116fe62fd6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -22,6 +22,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; @@ -29,6 +30,7 @@ import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -77,7 +79,7 @@ public PhysicalPlan implement(LogicalPlan plan) { * aggregation, filter, will accumulate (push down) OpenSearch query and aggregation DSL on * index scan. */ - return plan.accept(new OpenSearchDefaultImplementor(indexScan), indexScan); + return plan.accept(new OpenSearchDefaultImplementor(indexScan, client), indexScan); } @Override @@ -91,6 +93,8 @@ public static class OpenSearchDefaultImplementor extends DefaultImplementor { private final OpenSearchIndexScan indexScan; + private final OpenSearchClient client; + @Override public PhysicalPlan visitNode(LogicalPlan plan, OpenSearchIndexScan context) { if (plan instanceof OpenSearchLogicalIndexScan) { @@ -158,5 +162,11 @@ public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, public PhysicalPlan visitRelation(LogicalRelation node, OpenSearchIndexScan context) { return indexScan; } + + @Override + public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan context) { + return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), + node.getArguments(), client.getNodeClient()); + } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index ec391e15db3..bcb318793cf 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; @@ -280,6 +281,12 @@ void meta() { assertEquals("cluster-name", meta.get(META_CLUSTER_NAME)); } + @Test + void ml() { + OpenSearchNodeClient client = new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + assertNotNull(client.getNodeClient()); + } + private OpenSearchNodeClient mockClient(String indexName, String mappings) { ClusterService clusterService = mockClusterService(indexName, mappings); return new OpenSearchNodeClient(clusterService, nodeClient); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index e4500972b75..0c2503ea57a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -267,6 +267,11 @@ void metaWithIOException() throws IOException { assertThrows(IllegalStateException.class, () -> client.meta()); } + @Test + void mlWithException() { + assertThrows(UnsupportedOperationException.class, () -> client.getNodeClient()); + } + private Map mockFieldMappings(String indexName, String mappings) throws IOException { return ImmutableMap.of(indexName, IndexMetadata.fromXContent(createParser(mappings)).mapping()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java similarity index 92% rename from opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java rename to opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index c63de400732..fce7cc88ed1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -4,7 +4,7 @@ */ -package org.opensearch.sql.opensearch.executor; +package org.opensearch.sql.opensearch.executor.protector; import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -34,6 +34,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningClient; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; @@ -52,6 +55,7 @@ import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -252,6 +256,21 @@ public void testWithoutProtection() { ); } + @Test + public void testVisitMlCommons() { + NodeClient nodeClient = mock(NodeClient.class); + MLCommonsOperator mlCommonsOperator = + new MLCommonsOperator( + values(emptyList()), + "kmeans", + AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3))), + nodeClient + ); + + assertEquals(executionProtector.doProtect(mlCommonsOperator), + executionProtector.visitMLCommons(mlCommonsOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java new file mode 100644 index 00000000000..260f52770f8 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +@RunWith(MockitoJUnitRunner.Silent.class) +public class MLCommonsOperatorTest { + @Mock + private PhysicalPlan input; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + private MLCommonsOperator mlCommonsOperator; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private MachineLearningNodeClient machineLearningNodeClient; + + @BeforeEach + void setUp() { + mlCommonsOperator = new MLCommonsOperator(input, "kmeans", + AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3)), + AstDSL.argument("k2", AstDSL.stringLiteral("v1")), + AstDSL.argument("k3", AstDSL.booleanLiteral(true)), + AstDSL.argument("k4", AstDSL.doubleLiteral(2.0D)), + AstDSL.argument("k5", AstDSL.shortLiteral((short)2)), + AstDSL.argument("k6", AstDSL.longLiteral(2L)), + AstDSL.argument("k7", AstDSL.floatLiteral(2F))), + nodeClient); + when(input.hasNext()).thenReturn(true).thenReturn(false); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.put("k1", new ExprIntegerValue(2)); + when(input.next()).thenReturn(ExprTupleValue.fromExprValueMap(resultBuilder.build())); + + DataFrame dataFrame = DataFrameBuilder + .load(Collections.singletonList( + ImmutableMap.builder().put("result-k1", 2D) + .put("result-k2", 1) + .put("result-k3", "v3") + .put("result-k4", true) + .put("result-k5", (short)2) + .put("result-k6", 2L) + .put("result-k7", 2F) + .build()) + ); + MLPredictionOutput mlPredictionOutput = MLPredictionOutput.builder() + .taskId("test_task_id") + .status("test_status") + .predictionResult(dataFrame) + .build(); + + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + mlClientMockedStatic.when(() -> MLClient.getMLClient(any(NodeClient.class))) + .thenReturn(machineLearningNodeClient); + when(machineLearningNodeClient.trainAndPredict(any(MLInput.class)) + .actionGet(anyLong(), + eq(TimeUnit.SECONDS))) + .thenReturn(mlPredictionOutput); + } + } + + @Disabled + @Test + public void testOpen() { + mlCommonsOperator.open(); + assertTrue(mlCommonsOperator.hasNext()); + assertNotNull(mlCommonsOperator.next()); + assertFalse(mlCommonsOperator.hasNext()); + } + + @Test + public void testAccept() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor + = new PhysicalPlanNodeVisitor() {}; + assertNull(mlCommonsOperator.accept(physicalPlanNodeVisitor, null)); + } + + @Test + public void testConvertArgumentToMLParameter_UnsupportedType() { + Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("2020-10-31")); + assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator + .convertArgumentToMLParameter(argument, "LINEAR_REGRESSION")); + } + + @Test + public void testConvertArgumentToMLParameter_KMeansUnsupportedType() { + Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("string value")); + assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator + .convertArgumentToMLParameter(argument, "KMEANS")); + } + +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index a29f3f49fd3..52770df8db6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -7,19 +7,27 @@ package org.opensearch.sql.opensearch.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPlan; @ExtendWith(MockitoExtension.class) public class OpenSearchDefaultImplementorTest { @Mock OpenSearchIndexScan indexScan; + @Mock + OpenSearchClient client; /** * For test coverage. @@ -27,7 +35,7 @@ public class OpenSearchDefaultImplementorTest { @Test public void visitInvalidTypeShouldThrowException() { final OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan); + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); final IllegalStateException exception = assertThrows(IllegalStateException.class, () -> implementor.visitNode(relation("index"), @@ -38,4 +46,14 @@ public void visitInvalidTypeShouldThrowException() { + "class org.opensearch.sql.planner.logical.LogicalRelation", exception.getMessage()); } + + @Test + public void visitMachineLearning() { + LogicalMLCommons node = Mockito.mock(LogicalMLCommons.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitMLCommons(node, indexScan)); + } } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index 1c2403f4ff7..14b88c49e73 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -8,4 +8,7 @@ grant { permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.RuntimePermission "defineClass"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; }; diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index a6563bf9e8b..4d105c27b38 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -22,6 +22,7 @@ EVAL: 'EVAL'; HEAD: 'HEAD'; TOP: 'TOP'; RARE: 'RARE'; +KMEANS: 'KMEANS'; // COMMAND ASSIST KEYWORDS AS: 'AS'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 4ca3788c5dc..15bcec67dd9 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -20,7 +20,7 @@ pplStatement /** commands */ commands : whereCommand | fieldsCommand | renameCommand | statsCommand | dedupCommand | sortCommand | evalCommand | headCommand - | topCommand | rareCommand; + | topCommand | rareCommand | kmeansCommand; searchCommand : (SEARCH)? fromClause #searchFrom @@ -84,6 +84,11 @@ rareCommand (byClause)? ; +kmeansCommand + : KMEANS + k=integerLiteral + ; + /** clauses */ fromClause : SOURCE EQUAL tableSource diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 849cfe6fa2d..d4dbe08061a 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -40,6 +40,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.RareTopN.CommandType; @@ -51,6 +52,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ByClauseContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldListContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -281,6 +283,11 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla return aggregate; } + @Override + public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { + return new Kmeans(ArgumentFactory.getArgumentList(ctx)); + } + /** * Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 59ba431873e..59c91a50a5e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -23,6 +23,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; /** @@ -135,6 +136,18 @@ public static List getArgumentList(RareCommandContext ctx) { .singletonList(new Argument("noOfResults", new Literal(10, DataType.INTEGER))); } + /** + * Get list of {@link Argument}. + * + * @param ctx KmeansCommandContext instance + * @return the list of arguments fetched from the kmeans command + */ + public static List getArgumentList(KmeansCommandContext ctx) { + // TODO: add iterations and distanceType parameters for Kemans + return Collections + .singletonList(new Argument("k", getArgumentValue(ctx.k))); + } + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index b874862c65b..ce3f327f09d 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -44,6 +44,7 @@ import org.junit.rules.ExpectedException; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -533,6 +534,12 @@ public void testTopCommandWithMultipleFields() { )); } + @Test + public void testKmeansCommand() { + assertEqual("source=t | kmeans 3", + new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); + } + protected void assertEqual(String query, Node expectedPlan) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan);