Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ test.finalizedBy(project.tasks.jacocoTestReport)
jacocoTestCoverageVerification {
violationRules {
rule {
element = 'CLASS'
excludes = [
'org.opensearch.sql.utils.MLCommonsConstants'
]
limit {
counter = 'LINE'
minimum = 1.0
Expand Down
44 changes: 44 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -31,11 +37,13 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand All @@ -58,11 +66,13 @@
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalDedupe;
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
import org.opensearch.sql.planner.logical.LogicalRareTopN;
Expand Down Expand Up @@ -395,6 +405,40 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) {
return new LogicalValues(valueExprs);
}

/**
* Build {@link LogicalMLCommons} for Kmeans command.
*/
@Override
public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<Argument> options = node.getOptions();

TypeEnvironment currentEnv = context.peek();
currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER);

return new LogicalMLCommons(child, "kmeans", options);
}

/**
* Build {@link LogicalAD} for AD command.
*/
@Override
public LogicalPlan visitAD(AD node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
java.util.Map<String, Literal> options = node.getArguments();

TypeEnvironment currentEnv = context.peek();

currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE);
if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN);
} else {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE);
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP);
}
return new LogicalAD(child, options);
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand Down Expand Up @@ -239,4 +241,12 @@ public T visitLimit(Limit node, C context) {
public T visitSpan(Span node, C context) {
return visitChildren(node, context);
}

public T visitKmeans(Kmeans node, C context) {
return visitChildren(node, context);
}

public T visitAD(AD node, C context) {
return visitChildren(node, context);
}
}
8 changes: 8 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ public static Literal longLiteral(Long value) {
return literal(value, DataType.LONG);
}

public static Literal shortLiteral(Short value) {
return literal(value, DataType.SHORT);
}

public static Literal floatLiteral(Float value) {
return literal(value, DataType.FLOAT);
}

public static Literal dateLiteral(String value) {
return literal(value, DataType.DATE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public enum DataType {

INTEGER(ExprCoreType.INTEGER),
LONG(ExprCoreType.LONG),
SHORT(ExprCoreType.SHORT),
FLOAT(ExprCoreType.FLOAT),
DOUBLE(ExprCoreType.DOUBLE),
STRING(ExprCoreType.STRING),
BOOLEAN(ExprCoreType.BOOLEAN),
Expand Down
47 changes: 47 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/AD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class AD extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitAD(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
46 changes: 46 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class Kmeans extends UnresolvedPlan {
private UnresolvedPlan child;

private final List<Argument> options;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitKmeans(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/*
* AD logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalAD extends LogicalPlan {
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalAD.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalAD(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitAD(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Argument;

/**
* ml-commons logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalMLCommons extends LogicalPlan {
private final String algorithm;

private final List<Argument> arguments;

/**
* Constructor of LogicalMLCommons.
* @param child child logical plan
* @param algorithm algorithm name
* @param arguments arguments of the algorithm
*/
public LogicalMLCommons(LogicalPlan child, String algorithm,
List<Argument> arguments) {
super(Collections.singletonList(child));
this.algorithm = algorithm;
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitMLCommons(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,12 @@ public R visitRareTopN(LogicalRareTopN plan, C context) {
public R visitLimit(LogicalLimit plan, C context) {
return visitNode(plan, context);
}

public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,13 @@ public R visitLimit(LimitOperator node, C context) {
return visitNode(node, context);
}

public R visitMLCommons(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.sql.utils;

public class MLCommonsConstants {

public static final String SHINGLE_SIZE = "shingle_size";
public static final String TIME_DECAY = "time_decay";
public static final String TIME_FIELD = "time_field";

public static final String RCF_SCORE = "score";
public static final String RCF_ANOMALOUS = "anomalous";
public static final String RCF_ANOMALY_GRADE = "anomaly_grade";
public static final String RCF_TIMESTAMP = "timestamp";
}
Loading