From 5f3b7e2989f049ad9119ed7caad7fe17e69cbd2a Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Wed, 23 Feb 2022 18:55:14 -0800 Subject: [PATCH 1/8] PPL AD integration Signed-off-by: jackieyanghan --- .../org/opensearch/sql/analysis/Analyzer.java | 18 ++ .../sql/ast/AbstractNodeVisitor.java | 5 + .../java/org/opensearch/sql/ast/tree/AD.java | 42 ++++ .../sql/planner/logical/LogicalAD.java | 31 +++ .../logical/LogicalPlanNodeVisitor.java | 4 + .../physical/PhysicalPlanNodeVisitor.java | 5 + .../OpenSearchExecutionProtector.java | 12 ++ .../planner/physical/ADOperator.java | 185 ++++++++++++++++++ .../opensearch/storage/OpenSearchIndex.java | 8 + ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 2 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 13 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 9 + .../sql/ppl/utils/ArgumentFactory.java | 24 +++ 13 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 core/src/main/java/org/opensearch/sql/ast/tree/AD.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 93367cc4138..9b3c2e2c116 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -31,6 +31,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -57,6 +58,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; @@ -383,6 +385,22 @@ public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { return new LogicalMLCommons(child, "kmeans", options); } + /** + * Build {@link } for Kmeans command. + */ + @Override + public LogicalPlan visitAD(AD node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + java.util.Map options = node.getArguments(); + + TypeEnvironment currentEnv = context.peek(); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "timestamp"), ExprCoreType.TIMESTAMP); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "score"), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomaly_grade"), ExprCoreType.DOUBLE); + + return new LogicalAD(child, options); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index f591007ad15..47368782de2 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -32,6 +32,7 @@ import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -239,4 +240,8 @@ public T visitSpan(Span node, C context) { public T visitKmeans(Kmeans node, C context) { return visitChildren(node, context); } + + public T visitAD(AD node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AD.java b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java new file mode 100644 index 00000000000..6fd1c1dfd9b --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java @@ -0,0 +1,42 @@ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; + +import java.util.List; +import java.util.Map; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +@AllArgsConstructor +public class AD extends UnresolvedPlan { + private UnresolvedPlan child; + + private final Map arguments; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAD(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java new file mode 100644 index 00000000000..fbdfcb1f695 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java @@ -0,0 +1,31 @@ +package org.opensearch.sql.planner.logical; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/* + * AD logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalAD extends LogicalPlan { + private final Map arguments; + + public LogicalAD(LogicalPlan child, Map arguments) { + super(Collections.singletonList(child)); + this.arguments = arguments; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index c1f0d5d0418..5163e44edb2 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -73,4 +73,8 @@ public R visitLimit(LogicalLimit plan, C context) { public R visitMLCommons(LogicalMLCommons plan, C context) { return visitNode(plan, context); } + + public R visitAD(LogicalAD plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index fb7e3d0fe3f..87582df3bbf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -76,4 +76,9 @@ public R visitMLCommons(PhysicalPlan node, C context) { return visitNode(node, context); } + public R visitAD(PhysicalPlan node, C context) { + return visitNode(node, context); + } + + } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 2ae4255a546..3df87db21d2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.planner.physical.AggregationOperator; import org.opensearch.sql.planner.physical.DedupeOperator; @@ -137,6 +138,17 @@ public PhysicalPlan visitMLCommons(PhysicalPlan node, Object context) { ); } + @Override + public PhysicalPlan visitAD(PhysicalPlan node, Object context) { + ADOperator adOperator = (ADOperator) node; + return doProtect( + new ADOperator(visitInput(adOperator.getInput(), context), + adOperator.getArguments(), + adOperator.getNodeClient() + ) + ); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java new file mode 100644 index 00000000000..223c18298ef --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -0,0 +1,185 @@ +package org.opensearch.sql.opensearch.planner.physical; + +import com.google.common.collect.ImmutableMap; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FitRCFParams; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; +import sun.swing.AccumulativeRunnable; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; + +/** + * AD Physical operator to call AD interface to get results for + * algorithm execution. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class ADOperator extends PhysicalPlan { + + @Getter + private final PhysicalPlan input; + + @Getter + private final Map arguments; + + @Getter + private final NodeClient nodeClient; + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + private FunctionName rcfType; + + @Override + public void open() { + super.open(); + DataFrame inputDataFrame = generateInputDataset(); + MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments); + + MLInput mlinput = MLInput.builder() + .algorithm(rcfType) + .parameters(mlAlgoParams) + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + MLPredictionOutput predictionResult = (MLPredictionOutput) machineLearningClient + .trainAndPredict(mlinput) + .actionGet(30, TimeUnit.SECONDS); + Iterator inputRowIter = inputDataFrame.iterator(); + Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); + iterator = new Iterator() { + @Override + public boolean hasNext() { + return inputRowIter.hasNext(); + } + + @Override + public ExprValue next() { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + inputRowIter.next())); + resultBuilder.putAll(convertRowIntoExprValue( + predictionResult.getPredictionResult().columnMetas(), + resultRowIter.next())); + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + }; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitAD(this, context); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public List getChild() { + return null; + } + + protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { + if (arguments.get("time_field").getValue() == null) { + rcfType = FunctionName.BATCH_RCF; + return BatchRCFParams.builder() + .shingleSize((Integer) arguments.get("shingle_size").getValue()) + .build(); + } + rcfType = FunctionName.FIT_RCF; + return FitRCFParams.builder() + .shingleSize((Integer) arguments.get("shingle_size").getValue()) + .timeDecay((Double) arguments.get("time_decay").getValue()) + .timeField((String) arguments.get("time_field").getValue()) + .build(); + } + + private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + default: + break; + } + } + return resultBuilder.build(); + } + + private DataFrame generateInputDataset() { + List> inputData = new LinkedList<>(); + while (input.hasNext()) { + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) + -> put(key, value.value())); + } + }); + } + + return DataFrameBuilder.load(inputData); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index f116fe62fd6..69b71fa4848 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -22,6 +22,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -30,6 +31,7 @@ import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; @@ -168,5 +170,11 @@ public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan co return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), node.getArguments(), client.getNodeClient()); } + + @Override + public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { + return new ADOperator(visitChild(node, context), + node.getArguments(), client.getNodeClient()); + } } } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 4d105c27b38..49501835072 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -23,6 +23,7 @@ HEAD: 'HEAD'; TOP: 'TOP'; RARE: 'RARE'; KMEANS: 'KMEANS'; +AD: 'AD'; // COMMAND ASSIST KEYWORDS AS: 'AS'; @@ -267,6 +268,7 @@ Y: 'Y'; //STRING_LITERAL: DQUOTA_STRING | SQUOTA_STRING | BQUOTA_STRING; ID: ID_LITERAL; INTEGER_LITERAL: DEC_DIGIT+; +DOUBLE_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; DECIMAL_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; fragment DATE_SUFFIX: ([\-.][*0-9]+)*; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 15bcec67dd9..943f5b81bb6 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -20,7 +20,7 @@ pplStatement /** commands */ commands : whereCommand | fieldsCommand | renameCommand | statsCommand | dedupCommand | sortCommand | evalCommand | headCommand - | topCommand | rareCommand | kmeansCommand; + | topCommand | rareCommand | kmeansCommand | adCommand; searchCommand : (SEARCH)? fromClause #searchFrom @@ -89,6 +89,13 @@ kmeansCommand k=integerLiteral ; +adCommand + : AD + (shingle_size=integerLiteral)? + (time_decay=doubleLiteral)? + (time_field=stringLiteral)? + ; + /** clauses */ fromClause : SOURCE EQUAL tableSource @@ -321,6 +328,10 @@ integerLiteral : (PLUS | MINUS)? INTEGER_LITERAL ; +doubleLiteral + : (PLUS | MINUS)? DOUBLE_LITERAL + ; + decimalLiteral : (PLUS | MINUS)? DECIMAL_LITERAL ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d4dbe08061a..411f1993f7e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -31,10 +31,13 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -53,6 +56,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ByClauseContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldListContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -288,6 +292,11 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { return new Kmeans(ArgumentFactory.getArgumentList(ctx)); } + @Override + public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { + return new AD(ArgumentFactory.getArgumentMap(ctx)); + } + /** * Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 59c91a50a5e..17923b96346 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -17,12 +17,16 @@ import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; + import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; @@ -148,6 +152,26 @@ public static List getArgumentList(KmeansCommandContext ctx) { .singletonList(new Argument("k", getArgumentValue(ctx.k))); } + /** + * Get list of {@link Argument}. + * + * @param ctx ADCommandContext instance + * @return the list of arguments fetched from the kmeans command + */ + public static Map getArgumentMap(OpenSearchPPLParser.AdCommandContext ctx) { + return new HashMap() {{ + put("shingle_size", (ctx.shingle_size != null) + ? getArgumentValue(ctx.shingle_size) + : new Literal(8, DataType.INTEGER)); + put("time_decay", (ctx.time_decay != null) + ? getArgumentValue(ctx.time_decay) + : new Literal(0.0001, DataType.DOUBLE)); + put("time_field", (ctx.time_field != null) + ? getArgumentValue(ctx.time_field) + : new Literal(null, DataType.STRING)); + }}; + } + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) From f9cdb0b326f2454d838b70c5567dcad046d5762a Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 24 Feb 2022 21:47:11 -0800 Subject: [PATCH 2/8] Update AD command format Signed-off-by: jackieyanghan --- .../src/main/java/org/opensearch/sql/analysis/Analyzer.java | 2 +- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 3 +++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 9b3c2e2c116..76ba3f3c96c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -386,7 +386,7 @@ public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { } /** - * Build {@link } for Kmeans command. + * Build {@link LogicalAD} for AD command. */ @Override public LogicalPlan visitAD(AD node, AnalysisContext context) { diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 49501835072..96c820ffed5 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -49,6 +49,9 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; +SHINGLE_SIZE: 'SINGLE_SIZE'; +TIME_DECAY: 'TIME_DECAY'; +TIME_FIELD: 'TIME_FIELD'; // COMPARISON FUNCTION KEYWORDS CASE: 'CASE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 943f5b81bb6..2d215694430 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -91,9 +91,9 @@ kmeansCommand adCommand : AD - (shingle_size=integerLiteral)? - (time_decay=doubleLiteral)? - (time_field=stringLiteral)? + (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? + (TIME_DECAY EQUAL time_decay=doubleLiteral)? + (TIME_FIELD EQUAL time_field=stringLiteral)? ; /** clauses */ From 7dc1b4557a47f03d6c97b2498116b2497807899c Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Fri, 25 Feb 2022 15:26:06 -0800 Subject: [PATCH 3/8] Address issues with ad command Signed-off-by: jackieyanghan --- .../org/opensearch/sql/analysis/Analyzer.java | 1 + .../sql/data/model/ExprBooleanValue.java | 2 +- .../sql/planner/logical/LogicalAD.java | 12 +++-- .../planner/physical/ADOperator.java | 49 +++++++++++++++++-- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 3 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 6 +-- .../opensearch/sql/ppl/parser/AstBuilder.java | 2 - .../sql/ppl/utils/ArgumentFactory.java | 11 +++-- 8 files changed, 63 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 76ba3f3c96c..022abb67b42 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -397,6 +397,7 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { currentEnv.define(new Symbol(Namespace.FIELD_NAME, "timestamp"), ExprCoreType.TIMESTAMP); currentEnv.define(new Symbol(Namespace.FIELD_NAME, "score"), ExprCoreType.DOUBLE); currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomaly_grade"), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomalous"), ExprCoreType.BOOLEAN); return new LogicalAD(child, options); } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java index d655c0dabbb..b74be264901 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java @@ -19,7 +19,7 @@ public class ExprBooleanValue extends AbstractExprValue { private final Boolean value; - private ExprBooleanValue(Boolean value) { + public ExprBooleanValue(Boolean value) { this.value = value; } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java index fbdfcb1f695..c8c04b18177 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalAD.java @@ -1,15 +1,12 @@ package org.opensearch.sql.planner.logical; +import java.util.Collections; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Literal; -import java.util.Collections; -import java.util.List; -import java.util.Map; - /* * AD logical plan. */ @@ -19,6 +16,11 @@ public class LogicalAD extends LogicalPlan { private final Map arguments; + /** + * Constructor of LogicalAD. + * @param child child logical plan + * @param arguments arguments of the algorithm + */ public LogicalAD(LogicalPlan child, Map arguments) { super(Collections.singletonList(child)); this.arguments = arguments; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index 223c18298ef..701c6a84f82 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; @@ -34,6 +35,7 @@ import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; import sun.swing.AccumulativeRunnable; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; @@ -95,7 +97,7 @@ public ExprValue next() { ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), inputRowIter.next())); - resultBuilder.putAll(convertRowIntoExprValue( + resultBuilder.putAll(convertResultRowIntoExprValue( predictionResult.getPredictionResult().columnMetas(), resultRowIter.next())); return ExprTupleValue.fromExprValueMap(resultBuilder.build()); @@ -110,17 +112,17 @@ public R accept(PhysicalPlanNodeVisitor visitor, C context) { @Override public boolean hasNext() { - return false; + return iterator.hasNext(); } @Override public ExprValue next() { - return null; + return iterator.next(); } @Override public List getChild() { - return null; + return Collections.singletonList(input); } protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { @@ -135,10 +137,47 @@ protected MLAlgoParams convertArgumentToMLParameter(Map argumen .shingleSize((Integer) arguments.get("shingle_size").getValue()) .timeDecay((Double) arguments.get("time_decay").getValue()) .timeField((String) arguments.get("time_field").getValue()) + .dateFormat("yyyy-MM-dd HH:mm:ss") .build(); } private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + if ("timestamp".equalsIgnoreCase(resultKeyName)) { + resultKeyName = "timestamp1"; + } + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + case BOOLEAN: + resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); + default: + break; + } + } + return resultBuilder.build(); + } + + private Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); for (int i = 0; i < columnMetas.length; i++) { ColumnValue columnValue = row.getValue(i); @@ -162,6 +201,8 @@ private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, case FLOAT: resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); break; + case BOOLEAN: + resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); default: break; } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 96c820ffed5..f2306bb9de7 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -49,7 +49,7 @@ DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; PARTITIONS: 'PARTITIONS'; ALLNUM: 'ALLNUM'; DELIM: 'DELIM'; -SHINGLE_SIZE: 'SINGLE_SIZE'; +SHINGLE_SIZE: 'SHINGLE_SIZE'; TIME_DECAY: 'TIME_DECAY'; TIME_FIELD: 'TIME_FIELD'; @@ -271,7 +271,6 @@ Y: 'Y'; //STRING_LITERAL: DQUOTA_STRING | SQUOTA_STRING | BQUOTA_STRING; ID: ID_LITERAL; INTEGER_LITERAL: DEC_DIGIT+; -DOUBLE_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; DECIMAL_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; fragment DATE_SUFFIX: ([\-.][*0-9]+)*; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 2d215694430..27ce4e19ed1 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -92,7 +92,7 @@ kmeansCommand adCommand : AD (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)? - (TIME_DECAY EQUAL time_decay=doubleLiteral)? + (TIME_DECAY EQUAL time_decay=decimalLiteral)? (TIME_FIELD EQUAL time_field=stringLiteral)? ; @@ -328,10 +328,6 @@ integerLiteral : (PLUS | MINUS)? INTEGER_LITERAL ; -doubleLiteral - : (PLUS | MINUS)? DOUBLE_LITERAL - ; - decimalLiteral : (PLUS | MINUS)? DECIMAL_LITERAL ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 411f1993f7e..85bc7aec94a 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -31,10 +31,8 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; -import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.AD; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 17923b96346..e4641b90394 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -7,6 +7,7 @@ package org.opensearch.sql.ppl.utils; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; @@ -26,8 +27,8 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; /** @@ -153,12 +154,12 @@ public static List getArgumentList(KmeansCommandContext ctx) { } /** - * Get list of {@link Argument}. + * Get map of {@link Argument}. * * @param ctx ADCommandContext instance - * @return the list of arguments fetched from the kmeans command + * @return the list of arguments fetched from the AD command */ - public static Map getArgumentMap(OpenSearchPPLParser.AdCommandContext ctx) { + public static Map getArgumentMap(AdCommandContext ctx) { return new HashMap() {{ put("shingle_size", (ctx.shingle_size != null) ? getArgumentValue(ctx.shingle_size) @@ -177,6 +178,8 @@ private static Literal getArgumentValue(ParserRuleContext ctx) { ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) : ctx instanceof BooleanLiteralContext ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : ctx instanceof DecimalLiteralContext + ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } From 2a41e812b14630c4eae68fc8d0b1ad6842de6188 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Mon, 28 Feb 2022 16:09:03 -0800 Subject: [PATCH 4/8] Separate fit_rcf schema and batch_rcf schema Signed-off-by: jackieyanghan --- .../java/org/opensearch/sql/analysis/Analyzer.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 022abb67b42..477f894e41a 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -394,11 +395,14 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { java.util.Map options = node.getArguments(); TypeEnvironment currentEnv = context.peek(); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "timestamp"), ExprCoreType.TIMESTAMP); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "score"), ExprCoreType.DOUBLE); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomaly_grade"), ExprCoreType.DOUBLE); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomalous"), ExprCoreType.BOOLEAN); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "score"), ExprCoreType.DOUBLE); + if (Objects.isNull(node.getArguments().get("time_field").getValue())) { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomalous"), ExprCoreType.BOOLEAN); + } else { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomaly_grade"), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "timestamp"), ExprCoreType.TIMESTAMP); + } return new LogicalAD(child, options); } From 96f5ba906f8c6fd763e34c7a560d25e2c3ee27cb Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 3 Mar 2022 13:15:16 -0800 Subject: [PATCH 5/8] Add tests for AD command Signed-off-by: jackieyanghan --- .../java/org/opensearch/sql/ast/tree/AD.java | 5 +- .../opensearch/sql/analysis/AnalyzerTest.java | 33 +++++ .../logical/LogicalPlanNodeVisitorTest.java | 13 ++ .../physical/PhysicalPlanNodeVisitorTest.java | 8 ++ .../planner/physical/ADOperator.java | 123 ++++++++---------- .../OpenSearchExecutionProtectorTest.java | 21 +++ .../OpenSearchDefaultImplementorTest.java | 11 ++ .../opensearch/sql/ppl/parser/AstBuilder.java | 2 +- .../sql/ppl/utils/ArgumentFactory.java | 13 +- .../sql/ppl/parser/AstBuilderTest.java | 26 ++++ 10 files changed, 176 insertions(+), 79 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AD.java b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java index 6fd1c1dfd9b..4d1c9ebf531 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/AD.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AD.java @@ -1,6 +1,8 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -10,9 +12,6 @@ import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Literal; -import java.util.List; -import java.util.Map; - @Getter @Setter @ToString diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 2e9a6fe843a..9781204e633 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -34,6 +34,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -41,12 +43,16 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.springframework.context.annotation.Configuration; @@ -657,4 +663,31 @@ public void kmeanns_relation() { AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))) ); } + + @Test + public void ad_batchRCF_relation() { + Map argumentMap = + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } + + @Test + public void ad_fitRCF_relation() { + Map argumentMap = new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + }}; + assertAnalyzeEqual( + new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) + ); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f3fe6b5a84f..1b8d606211c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -12,6 +12,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -19,6 +20,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -115,6 +118,16 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))); assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); + + LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema"), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }); + assertNull(ad.accept(new LogicalPlanNodeVisitor() { + }, null)); } private static class NodesCount extends LogicalPlanNodeVisitor { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 7e86f3e68a1..cd561f3c093 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -142,6 +142,14 @@ public void test_visitMLCommons() { assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null)); } + @Test + public void test_visitAD() { + PhysicalPlanNodeVisitor physicalPlanNodeVisitor = + new PhysicalPlanNodeVisitor() {}; + + assertNull(physicalPlanNodeVisitor.visitAD(plan, null)); + } + public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor { public String print(PhysicalPlan node) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index 701c6a84f82..625029aedb3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -1,6 +1,13 @@ package org.opensearch.sql.opensearch.planner.physical; import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -15,11 +22,9 @@ import org.opensearch.ml.common.parameter.BatchRCFParams; import org.opensearch.ml.common.parameter.FitRCFParams; import org.opensearch.ml.common.parameter.FunctionName; -import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.MLAlgoParams; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprDoubleValue; @@ -33,17 +38,6 @@ import org.opensearch.sql.opensearch.client.MLClient; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; -import sun.swing.AccumulativeRunnable; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; - -import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; /** * AD Physical operator to call AD interface to get results for @@ -94,12 +88,16 @@ public boolean hasNext() { @Override public ExprValue next() { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); + resultSchemaBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), inputRowIter.next())); + Map resultSchema = resultSchemaBuilder.build(); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); resultBuilder.putAll(convertResultRowIntoExprValue( predictionResult.getPredictionResult().columnMetas(), - resultRowIter.next())); + resultRowIter.next(), + resultSchema)); + resultBuilder.putAll(resultSchema); return ExprTupleValue.fromExprValueMap(resultBuilder.build()); } }; @@ -146,66 +144,55 @@ private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, for (int i = 0; i < columnMetas.length; i++) { ColumnValue columnValue = row.getValue(i); String resultKeyName = columnMetas[i].getName(); - if ("timestamp".equalsIgnoreCase(resultKeyName)) { - resultKeyName = "timestamp1"; - } - switch (columnValue.columnType()) { - case INTEGER: - resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); - break; - case DOUBLE: - resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); - break; - case STRING: - resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); - break; - case SHORT: - resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); - break; - case LONG: - resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); - break; - case FLOAT: - resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); - break; - case BOOLEAN: - resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); - default: - break; - } + popluateResultBuilder(columnValue, resultKeyName, resultBuilder); } return resultBuilder.build(); } - private Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + private void popluateResultBuilder(ColumnValue columnValue, + String resultKeyName, + ImmutableMap.Builder resultBuilder) { + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + case BOOLEAN: + resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); + break; + default: + break; + } + } + + private Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, + Row row, + Map schema) { ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); for (int i = 0; i < columnMetas.length; i++) { ColumnValue columnValue = row.getValue(i); String resultKeyName = columnMetas[i].getName(); - switch (columnValue.columnType()) { - case INTEGER: - resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); - break; - case DOUBLE: - resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); - break; - case STRING: - resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); - break; - case SHORT: - resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); - break; - case LONG: - resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); - break; - case FLOAT: - resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); - break; - case BOOLEAN: - resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); - default: - break; + // change key name to avoid duplicate key issue in result map + // only value will be shown in the final returned result + if (schema.containsKey(resultKeyName)) { + resultKeyName = resultKeyName + "1"; } + popluateResultBuilder(columnValue, resultKeyName, resultBuilder); + } return resultBuilder.build(); } @@ -216,7 +203,7 @@ private DataFrame generateInputDataset() { inputData.add(new HashMap() { { input.next().tupleValue().forEach((key, value) - -> put(key, value.value())); + -> put(key, value.value())); } }); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index fce7cc88ed1..cc521ff1f7c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -37,6 +38,8 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.client.MachineLearningClient; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; @@ -55,6 +58,7 @@ import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; @@ -271,6 +275,23 @@ public void testVisitMlCommons() { executionProtector.visitMLCommons(mlCommonsOperator, null)); } + @Test + public void testVisitAD() { + NodeClient nodeClient = mock(NodeClient.class); + ADOperator adOperator = + new ADOperator( + values(emptyList()), + new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + }, nodeClient); + + assertEquals(executionProtector.doProtect(adOperator), + executionProtector.visitAD(adOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 52770df8db6..0770ea3938c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -18,6 +18,7 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -56,4 +57,14 @@ public void visitMachineLearning() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); assertNotNull(implementor.visitMLCommons(node, indexScan)); } + + @Test + public void visitAD() { + LogicalAD node = Mockito.mock(LogicalAD.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + assertNotNull(implementor.visitAD(node, indexScan)); + } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 85bc7aec94a..a55192be34c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -51,10 +51,10 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ByClauseContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldListContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ppl.utils.ArgumentFactory; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index e4641b90394..62a2678e420 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -21,15 +21,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; - import org.antlr.v4.runtime.ParserRuleContext; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; -import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.AdCommandContext; - +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.KmeansCommandContext; /** * Util class to get all arguments as a list from the PPL command. @@ -161,16 +159,17 @@ public static List getArgumentList(KmeansCommandContext ctx) { */ public static Map getArgumentMap(AdCommandContext ctx) { return new HashMap() {{ - put("shingle_size", (ctx.shingle_size != null) + put("shingle_size", (ctx.shingle_size != null) ? getArgumentValue(ctx.shingle_size) : new Literal(8, DataType.INTEGER)); - put("time_decay", (ctx.time_decay != null) + put("time_decay", (ctx.time_decay != null) ? getArgumentValue(ctx.time_decay) : new Literal(0.0001, DataType.DOUBLE)); - put("time_field", (ctx.time_field != null) + put("time_field", (ctx.time_field != null) ? getArgumentValue(ctx.time_field) : new Literal(null, DataType.STRING)); - }}; + } + }; } private static Literal getArgumentValue(ParserRuleContext ctx) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index ce3f327f09d..fda856c3736 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -38,12 +38,16 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; +import java.util.HashMap; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -540,6 +544,28 @@ public void testKmeansCommand() { new Kmeans(relation("t"),exprList(argument("k", intLiteral(3))))); } + @Test + public void test_fitRCFADCommand() { + assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp'", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(10, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + } + })); + } + + @Test + public void test_batchRCFADCommand() { + assertEqual("source=t | AD", + new AD(relation("t"),new HashMap() {{ + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal(null, DataType.STRING)); + } + })); + } + protected void assertEqual(String query, Node expectedPlan) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan); From 192537e7fd68d275b75bd44916a3f39a17009596 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Fri, 4 Mar 2022 13:27:48 -0800 Subject: [PATCH 6/8] Abstract duplicate code in ADOperator and MLCommonsOperator Signed-off-by: jackieyanghan --- core/build.gradle | 4 + .../org/opensearch/sql/analysis/Analyzer.java | 15 +- .../sql/utils/MLCommonsConstants.java | 13 ++ .../planner/physical/ADOperator.java | 132 ++-------------- .../planner/physical/MLCommonsOperator.java | 89 +---------- .../planner/physical/OperatorActions.java | 144 ++++++++++++++++++ .../sql/ppl/utils/ArgumentFactory.java | 13 +- .../sql/ppl/parser/AstBuilderTest.java | 4 +- 8 files changed, 202 insertions(+), 212 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java diff --git a/core/build.gradle b/core/build.gradle index 63ecd8c104a..697a6bb775c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -82,6 +82,10 @@ test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { violationRules { rule { + element = 'CLASS' + excludes = [ + 'org.opensearch.sql.utils.MLCommonsConstants' + ] limit { counter = 'LINE' minimum = 1.0 diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 477f894e41a..0379035c00b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -11,6 +11,11 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; +import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -396,12 +401,12 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) { TypeEnvironment currentEnv = context.peek(); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "score"), ExprCoreType.DOUBLE); - if (Objects.isNull(node.getArguments().get("time_field").getValue())) { - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomalous"), ExprCoreType.BOOLEAN); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); + if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); } else { - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "anomaly_grade"), ExprCoreType.DOUBLE); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "timestamp"), ExprCoreType.TIMESTAMP); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP); } return new LogicalAD(child, options); } diff --git a/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java new file mode 100644 index 00000000000..3e957f1bda0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/utils/MLCommonsConstants.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.utils; + +public class MLCommonsConstants { + + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TIME_DECAY = "time_decay"; + public static final String TIME_FIELD = "time_field"; + + public static final String RCF_SCORE = "score"; + public static final String RCF_ANOMALOUS = "anomalous"; + public static final String RCF_ANOMALY_GRADE = "anomaly_grade"; + public static final String RCF_TIMESTAMP = "timestamp"; +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index 625029aedb3..bf52f478f9f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -1,41 +1,26 @@ package org.opensearch.sql.opensearch.planner.physical; -import com.google.common.collect.ImmutableMap; +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; + import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.dataframe.ColumnMeta; -import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.parameter.BatchRCFParams; import org.opensearch.ml.common.parameter.FitRCFParams; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; -import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.data.model.ExprBooleanValue; -import org.opensearch.sql.data.model.ExprDoubleValue; -import org.opensearch.sql.data.model.ExprFloatValue; -import org.opensearch.sql.data.model.ExprIntegerValue; -import org.opensearch.sql.data.model.ExprLongValue; -import org.opensearch.sql.data.model.ExprShortValue; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.opensearch.client.MLClient; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -45,7 +30,7 @@ */ @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class ADOperator extends PhysicalPlan { +public class ADOperator extends OperatorActions { @Getter private final PhysicalPlan input; @@ -64,20 +49,12 @@ public class ADOperator extends PhysicalPlan { @Override public void open() { super.open(); - DataFrame inputDataFrame = generateInputDataset(); + DataFrame inputDataFrame = generateInputDataset(input); MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments); - MLInput mlinput = MLInput.builder() - .algorithm(rcfType) - .parameters(mlAlgoParams) - .inputDataset(new DataFrameInputDataset(inputDataFrame)) - .build(); + MLPredictionOutput predictionResult = + getMLPredictionResult(rcfType, mlAlgoParams, inputDataFrame, nodeClient); - MachineLearningNodeClient machineLearningClient = - MLClient.getMLClient(nodeClient); - MLPredictionOutput predictionResult = (MLPredictionOutput) machineLearningClient - .trainAndPredict(mlinput) - .actionGet(30, TimeUnit.SECONDS); Iterator inputRowIter = inputDataFrame.iterator(); Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); iterator = new Iterator() { @@ -88,17 +65,7 @@ public boolean hasNext() { @Override public ExprValue next() { - ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); - resultSchemaBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), - inputRowIter.next())); - Map resultSchema = resultSchemaBuilder.build(); - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - resultBuilder.putAll(convertResultRowIntoExprValue( - predictionResult.getPredictionResult().columnMetas(), - resultRowIter.next(), - resultSchema)); - resultBuilder.putAll(resultSchema); - return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); } }; } @@ -124,90 +91,19 @@ public List getChild() { } protected MLAlgoParams convertArgumentToMLParameter(Map arguments) { - if (arguments.get("time_field").getValue() == null) { + if (arguments.get(TIME_FIELD).getValue() == null) { rcfType = FunctionName.BATCH_RCF; return BatchRCFParams.builder() - .shingleSize((Integer) arguments.get("shingle_size").getValue()) + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) .build(); } rcfType = FunctionName.FIT_RCF; return FitRCFParams.builder() - .shingleSize((Integer) arguments.get("shingle_size").getValue()) - .timeDecay((Double) arguments.get("time_decay").getValue()) - .timeField((String) arguments.get("time_field").getValue()) + .shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue()) + .timeDecay((Double) arguments.get(TIME_DECAY).getValue()) + .timeField((String) arguments.get(TIME_FIELD).getValue()) .dateFormat("yyyy-MM-dd HH:mm:ss") .build(); } - private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - for (int i = 0; i < columnMetas.length; i++) { - ColumnValue columnValue = row.getValue(i); - String resultKeyName = columnMetas[i].getName(); - popluateResultBuilder(columnValue, resultKeyName, resultBuilder); - } - return resultBuilder.build(); - } - - private void popluateResultBuilder(ColumnValue columnValue, - String resultKeyName, - ImmutableMap.Builder resultBuilder) { - switch (columnValue.columnType()) { - case INTEGER: - resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); - break; - case DOUBLE: - resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); - break; - case STRING: - resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); - break; - case SHORT: - resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); - break; - case LONG: - resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); - break; - case FLOAT: - resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); - break; - case BOOLEAN: - resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); - break; - default: - break; - } - } - - private Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, - Row row, - Map schema) { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - for (int i = 0; i < columnMetas.length; i++) { - ColumnValue columnValue = row.getValue(i); - String resultKeyName = columnMetas[i].getName(); - // change key name to avoid duplicate key issue in result map - // only value will be shown in the final returned result - if (schema.containsKey(resultKeyName)) { - resultKeyName = resultKeyName + "1"; - } - popluateResultBuilder(columnValue, resultKeyName, resultBuilder); - - } - return resultBuilder.build(); - } - - private DataFrame generateInputDataset() { - List> inputData = new LinkedList<>(); - while (input.hasNext()) { - inputData.add(new HashMap() { - { - input.next().tupleValue().forEach((key, value) - -> put(key, value.value())); - } - }); - } - - return DataFrameBuilder.load(inputData); - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 5401298070d..7342aabb687 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -7,40 +7,21 @@ import static org.opensearch.ml.common.parameter.FunctionName.KMEANS; -import com.google.common.collect.ImmutableMap; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.dataframe.ColumnMeta; -import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.MLAlgoParams; -import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.data.model.ExprDoubleValue; -import org.opensearch.sql.data.model.ExprFloatValue; -import org.opensearch.sql.data.model.ExprIntegerValue; -import org.opensearch.sql.data.model.ExprLongValue; -import org.opensearch.sql.data.model.ExprShortValue; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.opensearch.client.MLClient; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -50,7 +31,7 @@ */ @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class MLCommonsOperator extends PhysicalPlan { +public class MLCommonsOperator extends OperatorActions { @Getter private final PhysicalPlan input; @@ -69,19 +50,12 @@ public class MLCommonsOperator extends PhysicalPlan { @Override public void open() { super.open(); - DataFrame inputDataFrame = generateInputDataset(); + DataFrame inputDataFrame = generateInputDataset(input); MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm); - MLInput mlinput = MLInput.builder() - .algorithm(FunctionName.valueOf(algorithm.toUpperCase())) - .parameters(mlAlgoParams) - .inputDataset(new DataFrameInputDataset(inputDataFrame)) - .build(); - - MachineLearningNodeClient machineLearningClient = - MLClient.getMLClient(nodeClient); - MLPredictionOutput predictionResult = (MLPredictionOutput) machineLearningClient - .trainAndPredict(mlinput) - .actionGet(30, TimeUnit.SECONDS); + MLPredictionOutput predictionResult = + getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()), + mlAlgoParams, inputDataFrame, nodeClient); + Iterator inputRowIter = inputDataFrame.iterator(); Iterator resultRowIter = predictionResult.getPredictionResult().iterator(); iterator = new Iterator() { @@ -92,13 +66,7 @@ public boolean hasNext() { @Override public ExprValue next() { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - resultBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), - inputRowIter.next())); - resultBuilder.putAll(convertRowIntoExprValue( - predictionResult.getPredictionResult().columnMetas(), - resultRowIter.next())); - return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + return buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter); } }; } @@ -140,48 +108,5 @@ protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String al } } - private Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { - ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); - for (int i = 0; i < columnMetas.length; i++) { - ColumnValue columnValue = row.getValue(i); - String resultKeyName = columnMetas[i].getName(); - switch (columnValue.columnType()) { - case INTEGER: - resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); - break; - case DOUBLE: - resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); - break; - case STRING: - resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); - break; - case SHORT: - resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); - break; - case LONG: - resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); - break; - case FLOAT: - resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); - break; - default: - break; - } - } - return resultBuilder.build(); - } - - private DataFrame generateInputDataset() { - List> inputData = new LinkedList<>(); - while (input.hasNext()) { - inputData.add(new HashMap() { - { - input.next().tupleValue().forEach((key, value) -> put(key, value.value())); - } - }); - } - - return DataFrameBuilder.load(inputData); - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java new file mode 100644 index 00000000000..2aba2497752 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java @@ -0,0 +1,144 @@ +package org.opensearch.sql.opensearch.planner.physical; + +import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import lombok.EqualsAndHashCode; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.dataframe.ColumnMeta; +import org.opensearch.ml.common.dataframe.ColumnValue; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLPredictionOutput; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.opensearch.client.MLClient; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +public abstract class OperatorActions extends PhysicalPlan { + + @EqualsAndHashCode.Exclude + private Iterator iterator; + + protected DataFrame generateInputDataset(PhysicalPlan input) { + List> inputData = new LinkedList<>(); + while (input.hasNext()) { + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); + } + + return DataFrameBuilder.load(inputData); + } + + protected Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + } + return resultBuilder.build(); + } + + protected void populateResultBuilder(ColumnValue columnValue, + String resultKeyName, + ImmutableMap.Builder resultBuilder) { + switch (columnValue.columnType()) { + case INTEGER: + resultBuilder.put(resultKeyName, new ExprIntegerValue(columnValue.intValue())); + break; + case DOUBLE: + resultBuilder.put(resultKeyName, new ExprDoubleValue(columnValue.doubleValue())); + break; + case STRING: + resultBuilder.put(resultKeyName, new ExprStringValue(columnValue.stringValue())); + break; + case SHORT: + resultBuilder.put(resultKeyName, new ExprShortValue(columnValue.shortValue())); + break; + case LONG: + resultBuilder.put(resultKeyName, new ExprLongValue(columnValue.longValue())); + break; + case FLOAT: + resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); + break; + case BOOLEAN: + resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); + break; + default: + break; + } + } + + protected Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, + Row row, + Map schema) { + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + for (int i = 0; i < columnMetas.length; i++) { + ColumnValue columnValue = row.getValue(i); + String resultKeyName = columnMetas[i].getName(); + // change key name to avoid duplicate key issue in result map + // only value will be shown in the final returned result + if (schema.containsKey(resultKeyName)) { + resultKeyName = resultKeyName + "1"; + } + populateResultBuilder(columnValue, resultKeyName, resultBuilder); + + } + return resultBuilder.build(); + } + + protected ExprTupleValue buildResult(Iterator inputRowIter, DataFrame inputDataFrame, + MLPredictionOutput predictionResult, Iterator resultRowIter) { + ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); + resultSchemaBuilder.putAll(convertRowIntoExprValue(inputDataFrame.columnMetas(), + inputRowIter.next())); + Map resultSchema = resultSchemaBuilder.build(); + ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); + resultBuilder.putAll(convertResultRowIntoExprValue( + predictionResult.getPredictionResult().columnMetas(), + resultRowIter.next(), + resultSchema)); + resultBuilder.putAll(resultSchema); + return ExprTupleValue.fromExprValueMap(resultBuilder.build()); + } + + protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, + MLAlgoParams mlAlgoParams, + DataFrame inputDataFrame, + NodeClient nodeClient) { + MLInput mlinput = MLInput.builder() + .algorithm(functionName) + .parameters(mlAlgoParams) + .inputDataset(new DataFrameInputDataset(inputDataFrame)) + .build(); + + MachineLearningNodeClient machineLearningClient = + MLClient.getMLClient(nodeClient); + + return (MLPredictionOutput) machineLearningClient + .trainAndPredict(mlinput) + .actionGet(30, TimeUnit.SECONDS); + } + +} diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 62a2678e420..09cef7c9110 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -15,6 +15,9 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; +import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; +import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; import java.util.Arrays; import java.util.Collections; @@ -159,13 +162,13 @@ public static List getArgumentList(KmeansCommandContext ctx) { */ public static Map getArgumentMap(AdCommandContext ctx) { return new HashMap() {{ - put("shingle_size", (ctx.shingle_size != null) + put(SHINGLE_SIZE, (ctx.shingle_size != null) ? getArgumentValue(ctx.shingle_size) - : new Literal(8, DataType.INTEGER)); - put("time_decay", (ctx.time_decay != null) + : new Literal(null, DataType.INTEGER)); + put(TIME_DECAY, (ctx.time_decay != null) ? getArgumentValue(ctx.time_decay) - : new Literal(0.0001, DataType.DOUBLE)); - put("time_field", (ctx.time_field != null) + : new Literal(null, DataType.DOUBLE)); + put(TIME_FIELD, (ctx.time_field != null) ? getArgumentValue(ctx.time_field) : new Literal(null, DataType.STRING)); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index fda856c3736..274a3ace774 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -559,8 +559,8 @@ public void test_fitRCFADCommand() { public void test_batchRCFADCommand() { assertEqual("source=t | AD", new AD(relation("t"),new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("shingle_size", new Literal(null, DataType.INTEGER)); + put("time_decay", new Literal(null, DataType.DOUBLE)); put("time_field", new Literal(null, DataType.STRING)); } })); From 54660d80de248cb9d00a82a3dd362a599d13e87e Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Fri, 4 Mar 2022 13:54:04 -0800 Subject: [PATCH 7/8] Update dependency opensearch-ml-client version Signed-off-by: jackieyanghan --- opensearch/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensearch/build.gradle b/opensearch/build.gradle index c2905c54f33..9446786bfd0 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,7 +37,7 @@ dependencies { compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.11.4' compile group: 'org.json', name: 'json', version:'20180813' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" - compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0' + compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0-SNAPSHOT' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' From fa1ef167ff0238618a357e208630497e56eeb6e4 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Tue, 8 Mar 2022 12:15:04 -0800 Subject: [PATCH 8/8] Add java doc on MLCommonsOperatorActions abstract class Signed-off-by: jackieyanghan --- .../sql/data/model/ExprBooleanValue.java | 2 +- .../planner/physical/ADOperator.java | 2 +- .../planner/physical/MLCommonsOperator.java | 2 +- ...ons.java => MLCommonsOperatorActions.java} | 51 ++++++++++++++++--- .../OpenSearchExecutionProtectorTest.java | 3 -- 5 files changed, 48 insertions(+), 12 deletions(-) rename opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/{OperatorActions.java => MLCommonsOperatorActions.java} (79%) diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java index b74be264901..d655c0dabbb 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java @@ -19,7 +19,7 @@ public class ExprBooleanValue extends AbstractExprValue { private final Boolean value; - public ExprBooleanValue(Boolean value) { + private ExprBooleanValue(Boolean value) { this.value = value; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java index bf52f478f9f..388b4a47750 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ADOperator.java @@ -30,7 +30,7 @@ */ @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class ADOperator extends OperatorActions { +public class ADOperator extends MLCommonsOperatorActions { @Getter private final PhysicalPlan input; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java index 7342aabb687..75870b5ee13 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperator.java @@ -31,7 +31,7 @@ */ @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class MLCommonsOperator extends OperatorActions { +public class MLCommonsOperator extends MLCommonsOperatorActions { @Getter private final PhysicalPlan input; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java similarity index 79% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java index 2aba2497752..201b9c5ec7c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OperatorActions.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -7,7 +7,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import lombok.EqualsAndHashCode; import org.opensearch.client.node.NodeClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.dataframe.ColumnMeta; @@ -32,11 +31,16 @@ import org.opensearch.sql.opensearch.client.MLClient; import org.opensearch.sql.planner.physical.PhysicalPlan; -public abstract class OperatorActions extends PhysicalPlan { - - @EqualsAndHashCode.Exclude - private Iterator iterator; +/** + * Common method actions for ml-commons related operators. + */ +public abstract class MLCommonsOperatorActions extends PhysicalPlan { + /** + * generate ml-commons request input dataset. + * @param input physical input + * @return ml-commons dataframe + */ protected DataFrame generateInputDataset(PhysicalPlan input) { List> inputData = new LinkedList<>(); while (input.hasNext()) { @@ -50,6 +54,12 @@ protected DataFrame generateInputDataset(PhysicalPlan input) { return DataFrameBuilder.load(inputData); } + /** + * covert result schema into ExprValue. + * @param columnMetas column metas + * @param row row + * @return a map of result schema in ExprValue format + */ protected Map convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) { ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder<>(); for (int i = 0; i < columnMetas.length; i++) { @@ -60,6 +70,12 @@ protected Map convertRowIntoExprValue(ColumnMeta[] columnMeta return resultBuilder.build(); } + /** + * populate result map by ml-commons supported data type. + * @param columnValue column value + * @param resultKeyName result kay name + * @param resultBuilder result builder + */ protected void populateResultBuilder(ColumnValue columnValue, String resultKeyName, ImmutableMap.Builder resultBuilder) { @@ -83,13 +99,20 @@ protected void populateResultBuilder(ColumnValue columnValue, resultBuilder.put(resultKeyName, new ExprFloatValue(columnValue.floatValue())); break; case BOOLEAN: - resultBuilder.put(resultKeyName, new ExprBooleanValue(columnValue.booleanValue())); + resultBuilder.put(resultKeyName, ExprBooleanValue.of(columnValue.booleanValue())); break; default: break; } } + /** + * concert result into ExprValue. + * @param columnMetas column metas + * @param row row + * @param schema schema + * @return a map of result in ExprValue format + */ protected Map convertResultRowIntoExprValue(ColumnMeta[] columnMetas, Row row, Map schema) { @@ -108,6 +131,14 @@ protected Map convertResultRowIntoExprValue(ColumnMeta[] colu return resultBuilder.build(); } + /** + * iterate result and built it into ExprTupleValue. + * @param inputRowIter input row iterator + * @param inputDataFrame input data frame + * @param predictionResult prediction result + * @param resultRowIter result row iterator + * @return result in ExprTupleValue format + */ protected ExprTupleValue buildResult(Iterator inputRowIter, DataFrame inputDataFrame, MLPredictionOutput predictionResult, Iterator resultRowIter) { ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder<>(); @@ -123,6 +154,14 @@ protected ExprTupleValue buildResult(Iterator inputRowIter, DataFrame input return ExprTupleValue.fromExprValueMap(resultBuilder.build()); } + /** + * get ml-commons train and predict result. + * @param functionName ml-commons algorithm name + * @param mlAlgoParams ml-commons algorithm parameters + * @param inputDataFrame input data frame + * @param nodeClient node client + * @return ml-commons train and predict result + */ protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, MLAlgoParams mlAlgoParams, DataFrame inputDataFrame, diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index cc521ff1f7c..2427ac4fe5c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -36,7 +36,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.client.MachineLearningClient; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; @@ -56,8 +55,6 @@ import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings;