diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 7fd4560bf0e..6f75810495b 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -38,6 +38,7 @@ import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -348,6 +349,10 @@ public T visitSubqueryAlias(SubqueryAlias node, C context) { return visitChildren(node, context); } + public T visitScalarSubquery(ScalarSubquery node, C context) { + return visitChildren(node, context); + } + public T visitExistsSubquery(ExistsSubquery node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java index ec23ba421d4..e9976ceffa7 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java @@ -16,9 +16,9 @@ import org.opensearch.sql.common.utils.StringUtils; @Getter -@EqualsAndHashCode(callSuper = false) +@EqualsAndHashCode(callSuper = true) @RequiredArgsConstructor -public class ExistsSubquery extends UnresolvedExpression { +public class ExistsSubquery extends SubqueryExpression { private final UnresolvedPlan query; @Override diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java index 74ca0584fd9..cc0a06ed01a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java @@ -15,9 +15,9 @@ import org.opensearch.sql.common.utils.StringUtils; @Getter -@EqualsAndHashCode(callSuper = false) +@EqualsAndHashCode(callSuper = true) @RequiredArgsConstructor -public class InSubquery extends UnresolvedExpression { +public class InSubquery extends SubqueryExpression { private final List value; private final UnresolvedPlan query; diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java new file mode 100644 index 00000000000..a1af8d8e657 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression.subquery; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +@RequiredArgsConstructor +public class ScalarSubquery extends SubqueryExpression { + private final UnresolvedPlan query; + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitScalarSubquery(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/subquery/SubqueryExpression.java b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/SubqueryExpression.java new file mode 100644 index 00000000000..dcf99fbdb91 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/subquery/SubqueryExpression.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression.subquery; + +import lombok.EqualsAndHashCode; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +/** Basic class of subquery expression */ +@EqualsAndHashCode(callSuper = false) +public abstract class SubqueryExpression extends UnresolvedExpression {} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 748028ce604..94cf976156a 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -31,14 +32,14 @@ import org.apache.calcite.util.Holder; import org.checkerframework.checker.nullness.qual.Nullable; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; @@ -91,14 +92,14 @@ private RelBuilder scan(RelOptTable tableSchema, CalcitePlanContext context) { @Override public RelNode visitFilter(Filter node, CalcitePlanContext context) { visitChildren(node, context); - boolean containsExistsSubquery = containsExistsSubquery(node.getCondition()); + boolean containsSubqueryExpression = containsSubqueryExpression(node.getCondition()); final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); - if (containsExistsSubquery) { + if (containsSubqueryExpression) { context.relBuilder.variable(v::set); context.pushCorrelVar(v.get()); } RexNode condition = rexVisitor.analyze(node.getCondition(), context); - if (containsExistsSubquery) { + if (containsSubqueryExpression) { context.relBuilder.filter(ImmutableList.of(v.get().id), condition); context.popCorrelVar(); } else { @@ -107,15 +108,20 @@ public RelNode visitFilter(Filter node, CalcitePlanContext context) { return context.relBuilder.peek(); } - private boolean containsExistsSubquery(Object condition) { - if (condition instanceof ExistsSubquery) { + private boolean containsSubqueryExpression(Node expr) { + if (expr == null) { + return false; + } + if (expr instanceof SubqueryExpression) { return true; } - if (condition instanceof Not n) { - return containsExistsSubquery(n.getExpression()); + if (expr instanceof Let l) { + return containsSubqueryExpression(l.getExpression()); } - if (condition instanceof Compare c) { - return containsExistsSubquery(c.getLeft()) || containsExistsSubquery(c.getRight()); + for (Node child : expr.getChild()) { + if (containsSubqueryExpression(child)) { + return true; + } } return false; } @@ -187,8 +193,25 @@ public RelNode visitEval(Eval node, CalcitePlanContext context) { node.getExpressionList().stream() .map( expr -> { + boolean containsSubqueryExpression = containsSubqueryExpression(expr); + final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); + if (containsSubqueryExpression) { + context.relBuilder.variable(v::set); + context.pushCorrelVar(v.get()); + } RexNode eval = rexVisitor.analyze(expr, context); - context.relBuilder.projectPlus(eval); + if (containsSubqueryExpression) { + // RelBuilder.projectPlus doesn't have a parameter with variablesSet: + // projectPlus(Iterable variablesSet, RexNode... nodes) + context.relBuilder.project( + Iterables.concat(context.relBuilder.fields(), ImmutableList.of(eval)), + ImmutableList.of(), + false, + ImmutableList.of(v.get().id)); + context.popCorrelVar(); + } else { + context.relBuilder.projectPlus(eval); + } return eval; }) .collect(Collectors.toList()); diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 5531dad49cf..a26765a36b6 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -18,17 +18,14 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserUtil; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.DateString; -import org.apache.calcite.util.Holder; import org.apache.calcite.util.TimeString; import org.apache.calcite.util.TimestampString; -import org.checkerframework.checker.nullness.qual.Nullable; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.And; @@ -46,6 +43,7 @@ import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils; import org.opensearch.sql.exception.SemanticCheckException; @@ -283,7 +281,7 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) { List nodes = node.getChild().stream().map(child -> analyze(child, context)).toList(); UnresolvedPlan subquery = node.getQuery(); - RelNode subqueryRel = resolveSubqueryPlan(subquery, false, context); + RelNode subqueryRel = resolveSubqueryPlan(subquery, context); try { return context.relBuilder.in(subqueryRel, nodes); // TODO @@ -303,18 +301,25 @@ public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) { } } + @Override + public RexNode visitScalarSubquery(ScalarSubquery node, CalcitePlanContext context) { + return context.relBuilder.scalarQuery( + b -> { + UnresolvedPlan subquery = node.getQuery(); + return resolveSubqueryPlan(subquery, context); + }); + } + @Override public RexNode visitExistsSubquery(ExistsSubquery node, CalcitePlanContext context) { - final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); return context.relBuilder.exists( b -> { UnresolvedPlan subquery = node.getQuery(); - return resolveSubqueryPlan(subquery, true, context); + return resolveSubqueryPlan(subquery, context); }); } - private RelNode resolveSubqueryPlan( - UnresolvedPlan subquery, boolean isExists, CalcitePlanContext context) { + private RelNode resolveSubqueryPlan(UnresolvedPlan subquery, CalcitePlanContext context) { // clear and store the outer state boolean isResolvingJoinConditionOuter = context.isResolvingJoinCondition(); if (isResolvingJoinConditionOuter) { diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLScalarSubqueryIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLScalarSubqueryIT.java new file mode 100644 index 00000000000..d02a1e277a7 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLScalarSubqueryIT.java @@ -0,0 +1,335 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.standalone; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_OCCUPATION; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WORKER; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WORK_INFORMATION; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.client.Request; + +public class CalcitePPLScalarSubqueryIT extends CalcitePPLIntegTestCase { + + @Override + public void init() throws IOException { + super.init(); + + loadIndex(Index.WORKER); + loadIndex(Index.WORK_INFORMATION); + loadIndex(Index.OCCUPATION); + + // {"index":{"_id":"7"}} + // {"id":1006,"name":"Tommy","occupation":"Teacher","country":"USA","salary":30000} + Request request1 = new Request("PUT", "/" + TEST_INDEX_WORKER + "/_doc/7?refresh=true"); + request1.setJsonEntity( + "{\"id\":1006,\"name\":\"Tommy\",\"occupation\":\"Teacher\",\"country\":\"USA\",\"salary\":30000}"); + client().performRequest(request1); + } + + @Test + public void testUncorrelatedScalarSubqueryInSelect() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | eval count_dept = [ + source = %s | stats count(department) + ] + | fields name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("name", "string"), schema("count_dept", "long")); + verifyDataRows( + result, + rows("Jake", 5), + rows("Hello", 5), + rows("John", 5), + rows("David", 5), + rows("David", 5), + rows("Jane", 5), + rows("Tommy", 5)); + } + + @Test + public void testUncorrelatedScalarSubqueryInExpressionInSelect() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | eval count_dept = [ + source = %s | stats count(department) + ] + 10 + | fields name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("name", "string"), schema("count_dept", "long")); + verifyDataRows( + result, + rows("Jake", 15), + rows("Hello", 15), + rows("John", 15), + rows("David", 15), + rows("David", 15), + rows("Jane", 15), + rows("Tommy", 15)); + } + + @Test + public void testUncorrelatedScalarSubqueryInSelectAndWhere() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where id > [ + source = %s | stats count(department) + ] + 999 + | eval count_dept = [ + source = %s | stats count(department) + ] + | fields name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("name", "string"), schema("count_dept", "long")); + verifyDataRows(result, rows("Jane", 5), rows("Tommy", 5)); + } + + @Test + public void testUncorrelatedScalarSubqueryInSelectAndInFilter() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s id > [ source = %s | stats count(department) ] + 999 + | eval count_dept = [ + source = %s | stats count(department) + ] + | fields name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("name", "string"), schema("count_dept", "long")); + verifyDataRows(result, rows("Jane", 5), rows("Tommy", 5)); + } + + @Test + public void testCorrelatedScalarSubqueryInSelect() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | eval count_dept = [ + source = %s + | where id = uid | stats count(department) + ] + | fields id, name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema( + result, schema("id", "integer"), schema("name", "string"), schema("count_dept", "long")); + verifyDataRows( + result, + rows(1000, "Jake", 1), + rows(1001, "Hello", 0), + rows(1002, "John", 1), + rows(1003, "David", 1), + rows(1004, "David", 0), + rows(1005, "Jane", 1), + rows(1006, "Tommy", 1)); + } + + @Test + public void testCorrelatedScalarSubqueryInSelectWithNonEqual() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | eval count_dept = [ + source = %s + | where id > uid | stats count(department) + ] + | fields id, name, count_dept + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema( + result, schema("id", "integer"), schema("name", "string"), schema("count_dept", "long")); + verifyDataRows( + result, + rows(1000, "Jake", 0), + rows(1001, "Hello", 1), + rows(1002, "John", 1), + rows(1003, "David", 2), + rows(1004, "David", 3), + rows(1005, "Jane", 3), + rows(1006, "Tommy", 4)); + } + + @Test + public void testCorrelatedScalarSubqueryInWhere() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where id = [ + source = %s | where id = uid | stats max(uid) + ] + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows( + result, + rows(1000, "Jake"), + rows(1002, "John"), + rows(1003, "David"), + rows(1005, "Jane"), + rows(1006, "Tommy")); + } + + @Test + public void testCorrelatedScalarSubqueryInFilter() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s id = [ source = %s | where id = uid | stats max(uid) ] + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows( + result, + rows(1000, "Jake"), + rows(1002, "John"), + rows(1003, "David"), + rows(1005, "Jane"), + rows(1006, "Tommy")); + } + + @Test + public void testDisjunctiveCorrelatedScalarSubquery() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where [ + source = %s | where id = uid OR uid = 1010 | stats count() + ] > 0 + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows( + result, + rows(1000, "Jake"), + rows(1002, "John"), + rows(1003, "David"), + rows(1005, "Jane"), + rows(1006, "Tommy")); + } + + @Test + public void testTwoUncorrelatedScalarSubqueriesInOr() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where id = [ + source = %s | sort uid | stats max(uid) + ] OR id = [ + source = %s | sort uid | where department = 'DATA' | stats min(uid) + ] + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows(result, rows(1002, "John"), rows(1006, "Tommy")); + } + + @Test + public void testTwoCorrelatedScalarSubqueriesInOr() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where id = [ + source = %s | where id = uid | stats max(uid) + ] OR id = [ + source = %s | sort uid | where department = 'DATA' | stats min(uid) + ] + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_WORK_INFORMATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows( + result, + rows(1000, "Jake"), + rows(1002, "John"), + rows(1003, "David"), + rows(1005, "Jane"), + rows(1006, "Tommy")); + } + + @Test + public void testNestedScalarSubquery() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s + | where id = [ + source = %s + | where uid = [ + source = %s + | stats min(salary) + ] + 1000 + | sort department + | stats max(uid) + ] + | fields id, name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_OCCUPATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows(result, rows(1000, "Jake")); + } + + @Test + public void testNestedScalarSubqueryWithTableAlias() { + JSONObject result = + executeQuery( + String.format( + """ + source = %s as o + | where id = [ + source = %s as i + | where uid = [ + source = %s as n + | stats min(n.salary) + ] + 1000 + | sort i.department + | stats max(i.uid) + ] + | fields o.id, o.name + """, + TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION, TEST_INDEX_OCCUPATION)); + verifySchema(result, schema("id", "integer"), schema("name", "string")); + verifyDataRows(result, rows(1000, "Jake")); + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 638fd055d5a..a34b6433ae4 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -359,6 +359,7 @@ valueExpression | getFormatFunction # getFormatFunctionCall | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr + | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr ; primaryExpression diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index aa6bb3db0cb..211128ec909 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -55,6 +55,7 @@ import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; @@ -426,6 +427,12 @@ public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryEx return ctx.NOT() != null ? new Not(expr) : expr; } + @Override + public UnresolvedExpression visitScalarSubqueryExpr( + OpenSearchPPLParser.ScalarSubqueryExprContext ctx) { + return new ScalarSubquery(astBuilder.visitSubSearch(ctx.subSearch())); + } + @Override public UnresolvedExpression visitExistsSubqueryExpr( OpenSearchPPLParser.ExistsSubqueryExprContext ctx) { diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 6488540679e..9831bc65843 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -31,6 +31,7 @@ import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; @@ -413,6 +414,12 @@ public String visitInSubquery(InSubquery node, String context) { return StringUtils.format("(%s) in [ %s ]", nodes, subquery); } + @Override + public String visitScalarSubquery(ScalarSubquery node, String context) { + String subquery = queryAnonymizer.anonymizeData(node.getQuery()); + return StringUtils.format("[ %s ]", subquery); + } + @Override public String visitExistsSubquery(ExistsSubquery node, String context) { String subquery = queryAnonymizer.anonymizeData(node.getQuery()); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLInSubqueryTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLInSubqueryTest.java index 5f46101a7c9..8f3b7517d4b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLInSubqueryTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLInSubqueryTest.java @@ -33,7 +33,7 @@ public void testInSubquery() { + " LogicalFilter(condition=[IN($7, {\n" + "LogicalProject(DEPTNO=[$0])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" - + "})])\n" + + "})], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -66,7 +66,7 @@ public void testSelfInSubquery() { + "LogicalProject(MGR=[$3])\n" + " LogicalFilter(condition=[=($7, 10)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" - + "})])\n" + + "})], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -95,7 +95,7 @@ public void testTwoExpressionsInSubquery() { + " LogicalFilter(condition=[IN($7, $1, {\n" + "LogicalProject(DEPTNO=[$0], DNAME=[$1])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" - + "})])\n" + + "})], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -124,7 +124,7 @@ public void testFilterInSubquery() { + " LogicalFilter(condition=[IN($7, $1, {\n" + "LogicalProject(DEPTNO=[$0], DNAME=[$1])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" - + "})])\n" + + "})], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -153,7 +153,7 @@ public void testNotInSubquery() { + " LogicalFilter(condition=[NOT(IN($7, $1, {\n" + "LogicalProject(DEPTNO=[$0], DNAME=[$1])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" - + "}))])\n" + + "}))], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -184,9 +184,9 @@ public void testNestedSubquery() { + " LogicalFilter(condition=[IN($1, {\n" + "LogicalProject(ENAME=[$0])\n" + " LogicalTableScan(table=[[scott, BONUS]])\n" - + "})])\n" + + "})], variablesSet=[[$cor1]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" - + "})])\n" + + "})], variablesSet=[[$cor0]])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n"; verifyLogical(root, expectedLogical); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLScalarSubqueryTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLScalarSubqueryTest.java new file mode 100644 index 00000000000..f243179c7d2 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLScalarSubqueryTest.java @@ -0,0 +1,348 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLScalarSubqueryTest extends CalcitePPLAbstractTest { + + public CalcitePPLScalarSubqueryTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testUncorrelatedScalarSubqueryInWhere() { + String ppl = + """ + source=EMP + | where SAL > [ + source=EMP + | stats AVG(SAL) + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalFilter(condition=[>($5, $SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], AVG(SAL)=[AVG($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + "}))], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `SAL` > (((SELECT AVG(`SAL`) `AVG(SAL)`\n" + + "FROM `scott`.`EMP`)))"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testUncorrelatedScalarSubqueryInSelect() { + String ppl = + """ + source=EMP + | eval min_empno = [ + source=EMP | stats min(EMPNO) + ] + | fields min_empno, SAL + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalProject(variablesSet=[[$cor0]], min_empno=[$SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], min(EMPNO)=[MIN($0)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + "})], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT (((SELECT MIN(`EMPNO`) `min(EMPNO)`\n" + + "FROM `scott`.`EMP`))) `min_empno`, `SAL`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testUncorrelatedScalarSubqueryInWhereAndSelect() { + String ppl = + """ + source=EMP + | eval min_empno = [ + source=EMP | stats min(EMPNO) + ] + | where SAL > [ + source=EMP + | stats AVG(SAL) + ] + | fields min_empno, SAL + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(min_empno=[$8], SAL=[$5])\n" + + " LogicalFilter(condition=[>($5, $SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], AVG(SAL)=[AVG($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + "}))], variablesSet=[[$cor1]])\n" + + " LogicalProject(variablesSet=[[$cor0]], EMPNO=[$0], ENAME=[$1], JOB=[$2]," + + " MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " min_empno=[$SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], min(EMPNO)=[MIN($0)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + "})])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `min_empno`, `SAL`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " (((SELECT MIN(`EMPNO`) `min(EMPNO)`\n" + + "FROM `scott`.`EMP`))) `min_empno`\n" + + "FROM `scott`.`EMP`) `t0`\n" + + "WHERE `SAL` > (((SELECT AVG(`SAL`) `AVG(SAL)`\n" + + "FROM `scott`.`EMP`)))"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testCorrelatedScalarSubqueryInWhere() { + String ppl = + """ + source=EMP + | where SAL > [ + source=SALGRADE | where SAL = HISAL | stats AVG(SAL) + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalFilter(condition=[>($5, $SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], AVG(SAL)=[AVG($0)])\n" + + " LogicalProject($f3=[$cor0.SAL])\n" + + " LogicalFilter(condition=[=($cor0.SAL, $2)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "}))], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `SAL` > (((SELECT AVG(`EMP`.`SAL`) `AVG(SAL)`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `EMP`.`SAL` = `HISAL`)))"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testCorrelatedScalarSubqueryInSelect() { + String ppl = + """ + source=EMP + | eval min_empno = [ + source=SALGRADE | where SAL = HISAL | stats min(EMPNO) + ] + | fields min_empno, SAL + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalProject(variablesSet=[[$cor0]], min_empno=[$SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], min(EMPNO)=[MIN($0)])\n" + + " LogicalProject($f3=[$cor0.EMPNO])\n" + + " LogicalFilter(condition=[=($cor0.SAL, $2)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "})], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT (((SELECT MIN(`EMP`.`EMPNO`) `min(EMPNO)`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `EMP`.`SAL` = `HISAL`))) `min_empno`, `SAL`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testDisjunctiveCorrelatedScalarSubqueryInWhere() { + String ppl = + """ + source=EMP + | where [ + source=SALGRADE | where SAL = HISAL OR HISAL > 1000.0 | stats COUNT() + ] > 0 + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalFilter(condition=[>($SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], COUNT()=[COUNT()])\n" + + " LogicalFilter(condition=[OR(=($cor0.SAL, $2), >($2, 1000.0E0:DOUBLE))])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "}), 0)], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE (((SELECT COUNT(*) `COUNT()`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `EMP`.`SAL` = `HISAL` OR `HISAL` > 1.0000E3))) > 0"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testDisjunctiveCorrelatedScalarSubqueryInWhere2() { + String ppl = + """ + source=EMP + | where [ + source=SALGRADE | where (SAL = HISAL AND HISAL > 1000.0) OR (SAL = HISAL AND LOSAL > 1000.0) | stats COUNT() + ] > 0 + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalFilter(condition=[>($SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], COUNT()=[COUNT()])\n" + + " LogicalFilter(condition=[OR(AND(=($cor0.SAL, $2), >($2, 1000.0E0:DOUBLE))," + + " AND(=($cor0.SAL, $2), >($1, 1000.0E0:DOUBLE)))])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "}), 0)], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE (((SELECT COUNT(*) `COUNT()`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `EMP`.`SAL` = `HISAL` AND `HISAL` > 1.0000E3 OR `EMP`.`SAL` = `HISAL` AND" + + " `LOSAL` > 1.0000E3))) > 0"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testTwoScalarSubqueriesInOr() { + String ppl = + """ + source=EMP + | where SAL = [ + source=SALGRADE | sort LOSAL | stats max(HISAL) + ] OR SAL = [ + source=SALGRADE | where LOSAL > 1000.0 | sort - HISAL | stats min(HISAL) + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalFilter(condition=[OR(=($5, $SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], max(HISAL)=[MAX($2)])\n" + + " LogicalSort(sort0=[$1], dir0=[ASC])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "})), =($5, $SCALAR_QUERY({\n" + + "LogicalAggregate(group=[{}], min(HISAL)=[MIN($2)])\n" + + " LogicalSort(sort0=[$2], dir0=[DESC])\n" + + " LogicalFilter(condition=[>($1, 1000.0E0:DOUBLE)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "})))], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `SAL` = (((SELECT MAX(`HISAL`) `max(HISAL)`\n" + + "FROM (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "FROM `scott`.`SALGRADE`\n" + + "ORDER BY `LOSAL` NULLS LAST) `t`))) OR `SAL` = (((SELECT MIN(`HISAL`) `min(HISAL)`\n" + + "FROM (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `LOSAL` > 1.0000E3\n" + + "ORDER BY `HISAL` DESC NULLS FIRST) `t2`)))"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testNestedScalarSubquery() { + String ppl = + """ + source=EMP + | where SAL = [ + source=SALGRADE + | where HISAL = [ + source=EMP + | stats max(SAL) as max_sal by JOB + | fields max_sal + ] + | stats max(HISAL) as max_hisal by GRADE + | fields max_hisal + | head 1 + ] + """; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalFilter(condition=[=($5, $SCALAR_QUERY({\n" + + "LogicalSort(fetch=[1])\n" + + " LogicalProject(max_hisal=[$1])\n" + + " LogicalAggregate(group=[{0}], max_hisal=[MAX($2)])\n" + + " LogicalFilter(condition=[=($2, $SCALAR_QUERY({\n" + + "LogicalProject(max_sal=[$1])\n" + + " LogicalAggregate(group=[{2}], max_sal=[MAX($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + "}))], variablesSet=[[$cor1]])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n" + + "}))], variablesSet=[[$cor0]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + // SELECT * + // FROM scott.EMP + // WHERE SAL = ( + // SELECT MAX(HISAL) max_hisal + // FROM scott.SALGRADE + // WHERE HISAL = ( + // SELECT MAX(SAL) max_sal + // FROM scott.EMP + // GROUP BY JOB + // ) + // GROUP BY GRADE + // LIMIT 1 + // ) + String expectedSparkSql = + "" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `SAL` = (((SELECT MAX(`HISAL`) `max_hisal`\n" + + "FROM `scott`.`SALGRADE`\n" + + "WHERE `HISAL` = (((SELECT MAX(`SAL`) `max_sal`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `JOB`)))\n" + + "GROUP BY `GRADE`\n" + + "LIMIT 1)))"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + // TODO: With Calcite, we can add more complex scalar subquery, such as + // stats by a scalar subquery: + // | eval count_a = [ + // source=.. + // ] + // | stats .. by count_a + // But currently, statsBy an expression is unsupported in PPL. +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index 32a34096bf4..be397a2572c 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -230,6 +230,19 @@ public void testExistsSubquery() { anonymize("source=t | where exists [source=s | where id = uid ] | fields id")); } + @Test + public void testScalarSubquery() { + assertEquals( + "source=t | where id = [ source=s | stats max(b) ] | fields + id", + anonymize("source=t | where id = [ source=s | stats max(b) ] | fields id")); + assertEquals( + "source=t | eval id=[ source=s | stats max(b) ] | fields + id", + anonymize("source=t | eval id = [ source=s | stats max(b) ] | fields id")); + assertEquals( + "source=t | where id > [ source=s | where id = uid | stats max(b) ] | fields + id", + anonymize("source=t id > [ source=s | where id = uid | stats max(b) ] | fields id")); + } + private String anonymize(String query) { AstBuilder astBuilder = new AstBuilder(query); return anonymize(astBuilder.visit(parser.parse(query)));