Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/sql-cli-release-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ jobs:
- name: Checkout SQL CLI
uses: actions/checkout@v2

# dependencies: ml-commons
- name: Checkout ml-commons
uses: actions/checkout@v2
with:
repository: 'opensearch-project/ml-commons'
path: ml-commons
ref: 'main'
- name: Build ml-commons
working-directory: ./ml-commons
run: ./gradlew publishToMavenLocal

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/sql-cli-test-and-build-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ jobs:
- name: Checkout SQL CLI
uses: actions/checkout@v2

# dependencies: ml-commons
- name: Checkout ml-commons
uses: actions/checkout@v2
with:
repository: 'opensearch-project/ml-commons'
path: ml-commons
ref: 'main'
- name: Build ml-commons
working-directory: ./ml-commons
run: ./gradlew publishToMavenLocal

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/sql-release-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ jobs:
runs-on: ubuntu-latest

steps:
# dependencies: ml-commons
- name: Checkout ml-commons
uses: actions/checkout@v2
with:
repository: 'opensearch-project/ml-commons'
path: ml-commons
ref: 'main'
- name: Build ml-commons
working-directory: ./ml-commons
run: ./gradlew publishToMavenLocal

- name: Checkout SQL
uses: actions/checkout@v1

Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/sql-test-and-build-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ jobs:
uses: actions/setup-java@v1
with:
java-version: 1.14

# dependencies: ml-commons
- name: Checkout ml-commons
uses: actions/checkout@v2
with:
repository: 'opensearch-project/ml-commons'
path: ml-commons
ref: 'main'
- name: Build ml-commons
working-directory: ./ml-commons
run: ./gradlew publishToMavenLocal

- name: Build with Gradle
run: ./gradlew build assemble -Dopensearch.version=${{ env.OPENSEARCH_VERSION }}
Expand Down
17 changes: 17 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand All @@ -47,6 +48,7 @@
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.data.model.ExprMissingValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
Expand All @@ -60,6 +62,7 @@
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
import org.opensearch.sql.planner.logical.LogicalRareTopN;
Expand Down Expand Up @@ -366,6 +369,20 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) {
return new LogicalValues(valueExprs);
}

/**
* Build {@link LogicalMLCommons} for Kmeans command.
*/
@Override
public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<Argument> options = node.getOptions();

TypeEnvironment currentEnv = context.peek();
currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER);

return new LogicalMLCommons(child, "kmeans", options);
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareTopN;
Expand Down Expand Up @@ -234,4 +235,8 @@ public T visitLimit(Limit node, C context) {
public T visitSpan(Span node, C context) {
return visitChildren(node, context);
}

public T visitKmeans(Kmeans node, C context) {
return visitChildren(node, context);
}
}
8 changes: 8 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ public static Literal longLiteral(Long value) {
return literal(value, DataType.LONG);
}

public static Literal shortLiteral(Short value) {
return literal(value, DataType.SHORT);
}

public static Literal floatLiteral(Float value) {
return literal(value, DataType.FLOAT);
}

public static Literal dateLiteral(String value) {
return literal(value, DataType.DATE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public enum DataType {

INTEGER(ExprCoreType.INTEGER),
LONG(ExprCoreType.LONG),
SHORT(ExprCoreType.SHORT),
FLOAT(ExprCoreType.FLOAT),
DOUBLE(ExprCoreType.DOUBLE),
STRING(ExprCoreType.STRING),
BOOLEAN(ExprCoreType.BOOLEAN),
Expand Down
40 changes: 40 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class Kmeans extends UnresolvedPlan {
private UnresolvedPlan child;

private final List<Argument> options;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitKmeans(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Argument;

/**
* ml-commons logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalMLCommons extends LogicalPlan {
private final String algorithm;

private final List<Argument> arguments;

/**
* Constructor of LogicalMLCommons.
* @param child child logical plan
* @param algorithm algorithm name
* @param arguments arguments of the algorithm
*/
public LogicalMLCommons(LogicalPlan child, String algorithm,
List<Argument> arguments) {
super(Collections.singletonList(child));
this.algorithm = algorithm;
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitMLCommons(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,8 @@ public R visitRareTopN(LogicalRareTopN plan, C context) {
public R visitLimit(LogicalLimit plan, C context) {
return visitNode(plan, context);
}

public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ public R visitLimit(LimitOperator node, C context) {
return visitNode(node, context);
}

public R visitMLCommons(PhysicalPlan node, C context) {
return visitNode(node, context);
}

}
13 changes: 13 additions & 0 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.opensearch.sql.expression.window.WindowDefinition;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlanDSL;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.ContextConfiguration;
Expand Down Expand Up @@ -644,4 +646,15 @@ public void named_aggregator_with_condition() {
)
);
}

@Test
public void kmeanns_relation() {
assertAnalyzeEqual(
new LogicalMLCommons(LogicalPlanDSL.relation("schema"),
"kmeans",
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))),
new Kmeans(AstDSL.relation("schema"),
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -108,6 +109,12 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() {
relation, CommandType.TOP, ImmutableList.of(expression), expression);
assertNull(rareTopN.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"),
"kmeans",
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))));
assertNull(mlCommons.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));
}

private static class NodesCount extends LogicalPlanNodeVisitor<Integer, Object> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ public void test_PhysicalPlanVisitor_should_return_null() {
}, null));
}

@Test
public void test_visitMLCommons() {
PhysicalPlanNodeVisitor physicalPlanNodeVisitor =
new PhysicalPlanNodeVisitor<Integer, Object>() {};

assertNull(physicalPlanNodeVisitor.visitMLCommons(plan, null));
}

public static class PhysicalPlanPrinter extends PhysicalPlanNodeVisitor<String, Integer> {

public String print(PhysicalPlan node) {
Expand Down
5 changes: 4 additions & 1 deletion opensearch/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies {
compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: '2.11.4'
compile group: 'org.json', name: 'json', version:'20180813'
compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}"
compile group: 'org.opensearch.ml', name:'opensearch-ml-client', version: '1.3.0.0'

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1'
Expand Down Expand Up @@ -75,7 +76,9 @@ jacocoTestCoverageVerification {
rule {
element = 'CLASS'
excludes = [
'org.opensearch.sql.opensearch.security.SecurityAccess'
'org.opensearch.sql.opensearch.security.SecurityAccess',
'org.opensearch.sql.opensearch.planner.physical.*',
'org.opensearch.sql.opensearch.client.MLClient'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.opensearch.sql.opensearch.client;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.client.MachineLearningNodeClient;


public class MLClient {
private static MachineLearningNodeClient INSTANCE;

private MLClient() {

}

/**
* get machine learning client.
* @param nodeClient node client
* @return machine learning client
*/
public static MachineLearningNodeClient getMLClient(NodeClient nodeClient) {
if (INSTANCE == null) {
INSTANCE = new MachineLearningNodeClient(nodeClient);
}
return INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import java.util.List;
import java.util.Map;
import org.opensearch.client.node.NodeClient;
import org.opensearch.sql.opensearch.mapping.IndexMapping;
import org.opensearch.sql.opensearch.request.OpenSearchRequest;
import org.opensearch.sql.opensearch.response.OpenSearchResponse;
Expand Down Expand Up @@ -64,4 +65,6 @@ public interface OpenSearchClient {
* @param task task
*/
void schedule(Runnable task);

NodeClient getNodeClient();
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ public void schedule(Runnable task) {
);
}

@Override
public NodeClient getNodeClient() {
return client;
}

private String[] resolveIndexExpression(ClusterState state, String[] indices) {
return resolver.concreteIndexNames(state, IndicesOptions.strictExpandOpen(), true, indices);
}
Expand Down
Loading