diff --git a/core/build.gradle b/core/build.gradle index d26af11cc2d..a0f0cf53e9e 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -81,6 +81,10 @@ test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { violationRules { rule { + element = 'CLASS' + excludes = [ + 'org.opensearch.sql.utils.MLCommonsConstants' + ] limit { counter = 'LINE' minimum = 1.0 diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 52216aefdae..968fa07f18d 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -11,6 +11,11 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -18,6 +23,7 @@ import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -31,11 +37,13 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; @@ -58,11 +66,13 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRareTopN; @@ -395,6 +405,40 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) { return new LogicalValues(valueExprs); } + /** + * Build {@link LogicalMLCommons} for Kmeans command. + */ + @Override + public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List options = node.getOptions(); + + TypeEnvironment currentEnv = context.peek(); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); + + return new LogicalMLCommons(child, "kmeans", options); + } + + /** + * Build {@link LogicalAD} for AD command. + */ + @Override + public LogicalPlan visitAD(AD node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + java.util.Map options = node.getArguments(); + + TypeEnvironment currentEnv = context.peek(); + + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); + if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); + } else { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP); + } + return new LogicalAD(child, options); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 86f5a6ebc8a..5708bb3b99c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -32,11 +32,13 @@ import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; @@ -239,4 +241,12 @@ public T visitLimit(Limit node, C context) { public T visitSpan(Span node, C context) { return visitChildren(node, context); } + + public T visitKmeans(Kmeans node, C context) { + return visitChildren(node, context); + } + + public T visitAD(AD node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 65f060a9214..3478697f4a9 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -153,6 +153,14 @@ public static Literal longLiteral(Long value) { return literal(value, DataType.LONG); } + public static Literal shortLiteral(Short value) { + return literal(value, DataType.SHORT); + } + + public static Literal floatLiteral(Float value) { + return literal(value, DataType.FLOAT); + } + public static Literal dateLiteral(String value) { return literal(value, DataType.DATE); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java b/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java index ddea7f2f26d..8755a15177f 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/DataType.java @@ -21,6 +21,8 @@ public enum DataType { INTEGER(ExprCoreType.INTEGER), LONG(ExprCoreType.LONG), + SHORT(ExprCoreType.SHORT), + FLOAT(ExprCoreType.FLOAT), DOUBLE(ExprCoreType.DOUBLE), STRING(ExprCoreType.STRING), BOOLEAN(ExprCoreType.BOOLEAN), diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AD.java b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java new file mode 100644 index 00000000000..e9aee25c230 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class AD extends UnresolvedPlan { + private UnresolvedPlan child; + + private final Map arguments; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAD(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java new file mode 100644 index 00000000000..34099ebbbd3 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class Kmeans extends UnresolvedPlan { + private UnresolvedPlan child; + + private final List options; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitKmeans(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java new file mode 100644 index 00000000000..c8c04b18177 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java @@ -0,0 +1,33 @@ +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; + +/* + * AD logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalAD extends LogicalPlan { + private final Map arguments; + + /** + * Constructor of LogicalAD. + * @param child child logical plan + * @param arguments arguments of the algorithm + */ + public LogicalAD(LogicalPlan child, Map arguments) { + super(Collections.singletonList(child)); + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java new file mode 100644 index 00000000000..c4b44317dd8 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalMLCommons.java @@ -0,0 +1,38 @@ +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Argument; + +/** + * ml-commons logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalMLCommons extends LogicalPlan { + private final String algorithm; + + private final List arguments; + + /** + * Constructor of LogicalMLCommons. + * @param child child logical plan + * @param algorithm algorithm name + * @param arguments arguments of the algorithm + */ + public LogicalMLCommons(LogicalPlan child, String algorithm, + List arguments) { + super(Collections.singletonList(child)); + this.algorithm = algorithm; + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitMLCommons(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 5c11d230a10..5163e44edb2 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -69,4 +69,12 @@ public R visitRareTopN(LogicalRareTopN plan, C context) { public R visitLimit(LogicalLimit plan, C context) { return visitNode(plan, context); } + + public R visitMLCommons(LogicalMLCommons plan, C context) { + return visitNode(plan, context); + } + + public R visitAD(LogicalAD plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 110a4ff16b2..87582df3bbf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -72,4 +72,13 @@ public R visitLimit(LimitOperator node, C context) { return visitNode(node, context); } + public R visitMLCommons(PhysicalPlan node, C context) { + return visitNode(node, context); + } + + public R visitAD(PhysicalPlan node, C context) { + return visitNode(node, context); + } + + } diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java new file mode 100644 index 00000000000..3e957f1bda0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.utils; + +public class MLCommonsConstants { + + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TIME_DECAY = "time_decay"; + public static final String TIME_FIELD = "time_field"; + + public static final String RCF_SCORE = "score"; + public static final String RCF_ANOMALOUS = "anomalous"; + public static final String RCF_ANOMALY_GRADE = "anomaly_grade"; + public static final String RCF_TIMESTAMP = "timestamp"; +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 63ee4f827a0..fde22f2485a 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -35,6 +35,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -42,12 +44,18 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.springframework.context.annotation.Configuration; import org.springframework.test.annotation.DirtiesContext; @@ -690,4 +698,42 @@ public void parse_relation() { AstDSL.alias("string_value", qualifiedName("string_value")) )); } + + @Test + public void kmeanns_relation() { + assertAnalyzeEqual( + new LogicalMLCommons(LogicalPlanDSL.relation("schema"), + "kmeans", + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))), + new Kmeans(AstDSL.relation("schema"), + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) + ); + } + + @Test + public void ad_batchRCF_relation() { + Map argumentMap = + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } + + @Test + public void ad_fitRCF_relation() { + Map argumentMap = new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index a6a0a9d519e..1b8d606211c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -12,12 +12,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -108,6 +112,22 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { relation, CommandType.TOP, ImmutableList.of(expression), expression); assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + + LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), + "kmeans", + AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); + assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { + }, null)); + + LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema"), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }); + assertNull(ad.accept(new LogicalPlanNodeVisitor() { + }, null)); } private static class NodesCount extends LogicalPlanNodeVisitor { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 092abb87caa..cd561f3c093 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -134,6 +134,22 @@ public void test_PhysicalPlanVisitor_should_return_null() { }, null)); } + @Test + public void test_visitMLCommons() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null)); + } + + @Test + public void test_visitAD() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitAD(plan, null)); + } + public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor { public String print(PhysicalPlan node) { diff --git a/opensearch/build.gradle b/opensearch/build.gradle index 3b32c1ee554..726b56f390b 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,6 +37,7 @@ dependencies { compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.12.6' compile group: 'org.json', name: 'json', version:'20180813' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" + compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0-SNAPSHOT' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' @@ -75,7 +76,9 @@ jacocoTestCoverageVerification { rule { element = 'CLASS' excludes = [ - 'org.opensearch.sql.opensearch.security.SecurityAccess' + 'org.opensearch.sql.opensearch.security.SecurityAccess', + 'org.opensearch.sql.opensearch.planner.physical.*', + 'org.opensearch.sql.opensearch.client.MLClient' ] limit { counter = 'LINE' diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java new file mode 100644 index 00000000000..19f49d0e5f2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/MLClient.java @@ -0,0 +1,25 @@ +package org.opensearch.sql.opensearch.client; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; + + +public class MLClient { + private static MachineLearningNodeClient INSTANCE; + + private MLClient() { + + } + + /** + * get machine learning client. + * @param nodeClient node client + * @return machine learning client + */ + public static MachineLearningNodeClient getMLClient(NodeClient nodeClient) { + if (INSTANCE == null) { + INSTANCE = new MachineLearningNodeClient(nodeClient); + } + return INSTANCE; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java index 5961560f55f..c1b7d782d27 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.client.node.NodeClient; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -64,4 +65,6 @@ public interface OpenSearchClient { * @param task task */ void schedule(Runnable task); + + NodeClient getNodeClient(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index 7bc71391636..b66a1dc7ed3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -147,6 +147,11 @@ public void schedule(Runnable task) { ); } + @Override + public NodeClient getNodeClient() { + return client; + } + private String[] resolveIndexExpression(ClusterState state, String[] indices) { return resolver.concreteIndexNames(state, IndicesOptions.strictExpandOpen(), true, indices); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index 0ff860cb0bf..9da8c442e09 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -24,6 +24,7 @@ import org.opensearch.client.indices.GetIndexResponse; import org.opensearch.client.indices.GetMappingsRequest; import org.opensearch.client.indices.GetMappingsResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.sql.opensearch.mapping.IndexMapping; @@ -135,4 +136,9 @@ public void cleanup(OpenSearchRequest request) { public void schedule(Runnable task) { task.run(); } + + @Override + public NodeClient getNodeClient() { + throw new UnsupportedOperationException("Unsupported method."); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index aec8800944e..45d2b126204 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -8,6 +8,8 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.planner.physical.AggregationOperator; import org.opensearch.sql.planner.physical.DedupeOperator; import org.opensearch.sql.planner.physical.EvalOperator; @@ -126,6 +128,28 @@ public PhysicalPlan visitLimit(LimitOperator node, Object context) { node.getOffset()); } + @Override + public PhysicalPlan visitMLCommons(PhysicalPlan node, Object context) { + MLCommonsOperator mlCommonsOperator = (MLCommonsOperator) node; + return doProtect( + new MLCommonsOperator(visitInput(mlCommonsOperator.getInput(), context), + mlCommonsOperator.getAlgorithm(), + mlCommonsOperator.getArguments(), + mlCommonsOperator.getNodeClient()) + ); + } + + @Override + public PhysicalPlan visitAD(PhysicalPlan node, Object context) { + ADOperator adOperator = (ADOperator) node; + return doProtect( + new ADOperator(visitInput(adOperator.getInput(), context), + adOperator.getArguments(), + adOperator.getNodeClient() + ) + ); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; @@ -134,7 +158,7 @@ PhysicalPlan visitInput(PhysicalPlan node, Object context) { } } - private PhysicalPlan doProtect(PhysicalPlan node) { + protected PhysicalPlan doProtect(PhysicalPlan node) { if (isProtected(node)) { return node; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java new file mode 100644 index 00000000000..acf3bbdc22c --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FitRCFParams; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +/** + * AD Physical operator to call AD interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class ADOperator extends MLCommonsOperatorActions { + + @Getter + private final PhysicalPlan input; + + @Getter + private final Map arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + private FunctionName rcfType; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(input); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments); + + MLPredictionOutput predictionResult = + getMLPredictionResult(rcfType, mlAlgoParams, inputDataFrame, nodeClient); + + Iterator inputRowIter = inputDataFrame.iterator(); + Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); + iterator = new Iterator() { + @Override + public boolean hasNext() { + return inputRowIter.hasNext(); + } + + @Override + public ExprValue next() { + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { + if (arguments.get(TIME_FIELD).getValue() == null) { + rcfType = FunctionName.BATCH_RCF; + return BatchRCFParams.builder() + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .build(); + } + rcfType = FunctionName.FIT_RCF; + return FitRCFParams.builder() + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) + .timeField((String) arguments.get(TIME_FIELD).getValue()) + .dateFormat("yyyy-MM-dd HH:mm:ss") + .build(); + } + +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java new file mode 100644 index 00000000000..75870b5ee13 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +/** + * ml-commons Physical operator to call machine learning interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class MLCommonsOperator extends MLCommonsOperatorActions { + @Getter + private final PhysicalPlan input; + + @Getter + private final String algorithm; + + @Getter + private final List arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(input); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); + MLPredictionOutput predictionResult = + getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()), + mlAlgoParams, inputDataFrame, nodeClient); + + Iterator inputRowIter = inputDataFrame.iterator(); + Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); + iterator = new Iterator() { + @Override + public boolean hasNext() { + return inputRowIter.hasNext(); + } + + @Override + public ExprValue next() { + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitMLCommons(this, context); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String algorithm) { + switch (FunctionName.valueOf(algorithm.toUpperCase())) { + case KMEANS: + if (argument.getValue().getValue() instanceof Number) { + return KMeansParams.builder().centroids((Integer) argument.getValue().getValue()).build(); + } else { + throw new IllegalArgumentException("unsupported Kmeans argument type:" + + argument.getValue().getType()); + } + default: + // TODO: update available algorithms in the message when adding a new case + throw new IllegalArgumentException( + String.format("unsupported algorithm: %s, available algorithms: %s.", + FunctionName.valueOf(algorithm.toUpperCase()), KMEANS)); + } + } + +} + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java new file mode 100644 index 00000000000..21b232c031e --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.planner.physical; + +import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +/** + * Common method actions for ml-commons related operators. + */ +public abstract class MLCommonsOperatorActions extends PhysicalPlan { + + /** + * generate ml-commons request input dataset. + * @param input physical input + * @return ml-commons dataframe + */ + protected DataFrame generateInputDataset(PhysicalPlan input) { + List> inputData = new LinkedList<>(); + while (input.hasNext()) { + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); + } + + return DataFrameBuilder.load(inputData); + } + + /** + * covert result schema into ExprValue. + * @param columnMetas column metas + * @param row row + * @return a map of result schema in ExprValue format + */ + protected Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + } + return resultBuilder.build(); + } + + /** + * populate result map by ml-commons supported data type. + * @param columnValue column value + * @param resultKeyName result kay name + * @param resultBuilder result builder + */ + protected void populateResultBuilder(ColumnValue columnValue, + String resultKeyName, + ImmutableMap.Builder resultBuilder) { + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + case BOOLEAN: + resultBuilder.put(resultKeyName, ExprBooleanValue.of(columnValue.booleanValue())); + break; + default: + break; + } + } + + /** + * concert result into ExprValue. + * @param columnMetas column metas + * @param row row + * @param schema schema + * @return a map of result in ExprValue format + */ + protected Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, + Row row, + Map schema) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + // change key name to avoid duplicate key issue in result map + // only value will be shown in the final returned result + if (schema.containsKey(resultKeyName)) { + resultKeyName = resultKeyName + "1"; + } + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + + } + return resultBuilder.build(); + } + + /** + * iterate result and built it into ExprTupleValue. + * @param inputRowIter input row iterator + * @param inputDataFrame input data frame + * @param predictionResult prediction result + * @param resultRowIter result row iterator + * @return result in ExprTupleValue format + */ + protected ExprTupleValue buildResult(Iterator inputRowIter, DataFrame inputDataFrame, + MLPredictionOutput predictionResult, Iterator resultRowIter) { + ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); + resultSchemaBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + inputRowIter.next())); + Map resultSchema = resultSchemaBuilder.build(); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.putAll(convertResultRowIntoExprValue( + predictionResult.getPredictionResult().columnMetas(), + resultRowIter.next(), + resultSchema)); + resultBuilder.putAll(resultSchema); + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + + /** + * get ml-commons train and predict result. + * @param functionName ml-commons algorithm name + * @param mlAlgoParams ml-commons algorithm parameters + * @param inputDataFrame input data frame + * @param nodeClient node client + * @return ml-commons train and predict result + */ + protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, + MLAlgoParams mlAlgoParams, + DataFrame inputDataFrame, + NodeClient nodeClient) { + MLInput mlinput = MLInput.builder() + .algorithm(functionName) + .parameters(mlAlgoParams) + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + + return (MLPredictionOutput) machineLearningClient + .trainAndPredict(mlinput) + .actionGet(30, TimeUnit.SECONDS); + } + +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index e0cde82a817..49301cbf536 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -22,6 +22,8 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -30,6 +32,8 @@ import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -88,7 +92,7 @@ public PhysicalPlan implement(LogicalPlan plan) { * aggregation, filter, will accumulate (push down) OpenSearch query and aggregation DSL on * index scan. */ - return plan.accept(new OpenSearchDefaultImplementor(indexScan), indexScan); + return plan.accept(new OpenSearchDefaultImplementor(indexScan, client), indexScan); } @Override @@ -102,6 +106,8 @@ public static class OpenSearchDefaultImplementor extends DefaultImplementor { private final OpenSearchIndexScan indexScan; + private final OpenSearchClient client; + @Override public PhysicalPlan visitNode(LogicalPlan plan, OpenSearchIndexScan context) { if (plan instanceof OpenSearchLogicalIndexScan) { @@ -169,5 +175,17 @@ public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, public PhysicalPlan visitRelation(LogicalRelation node, OpenSearchIndexScan context) { return indexScan; } + + @Override + public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan context) { + return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), + node.getArguments(), client.getNodeClient()); + } + + @Override + public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { + return new ADOperator(visitChild(node, context), + node.getArguments(), client.getNodeClient()); + } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index ec391e15db3..bcb318793cf 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; @@ -280,6 +281,12 @@ void meta() { assertEquals("cluster-name", meta.get(META_CLUSTER_NAME)); } + @Test + void ml() { + OpenSearchNodeClient client = new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + assertNotNull(client.getNodeClient()); + } + private OpenSearchNodeClient mockClient(String indexName, String mappings) { ClusterService clusterService = mockClusterService(indexName, mappings); return new OpenSearchNodeClient(clusterService, nodeClient); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index e4500972b75..0c2503ea57a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -267,6 +267,11 @@ void metaWithIOException() throws IOException { assertThrows(IllegalStateException.class, () -> client.meta()); } + @Test + void mlWithException() { + assertThrows(UnsupportedOperationException.class, () -> client.getNodeClient()); + } + private Map mockFieldMappings(String indexName, String mappings) throws IOException { return ImmutableMap.of(indexName, IndexMetadata.fromXContent(createParser(mappings)).mapping()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java similarity index 86% rename from opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java rename to opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index c63de400732..2427ac4fe5c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -4,7 +4,7 @@ */ -package org.opensearch.sql.opensearch.executor; +package org.opensearch.sql.opensearch.executor.protector; import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -34,6 +35,10 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.node.NodeClient; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; @@ -50,8 +55,8 @@ import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; +import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -252,6 +257,38 @@ public void testWithoutProtection() { ); } + @Test + public void testVisitMlCommons() { + NodeClient nodeClient = mock(NodeClient.class); + MLCommonsOperator mlCommonsOperator = + new MLCommonsOperator( + values(emptyList()), + "kmeans", + AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3))), + nodeClient + ); + + assertEquals(executionProtector.doProtect(mlCommonsOperator), + executionProtector.visitMLCommons(mlCommonsOperator, null)); + } + + @Test + public void testVisitAD() { + NodeClient nodeClient = mock(NodeClient.class); + ADOperator adOperator = + new ADOperator( + values(emptyList()), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }, nodeClient); + + assertEquals(executionProtector.doProtect(adOperator), + executionProtector.visitAD(adOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java new file mode 100644 index 00000000000..260f52770f8 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorTest.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +@RunWith(MockitoJUnitRunner.Silent.class) +public class MLCommonsOperatorTest { + @Mock + private PhysicalPlan input; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + private MLCommonsOperator mlCommonsOperator; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private MachineLearningNodeClient machineLearningNodeClient; + + @BeforeEach + void setUp() { + mlCommonsOperator = new MLCommonsOperator(input, "kmeans", + AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3)), + AstDSL.argument("k2", AstDSL.stringLiteral("v1")), + AstDSL.argument("k3", AstDSL.booleanLiteral(true)), + AstDSL.argument("k4", AstDSL.doubleLiteral(2.0D)), + AstDSL.argument("k5", AstDSL.shortLiteral((short)2)), + AstDSL.argument("k6", AstDSL.longLiteral(2L)), + AstDSL.argument("k7", AstDSL.floatLiteral(2F))), + nodeClient); + when(input.hasNext()).thenReturn(true).thenReturn(false); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.put("k1", new ExprIntegerValue(2)); + when(input.next()).thenReturn(ExprTupleValue.fromExprValueMap(resultBuilder.build())); + + DataFrame dataFrame = DataFrameBuilder + .load(Collections.singletonList( + ImmutableMap.builder().put("result-k1", 2D) + .put("result-k2", 1) + .put("result-k3", "v3") + .put("result-k4", true) + .put("result-k5", (short)2) + .put("result-k6", 2L) + .put("result-k7", 2F) + .build()) + ); + MLPredictionOutput mlPredictionOutput = MLPredictionOutput.builder() + .taskId("test_task_id") + .status("test_status") + .predictionResult(dataFrame) + .build(); + + try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + mlClientMockedStatic.when(() -> MLClient.getMLClient(any(NodeClient.class))) + .thenReturn(machineLearningNodeClient); + when(machineLearningNodeClient.trainAndPredict(any(MLInput.class)) + .actionGet(anyLong(), + eq(TimeUnit.SECONDS))) + .thenReturn(mlPredictionOutput); + } + } + + @Disabled + @Test + public void testOpen() { + mlCommonsOperator.open(); + assertTrue(mlCommonsOperator.hasNext()); + assertNotNull(mlCommonsOperator.next()); + assertFalse(mlCommonsOperator.hasNext()); + } + + @Test + public void testAccept() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor + = new PhysicalPlanNodeVisitor() {}; + assertNull(mlCommonsOperator.accept(physicalPlanNodeVisitor, null)); + } + + @Test + public void testConvertArgumentToMLParameter_UnsupportedType() { + Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("2020-10-31")); + assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator + .convertArgumentToMLParameter(argument, "LINEAR_REGRESSION")); + } + + @Test + public void testConvertArgumentToMLParameter_KMeansUnsupportedType() { + Argument argument = AstDSL.argument("k2", AstDSL.dateLiteral("string value")); + assertThrows(IllegalArgumentException.class, () -> mlCommonsOperator + .convertArgumentToMLParameter(argument, "KMEANS")); + } + +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index a29f3f49fd3..0770ea3938c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -7,19 +7,28 @@ package org.opensearch.sql.opensearch.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPlan; @ExtendWith(MockitoExtension.class) public class OpenSearchDefaultImplementorTest { @Mock OpenSearchIndexScan indexScan; + @Mock + OpenSearchClient client; /** * For test coverage. @@ -27,7 +36,7 @@ public class OpenSearchDefaultImplementorTest { @Test public void visitInvalidTypeShouldThrowException() { final OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan); + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); final IllegalStateException exception = assertThrows(IllegalStateException.class, () -> implementor.visitNode(relation("index"), @@ -38,4 +47,24 @@ public void visitInvalidTypeShouldThrowException() { + "class org.opensearch.sql.planner.logical.LogicalRelation", exception.getMessage()); } + + @Test + public void visitMachineLearning() { + LogicalMLCommons node = Mockito.mock(LogicalMLCommons.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitMLCommons(node, indexScan)); + } + + @Test + public void visitAD() { + LogicalAD node = Mockito.mock(LogicalAD.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitAD(node, indexScan)); + } } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index 1c2403f4ff7..14b88c49e73 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -8,4 +8,7 @@ grant { permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.RuntimePermission "defineClass"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; }; diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index eee8fe46a66..189a329de69 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -23,6 +23,8 @@ HEAD: 'HEAD'; TOP: 'TOP'; RARE: 'RARE'; PARSE: 'PARSE'; +KMEANS: 'KMEANS'; +AD: 'AD'; // COMMAND ASSIST KEYWORDS AS: 'AS'; @@ -48,6 +50,9 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +SHINGLE_SIZE: 'SHINGLE_SIZE'; +TIME_DECAY: 'TIME_DECAY'; +TIME_FIELD: 'TIME_FIELD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 932f81e83d6..d6cd1e99b85 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -20,7 +20,7 @@ pplStatement /** commands */ commands : whereCommand | fieldsCommand | renameCommand | statsCommand | dedupCommand | sortCommand | evalCommand | headCommand - | topCommand | rareCommand | parseCommand; + | topCommand | rareCommand | parseCommand | kmeansCommand | adCommand; searchCommand : (SEARCH)? fromClause #searchFrom @@ -87,6 +87,18 @@ rareCommand parseCommand : PARSE expression pattern ; + +kmeansCommand + : KMEANS + k=integerLiteral + ; + +adCommand + : AD + (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? + (TIME_DECAY EQUAL time_decay=decimalLiteral)? + (TIME_FIELD EQUAL time_field=stringLiteral)? + ; /** clauses */ fromClause diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ab1509129b9..88b61fbcb8a 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -37,11 +37,13 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -52,8 +54,10 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ByClauseContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldListContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -303,6 +307,16 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla return aggregate; } + @Override + public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { + return new Kmeans(ArgumentFactory.getArgumentList(ctx)); + } + + @Override + public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { + return new AD(ArgumentFactory.getArgumentMap(ctx)); + } + /** * Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 59ba431873e..09cef7c9110 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -7,6 +7,7 @@ package org.opensearch.sql.ppl.utils; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; @@ -14,16 +15,22 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; - +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; /** * Util class to get all arguments as a list from the PPL command. @@ -135,11 +142,46 @@ public static List getArgumentList(RareCommandContext ctx) { .singletonList(new Argument("noOfResults", new Literal(10, DataType.INTEGER))); } + /** + * Get list of {@link Argument}. + * + * @param ctx KmeansCommandContext instance + * @return the list of arguments fetched from the kmeans command + */ + public static List getArgumentList(KmeansCommandContext ctx) { + // TODO: add iterations and distanceType parameters for Kemans + return Collections + .singletonList(new Argument("k", getArgumentValue(ctx.k))); + } + + /** + * Get map of {@link Argument}. + * + * @param ctx ADCommandContext instance + * @return the list of arguments fetched from the AD command + */ + public static Map getArgumentMap(AdCommandContext ctx) { + return new HashMap() {{ + put(SHINGLE_SIZE, (ctx.shingle_size != null) + ? getArgumentValue(ctx.shingle_size) + : new Literal(null, DataType.INTEGER)); + put(TIME_DECAY, (ctx.time_decay != null) + ? getArgumentValue(ctx.time_decay) + : new Literal(null, DataType.DOUBLE)); + put(TIME_FIELD, (ctx.time_field != null) + ? getArgumentValue(ctx.time_field) + : new Literal(null, DataType.STRING)); + } + }; + } + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) : ctx instanceof BooleanLiteralContext ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : ctx instanceof DecimalLiteralContext + ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 87cc79873a5..5f729e5d06f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -39,12 +39,17 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; +import java.util.HashMap; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -575,6 +580,34 @@ public void testParseCommand() { )); } + @Test + public void testKmeansCommand() { + assertEqual("source=t | kmeans 3", + new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); + } + + @Test + public void test_fitRCFADCommand() { + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp'", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(10, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + } + })); + } + + @Test + public void test_batchRCFADCommand() { + assertEqual("source=t | AD", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(null, DataType.INTEGER)); + put("time_decay", new Literal(null, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + })); + } + protected void assertEqual(String query, Node expectedPlan) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan);