Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ test.finalizedBy(project.tasks.jacocoTestReport)
jacocoTestCoverageVerification {
violationRules {
rule {
element = 'CLASS'
excludes = [
'org.opensearch.sql.utils.MLCommonsConstants'
]
limit {
counter = 'LINE'
minimum = 1.0
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -31,6 +37,7 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
Expand All @@ -57,6 +64,7 @@
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalDedupe;
import org.opensearch.sql.planner.logical.LogicalEval;
Expand Down Expand Up @@ -383,6 +391,26 @@ public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
return new LogicalMLCommons(child, "kmeans", options);
}

/**
* Build {@link LogicalAD} for AD command.
*/
@Override
public LogicalPlan visitAD(AD node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
java.util.Map<String, Literal> options = node.getArguments();

TypeEnvironment currentEnv = context.peek();

currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE);
if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN);
} else {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE);
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP);
}
return new LogicalAD(child, options);
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
Expand Down Expand Up @@ -239,4 +240,8 @@ public T visitSpan(Span node, C context) {
public T visitKmeans(Kmeans node, C context) {
return visitChildren(node, context);
}

public T visitAD(AD node, C context) {
return visitChildren(node, context);
}
}
41 changes: 41 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/AD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class AD extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitAD(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/*
* AD logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalAD extends LogicalPlan {
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalAD.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalAD(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitAD(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ public R visitLimit(LogicalLimit plan, C context) {
public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,9 @@ public R visitMLCommons(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.sql.utils;

public class MLCommonsConstants {

public static final String SHINGLE_SIZE = "shingle_size";
public static final String TIME_DECAY = "time_decay";
public static final String TIME_FIELD = "time_field";

public static final String RCF_SCORE = "score";
public static final String RCF_ANOMALOUS = "anomalous";
public static final String RCF_ANOMALY_GRADE = "anomaly_grade";
public static final String RCF_TIMESTAMP = "timestamp";
}
33 changes: 33 additions & 0 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,25 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.opensearch.sql.expression.window.WindowDefinition;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlanDSL;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -657,4 +663,31 @@ public void kmeanns_relation() {
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))))
);
}

@Test
public void ad_batchRCF_relation() {
Map<String, Literal> argumentMap =
new HashMap<String, Literal>() {{
put("shingle_size", new Literal(8, DataType.INTEGER));
put("time_decay", new Literal(0.0001, DataType.DOUBLE));
put("time_field", new Literal(null, DataType.STRING));
}};
assertAnalyzeEqual(
new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap),
new AD(AstDSL.relation("schema"), argumentMap)
);
}

@Test
public void ad_fitRCF_relation() {
Map<String, Literal> argumentMap = new HashMap<String, Literal>() {{
put("shingle_size", new Literal(8, DataType.INTEGER));
put("time_decay", new Literal(0.0001, DataType.DOUBLE));
put("time_field", new Literal("timestamp", DataType.STRING));
}};
assertAnalyzeEqual(
new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap),
new AD(AstDSL.relation("schema"), argumentMap)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -115,6 +118,16 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() {
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))));
assertNull(mlCommons.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema"),
new HashMap<String, Literal>() {{
put("shingle_size", new Literal(8, DataType.INTEGER));
put("time_decay", new Literal(0.0001, DataType.DOUBLE));
put("time_field", new Literal(null, DataType.STRING));
}
});
assertNull(ad.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));
}

private static class NodesCount extends LogicalPlanNodeVisitor<Integer, Object> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ public void test_visitMLCommons() {
assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null));
}

@Test
public void test_visitAD() {
PhysicalPlanNodeVisitor physicalPlanNodeVisitor =
new PhysicalPlanNodeVisitor<Integer, Object>() {};

assertNull(physicalPlanNodeVisitor.visitAD(plan, null));
}

public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor<String, Integer> {

public String print(PhysicalPlan node) {
Expand Down
2 changes: 1 addition & 1 deletion opensearch/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies {
compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.11.4'
compile group: 'org.json', name: 'json', version:'20180813'
compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}"
compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0'
compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0-SNAPSHOT'

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import lombok.RequiredArgsConstructor;
import org.opensearch.sql.monitor.ResourceMonitor;
import org.opensearch.sql.opensearch.planner.physical.ADOperator;
import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator;
import org.opensearch.sql.planner.physical.AggregationOperator;
import org.opensearch.sql.planner.physical.DedupeOperator;
Expand Down Expand Up @@ -137,6 +138,17 @@ public PhysicalPlan visitMLCommons(PhysicalPlan node, Object context) {
);
}

@Override
public PhysicalPlan visitAD(PhysicalPlan node, Object context) {
ADOperator adOperator = (ADOperator) node;
return doProtect(
new ADOperator(visitInput(adOperator.getInput(), context),
adOperator.getArguments(),
adOperator.getNodeClient()
)
);
}

PhysicalPlan visitInput(PhysicalPlan node, Object context) {
if (null == node) {
return node;
Expand Down
Loading