diff --git a/core/build.gradle b/core/build.gradle index 63ecd8c104a..697a6bb775c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -82,6 +82,10 @@ test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { violationRules { rule { + element = 'CLASS' + excludes = [ + 'org.opensearch.sql.utils.MLCommonsConstants' + ] limit { counter = 'LINE' minimum = 1.0 diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 93367cc4138..0379035c00b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -11,6 +11,11 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -18,6 +23,7 @@ import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -31,6 +37,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -57,6 +64,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; @@ -383,6 +391,26 @@ public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { return new LogicalMLCommons(child, "kmeans", options); } + /** + * Build {@link LogicalAD} for AD command. + */ + @Override + public LogicalPlan visitAD(AD node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + java.util.Map options = node.getArguments(); + + TypeEnvironment currentEnv = context.peek(); + + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); + if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); + } else { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP); + } + return new LogicalAD(child, options); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index f591007ad15..47368782de2 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -32,6 +32,7 @@ import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -239,4 +240,8 @@ public T visitSpan(Span node, C context) { public T visitKmeans(Kmeans node, C context) { return visitChildren(node, context); } + + public T visitAD(AD node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AD.java b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java new file mode 100644 index 00000000000..4d1c9ebf531 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java @@ -0,0 +1,41 @@ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class AD extends UnresolvedPlan { + private UnresolvedPlan child; + + private final Map arguments; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAD(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java new file mode 100644 index 00000000000..c8c04b18177 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java @@ -0,0 +1,33 @@ +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; + +/* + * AD logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalAD extends LogicalPlan { + private final Map arguments; + + /** + * Constructor of LogicalAD. + * @param child child logical plan + * @param arguments arguments of the algorithm + */ + public LogicalAD(LogicalPlan child, Map arguments) { + super(Collections.singletonList(child)); + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index c1f0d5d0418..5163e44edb2 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -73,4 +73,8 @@ public R visitLimit(LogicalLimit plan, C context) { public R visitMLCommons(LogicalMLCommons plan, C context) { return visitNode(plan, context); } + + public R visitAD(LogicalAD plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index fb7e3d0fe3f..87582df3bbf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -76,4 +76,9 @@ public R visitMLCommons(PhysicalPlan node, C context) { return visitNode(node, context); } + public R visitAD(PhysicalPlan node, C context) { + return visitNode(node, context); + } + + } diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java new file mode 100644 index 00000000000..3e957f1bda0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.utils; + +public class MLCommonsConstants { + + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TIME_DECAY = "time_decay"; + public static final String TIME_FIELD = "time_field"; + + public static final String RCF_SCORE = "score"; + public static final String RCF_ANOMALOUS = "anomalous"; + public static final String RCF_ANOMALY_GRADE = "anomaly_grade"; + public static final String RCF_TIMESTAMP = "timestamp"; +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 2e9a6fe843a..9781204e633 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -34,6 +34,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -41,12 +43,16 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.springframework.context.annotation.Configuration; @@ -657,4 +663,31 @@ public void kmeanns_relation() { AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) ); } + + @Test + public void ad_batchRCF_relation() { + Map argumentMap = + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } + + @Test + public void ad_fitRCF_relation() { + Map argumentMap = new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f3fe6b5a84f..1b8d606211c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -12,6 +12,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -19,6 +20,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -115,6 +118,16 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); + + LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema"), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }); + assertNull(ad.accept(new LogicalPlanNodeVisitor() { + }, null)); } private static class NodesCount extends LogicalPlanNodeVisitor { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 7e86f3e68a1..cd561f3c093 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -142,6 +142,14 @@ public void test_visitMLCommons() { assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null)); } + @Test + public void test_visitAD() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitAD(plan, null)); + } + public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor { public String print(PhysicalPlan node) { diff --git a/opensearch/build.gradle b/opensearch/build.gradle index c2905c54f33..9446786bfd0 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,7 +37,7 @@ dependencies { compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.11.4' compile group: 'org.json', name: 'json', version:'20180813' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" - compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0' + compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0-SNAPSHOT' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 2ae4255a546..3df87db21d2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.planner.physical.AggregationOperator; import org.opensearch.sql.planner.physical.DedupeOperator; @@ -137,6 +138,17 @@ public PhysicalPlan visitMLCommons(PhysicalPlan node, Object context) { ); } + @Override + public PhysicalPlan visitAD(PhysicalPlan node, Object context) { + ADOperator adOperator = (ADOperator) node; + return doProtect( + new ADOperator(visitInput(adOperator.getInput(), context), + adOperator.getArguments(), + adOperator.getNodeClient() + ) + ); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java new file mode 100644 index 00000000000..388b4a47750 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -0,0 +1,109 @@ +package org.opensearch.sql.opensearch.planner.physical; + +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FitRCFParams; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +/** + * AD Physical operator to call AD interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class ADOperator extends MLCommonsOperatorActions { + + @Getter + private final PhysicalPlan input; + + @Getter + private final Map arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + private FunctionName rcfType; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(input); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments); + + MLPredictionOutput predictionResult = + getMLPredictionResult(rcfType, mlAlgoParams, inputDataFrame, nodeClient); + + Iterator inputRowIter = inputDataFrame.iterator(); + Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); + iterator = new Iterator() { + @Override + public boolean hasNext() { + return inputRowIter.hasNext(); + } + + @Override + public ExprValue next() { + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { + if (arguments.get(TIME_FIELD).getValue() == null) { + rcfType = FunctionName.BATCH_RCF; + return BatchRCFParams.builder() + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .build(); + } + rcfType = FunctionName.FIT_RCF; + return FitRCFParams.builder() + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) + .timeField((String) arguments.get(TIME_FIELD).getValue()) + .dateFormat("yyyy-MM-dd HH:mm:ss") + .build(); + } + +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 5401298070d..75870b5ee13 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -7,40 +7,21 @@ import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; -import com.google.common.collect.ImmutableMap; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.dataframe.ColumnMeta; -import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.MLAlgoParams; -import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.data.model.ExprDoubleValue; -import org.opensearch.sql.data.model.ExprFloatValue; -import org.opensearch.sql.data.model.ExprIntegerValue; -import org.opensearch.sql.data.model.ExprLongValue; -import org.opensearch.sql.data.model.ExprShortValue; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.opensearch.client.MLClient; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -50,7 +31,7 @@ */ @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class MLCommonsOperator extends PhysicalPlan { +public class MLCommonsOperator extends MLCommonsOperatorActions { @Getter private final PhysicalPlan input; @@ -69,19 +50,12 @@ public class MLCommonsOperator extends PhysicalPlan { @Override public void open() { super.open(); - DataFrame inputDataFrame = generateInputDataset(); + DataFrame inputDataFrame = generateInputDataset(input); MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); - MLInput mlinput = MLInput.builder() - .algorithm(FunctionName.valueOf(algorithm.toUpperCase())) - .parameters(mlAlgoParams) - .inputDataset(new DataFrameInputDataset(inputDataFrame)) - .build(); - - MachineLearningNodeClient machineLearningClient = - MLClient.getMLClient(nodeClient); - MLPredictionOutput predictionResult = (MLPredictionOutput) machineLearningClient - .trainAndPredict(mlinput) - .actionGet(30, TimeUnit.SECONDS); + MLPredictionOutput predictionResult = + getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()), + mlAlgoParams, inputDataFrame, nodeClient); + Iterator inputRowIter = inputDataFrame.iterator(); Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); iterator = new Iterator() { @@ -92,13 +66,7 @@ public boolean hasNext() { @Override public ExprValue next() { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), - inputRowIter.next())); - resultBuilder.putAll(convertRowIntoExprValue( - predictionResult.getPredictionResult().columnMetas(), - resultRowIter.next())); - return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); } }; } @@ -140,48 +108,5 @@ protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String al } } - private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - for (int i = 0; i < columnMetas.length; i++) { - ColumnValue columnValue = row.getValue(i); - String resultKeyName = columnMetas[i].getName(); - switch (columnValue.columnType()) { - case INTEGER: - resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); - break; - case DOUBLE: - resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); - break; - case STRING: - resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); - break; - case SHORT: - resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); - break; - case LONG: - resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); - break; - case FLOAT: - resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); - break; - default: - break; - } - } - return resultBuilder.build(); - } - - private DataFrame generateInputDataset() { - List> inputData = new LinkedList<>(); - while (input.hasNext()) { - inputData.add(new HashMap() { - { - input.next().tupleValue().forEach((key, value) -> put(key, value.value())); - } - }); - } - - return DataFrameBuilder.load(inputData); - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java new file mode 100644 index 00000000000..201b9c5ec7c --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -0,0 +1,183 @@ +package org.opensearch.sql.opensearch.planner.physical; + +import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +/** + * Common method actions for ml-commons related operators. + */ +public abstract class MLCommonsOperatorActions extends PhysicalPlan { + + /** + * generate ml-commons request input dataset. + * @param input physical input + * @return ml-commons dataframe + */ + protected DataFrame generateInputDataset(PhysicalPlan input) { + List> inputData = new LinkedList<>(); + while (input.hasNext()) { + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); + } + + return DataFrameBuilder.load(inputData); + } + + /** + * covert result schema into ExprValue. + * @param columnMetas column metas + * @param row row + * @return a map of result schema in ExprValue format + */ + protected Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + } + return resultBuilder.build(); + } + + /** + * populate result map by ml-commons supported data type. + * @param columnValue column value + * @param resultKeyName result kay name + * @param resultBuilder result builder + */ + protected void populateResultBuilder(ColumnValue columnValue, + String resultKeyName, + ImmutableMap.Builder resultBuilder) { + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + case BOOLEAN: + resultBuilder.put(resultKeyName, ExprBooleanValue.of(columnValue.booleanValue())); + break; + default: + break; + } + } + + /** + * concert result into ExprValue. + * @param columnMetas column metas + * @param row row + * @param schema schema + * @return a map of result in ExprValue format + */ + protected Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, + Row row, + Map schema) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + // change key name to avoid duplicate key issue in result map + // only value will be shown in the final returned result + if (schema.containsKey(resultKeyName)) { + resultKeyName = resultKeyName + "1"; + } + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + + } + return resultBuilder.build(); + } + + /** + * iterate result and built it into ExprTupleValue. + * @param inputRowIter input row iterator + * @param inputDataFrame input data frame + * @param predictionResult prediction result + * @param resultRowIter result row iterator + * @return result in ExprTupleValue format + */ + protected ExprTupleValue buildResult(Iterator inputRowIter, DataFrame inputDataFrame, + MLPredictionOutput predictionResult, Iterator resultRowIter) { + ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); + resultSchemaBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + inputRowIter.next())); + Map resultSchema = resultSchemaBuilder.build(); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.putAll(convertResultRowIntoExprValue( + predictionResult.getPredictionResult().columnMetas(), + resultRowIter.next(), + resultSchema)); + resultBuilder.putAll(resultSchema); + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + + /** + * get ml-commons train and predict result. + * @param functionName ml-commons algorithm name + * @param mlAlgoParams ml-commons algorithm parameters + * @param inputDataFrame input data frame + * @param nodeClient node client + * @return ml-commons train and predict result + */ + protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, + MLAlgoParams mlAlgoParams, + DataFrame inputDataFrame, + NodeClient nodeClient) { + MLInput mlinput = MLInput.builder() + .algorithm(functionName) + .parameters(mlAlgoParams) + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + + return (MLPredictionOutput) machineLearningClient + .trainAndPredict(mlinput) + .actionGet(30, TimeUnit.SECONDS); + } + +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index f116fe62fd6..69b71fa4848 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -22,6 +22,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -30,6 +31,7 @@ import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; @@ -168,5 +170,11 @@ public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan co return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), node.getArguments(), client.getNodeClient()); } + + @Override + public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { + return new ADOperator(visitChild(node, context), + node.getArguments(), client.getNodeClient()); + } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index fce7cc88ed1..2427ac4fe5c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -35,8 +36,9 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.client.MachineLearningClient; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; @@ -53,8 +55,7 @@ import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; @@ -271,6 +272,23 @@ public void testVisitMlCommons() { executionProtector.visitMLCommons(mlCommonsOperator, null)); } + @Test + public void testVisitAD() { + NodeClient nodeClient = mock(NodeClient.class); + ADOperator adOperator = + new ADOperator( + values(emptyList()), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }, nodeClient); + + assertEquals(executionProtector.doProtect(adOperator), + executionProtector.visitAD(adOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 52770df8db6..0770ea3938c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -18,6 +18,7 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -56,4 +57,14 @@ public void visitMachineLearning() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); assertNotNull(implementor.visitMLCommons(node, indexScan)); } + + @Test + public void visitAD() { + LogicalAD node = Mockito.mock(LogicalAD.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitAD(node, indexScan)); + } } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 4d105c27b38..f2306bb9de7 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -23,6 +23,7 @@ HEAD: 'HEAD'; TOP: 'TOP'; RARE: 'RARE'; KMEANS: 'KMEANS'; +AD: 'AD'; // COMMAND ASSIST KEYWORDS AS: 'AS'; @@ -48,6 +49,9 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +SHINGLE_SIZE: 'SHINGLE_SIZE'; +TIME_DECAY: 'TIME_DECAY'; +TIME_FIELD: 'TIME_FIELD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 15bcec67dd9..27ce4e19ed1 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -20,7 +20,7 @@ pplStatement /** commands */ commands : whereCommand | fieldsCommand | renameCommand | statsCommand | dedupCommand | sortCommand | evalCommand | headCommand - | topCommand | rareCommand | kmeansCommand; + | topCommand | rareCommand | kmeansCommand | adCommand; searchCommand : (SEARCH)? fromClause #searchFrom @@ -89,6 +89,13 @@ kmeansCommand k=integerLiteral ; +adCommand + : AD + (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? + (TIME_DECAY EQUAL time_decay=decimalLiteral)? + (TIME_FIELD EQUAL time_field=stringLiteral)? + ; + /** clauses */ fromClause : SOURCE EQUAL tableSource diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d4dbe08061a..a55192be34c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -35,6 +35,7 @@ import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -50,6 +51,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ByClauseContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldListContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; @@ -288,6 +290,11 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { return new Kmeans(ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { + return new AD(ArgumentFactory.getArgumentMap(ctx)); + } + /** * Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 59c91a50a5e..09cef7c9110 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -7,6 +7,7 @@ package org.opensearch.sql.ppl.utils; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; @@ -14,18 +15,23 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; - /** * Util class to get all arguments as a list from the PPL command. */ @@ -148,11 +154,34 @@ public static List getArgumentList(KmeansCommandContext ctx) { .singletonList(new Argument("k", getArgumentValue(ctx.k))); } + /** + * Get map of {@link Argument}. + * + * @param ctx ADCommandContext instance + * @return the list of arguments fetched from the AD command + */ + public static Map getArgumentMap(AdCommandContext ctx) { + return new HashMap() {{ + put(SHINGLE_SIZE, (ctx.shingle_size != null) + ? getArgumentValue(ctx.shingle_size) + : new Literal(null, DataType.INTEGER)); + put(TIME_DECAY, (ctx.time_decay != null) + ? getArgumentValue(ctx.time_decay) + : new Literal(null, DataType.DOUBLE)); + put(TIME_FIELD, (ctx.time_field != null) + ? getArgumentValue(ctx.time_field) + : new Literal(null, DataType.STRING)); + } + }; + } + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) : ctx instanceof BooleanLiteralContext ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : ctx instanceof DecimalLiteralContext + ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index ce3f327f09d..274a3ace774 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -38,12 +38,16 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; +import java.util.HashMap; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -540,6 +544,28 @@ public void testKmeansCommand() { new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); } + @Test + public void test_fitRCFADCommand() { + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp'", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(10, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + } + })); + } + + @Test + public void test_batchRCFADCommand() { + assertEqual("source=t | AD", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(null, DataType.INTEGER)); + put("time_decay", new Literal(null, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + })); + } + protected void assertEqual(String query, Node expectedPlan) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan);