diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 91666e40c01..8fc6b7fcea6 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -37,6 +37,7 @@ public enum Key { CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"), CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR( "plugins.calcite.pushdown.rowcount.estimation.factor"), + CALCITE_SUPPORT_ALL_JOIN_TYPES("plugins.calcite.all_join_types.allowed"), /** Query Settings. */ FIELD_TYPE_TOLERANCE("plugins.query.field_type_tolerance"), diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/Argument.java b/core/src/main/java/org/opensearch/sql/ast/expression/Argument.java index 7aa52e44631..08bb3a4a418 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/Argument.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/Argument.java @@ -39,15 +39,24 @@ public static class ArgumentMap { private final Map map; public ArgumentMap(List arguments) { - this.map = - arguments.stream() - .collect(java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue)); + if (arguments == null || arguments.isEmpty()) { + this.map = Map.of(); + } else { + this.map = + arguments.stream() + .collect( + java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue)); + } } public static ArgumentMap of(List arguments) { return new ArgumentMap(arguments); } + public static ArgumentMap empty() { + return new ArgumentMap(null); + } + /** * Get argument value by name. * diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/Literal.java b/core/src/main/java/org/opensearch/sql/ast/expression/Literal.java index edb66e805da..3d61d5dc5a3 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/Literal.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/Literal.java @@ -46,4 +46,8 @@ public R accept(AbstractNodeVisitor nodeVisitor, C context) { public String toString() { return String.valueOf(value); } + + public static Literal TRUE = new Literal(true, DataType.BOOLEAN); + public static Literal FALSE = new Literal(false, DataType.BOOLEAN); + public static Literal ZERO = new Literal(Integer.valueOf("0"), DataType.INTEGER); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java index 0d976cbea8e..a0f58d35124 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -15,6 +15,8 @@ import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.UnresolvedExpression; @ToString @@ -28,6 +30,8 @@ public class Join extends UnresolvedPlan { private final JoinType joinType; private final Optional joinCondition; private final JoinHint joinHint; + private final Optional> joinFields; + private final Argument.ArgumentMap argumentMap; public Join( UnresolvedPlan right, @@ -35,13 +39,17 @@ public Join( Optional rightAlias, JoinType joinType, Optional joinCondition, - JoinHint joinHint) { + JoinHint joinHint, + Optional> joinFields, + Argument.ArgumentMap argumentMap) { this.right = right; this.leftAlias = leftAlias; this.rightAlias = rightAlias; this.joinType = joinType; this.joinCondition = joinCondition; this.joinHint = joinHint; + this.joinFields = joinFields; + this.argumentMap = argumentMap; } @Override @@ -89,6 +97,11 @@ public enum JoinType { FULL } + /** RIGHT, CROSS, FULL are performance sensitive join types */ + public static List highCostJoinTypes() { + return List.of(JoinType.RIGHT, JoinType.CROSS, JoinType.FULL); + } + @Getter @RequiredArgsConstructor public static class JoinHint { diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 5261d438863..abab6179274 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -48,6 +48,7 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.rex.RexWindowBounds; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeFamily; @@ -903,6 +904,70 @@ private Optional extractAliasLiteral(RexNode node) { public RelNode visitJoin(Join node, CalcitePlanContext context) { List children = node.getChildren(); children.forEach(c -> analyze(c, context)); + if (node.getJoinCondition().isEmpty()) { + // join-with-field-list grammar + List leftColumns = context.relBuilder.peek(1).getRowType().getFieldNames(); + List rightColumns = context.relBuilder.peek().getRowType().getFieldNames(); + List duplicatedFieldNames = + leftColumns.stream().filter(rightColumns::contains).toList(); + RexNode joinCondition; + if (node.getJoinFields().isPresent()) { + joinCondition = + node.getJoinFields().get().stream() + .map(field -> buildJoinConditionByFieldName(context, field.getField().toString())) + .reduce(context.rexBuilder::and) + .orElse(context.relBuilder.literal(true)); + } else { + joinCondition = + duplicatedFieldNames.stream() + .map(fieldName -> buildJoinConditionByFieldName(context, fieldName)) + .reduce(context.rexBuilder::and) + .orElse(context.relBuilder.literal(true)); + } + if (node.getJoinType() == SEMI || node.getJoinType() == ANTI) { + // semi and anti join only return left table outputs + context.relBuilder.join( + JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition); + return context.relBuilder.peek(); + } + List toBeRemovedFields; + if (node.getArgumentMap().get("overwrite") == null // 'overwrite' default value is true + || (node.getArgumentMap().get("overwrite").equals(Literal.TRUE))) { + toBeRemovedFields = + duplicatedFieldNames.stream() + .map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, true, context)) + .toList(); + } else { + toBeRemovedFields = + duplicatedFieldNames.stream() + .map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, false, context)) + .toList(); + } + Literal max = node.getArgumentMap().get("max"); + if (max != null && !max.equals(Literal.ZERO)) { + // max != 0 means the right-side should be dedup + Integer allowedDuplication = (Integer) max.getValue(); + if (allowedDuplication < 0) { + throw new SemanticCheckException("max option must be a positive integer"); + } + List dedupeFields = + node.getJoinFields().isPresent() + ? node.getJoinFields().get().stream() + .map(a -> (RexNode) context.relBuilder.field(a.getField().toString())) + .toList() + : duplicatedFieldNames.stream() + .map(a -> (RexNode) context.relBuilder.field(a)) + .toList(); + buildDedupNotNull(context, dedupeFields, allowedDuplication); + } + context.relBuilder.join( + JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition); + if (!toBeRemovedFields.isEmpty()) { + context.relBuilder.projectExcept(toBeRemovedFields); + } + return context.relBuilder.peek(); + } + // The join-with-criteria grammar doesn't allow empty join condition RexNode joinCondition = node.getJoinCondition() .map(c -> rexVisitor.analyzeJoinCondition(c, context)) @@ -938,6 +1003,19 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) { .orElse(rightTableQualifiedName + "." + col) : col) .toList(); + + Literal max = node.getArgumentMap().get("max"); + if (max != null && !max.equals(Literal.ZERO)) { + // max != 0 means the right-side should be dedup + Integer allowedDuplication = (Integer) max.getValue(); + if (allowedDuplication < 0) { + throw new SemanticCheckException("max option must be a positive integer"); + } + List dedupeFields = + getRightColumnsInJoinCriteria(context.relBuilder, joinCondition); + + buildDedupNotNull(context, dedupeFields, allowedDuplication); + } context.relBuilder.join( JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition); JoinAndLookupUtils.renameToExpectedFields( @@ -946,6 +1024,37 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) { return context.relBuilder.peek(); } + private List getRightColumnsInJoinCriteria( + RelBuilder relBuilder, RexNode joinCondition) { + int stackSize = relBuilder.size(); + int leftFieldCount = relBuilder.peek(stackSize - 1).getRowType().getFieldCount(); + RelNode right = relBuilder.peek(stackSize - 2); + List allColumnNamesOfRight = right.getRowType().getFieldNames(); + + List rightColumnIndexes = new ArrayList<>(); + joinCondition.accept( + new RexVisitorImpl(true) { + @Override + public Void visitInputRef(RexInputRef inputRef) { + if (inputRef.getIndex() >= leftFieldCount) { + rightColumnIndexes.add(inputRef.getIndex() - leftFieldCount); + } + return super.visitInputRef(inputRef); + } + }); + return rightColumnIndexes.stream() + .map(allColumnNamesOfRight::get) + .map(n -> (RexNode) relBuilder.field(n)) + .toList(); + } + + private static RexNode buildJoinConditionByFieldName( + CalcitePlanContext context, String fieldName) { + RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, false, context); + RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, true, context); + return context.rexBuilder.equals(sourceKey, lookupKey); + } + @Override public RelNode visitSubqueryAlias(SubqueryAlias node, CalcitePlanContext context) { visitChildren(node, context); @@ -1068,74 +1177,82 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) { List dedupeFields = node.getFields().stream().map(f -> rexVisitor.analyze(f, context)).toList(); if (keepEmpty) { - /* - * | dedup 2 a, b keepempty=false - * DropColumns('_row_number_dedup_) - * +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b)) - * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] - * +- ... - */ - // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, - // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a - // ASC - // NULLS FIRST, 'b ASC NULLS FIRST] - RexNode rowNumber = - context - .relBuilder - .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) - .over() - .partitionBy(dedupeFields) - .orderBy(dedupeFields) - .rowsTo(RexWindowBounds.CURRENT_ROW) - .as(ROW_NUMBER_COLUMN_FOR_DEDUP); - context.relBuilder.projectPlus(rowNumber); - RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); - // Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n) - context.relBuilder.filter( - context.relBuilder.or( - context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()), - context.relBuilder.lessThanOrEqual( - _row_number_dedup_, context.relBuilder.literal(allowedDuplication)))); - // DropColumns('_row_number_) - context.relBuilder.projectExcept(_row_number_dedup_); + buildDedupOrNull(context, dedupeFields, allowedDuplication); } else { - /* - * | dedup 2 a, b keepempty=false - * DropColumns('_row_number_dedup_) - * +- Filter ('_row_number_dedup_ <= n) - * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] - * +- Filter (isnotnull('a) AND isnotnull('b)) - * +- ... - */ - // Filter (isnotnull('a) AND isnotnull('b)) - context.relBuilder.filter( - context.relBuilder.and( - dedupeFields.stream().map(context.relBuilder::isNotNull).toList())); - // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, - // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a - // ASC - // NULLS FIRST, 'b ASC NULLS FIRST] - RexNode rowNumber = - context - .relBuilder - .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) - .over() - .partitionBy(dedupeFields) - .orderBy(dedupeFields) - .rowsTo(RexWindowBounds.CURRENT_ROW) - .as(ROW_NUMBER_COLUMN_FOR_DEDUP); - context.relBuilder.projectPlus(rowNumber); - RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); - // Filter ('_row_number_dedup_ <= n) - context.relBuilder.filter( - context.relBuilder.lessThanOrEqual( - _row_number_dedup_, context.relBuilder.literal(allowedDuplication))); - // DropColumns('_row_number_dedup_) - context.relBuilder.projectExcept(_row_number_dedup_); + buildDedupNotNull(context, dedupeFields, allowedDuplication); } return context.relBuilder.peek(); } + private static void buildDedupOrNull( + CalcitePlanContext context, List dedupeFields, Integer allowedDuplication) { + /* + * | dedup 2 a, b keepempty=false + * DropColumns('_row_number_dedup_) + * +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b)) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * +- ... + */ + // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, + // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a + // ASC + // NULLS FIRST, 'b ASC NULLS FIRST] + RexNode rowNumber = + context + .relBuilder + .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) + .over() + .partitionBy(dedupeFields) + .orderBy(dedupeFields) + .rowsTo(RexWindowBounds.CURRENT_ROW) + .as(ROW_NUMBER_COLUMN_FOR_DEDUP); + context.relBuilder.projectPlus(rowNumber); + RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); + // Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n) + context.relBuilder.filter( + context.relBuilder.or( + context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()), + context.relBuilder.lessThanOrEqual( + _row_number_dedup_, context.relBuilder.literal(allowedDuplication)))); + // DropColumns('_row_number_dedup_) + context.relBuilder.projectExcept(_row_number_dedup_); + } + + private static void buildDedupNotNull( + CalcitePlanContext context, List dedupeFields, Integer allowedDuplication) { + /* + * | dedup 2 a, b keepempty=false + * DropColumns('_row_number_dedup_) + * +- Filter ('_row_number_dedup_ <= n) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * +- Filter (isnotnull('a) AND isnotnull('b)) + * +- ... + */ + // Filter (isnotnull('a) AND isnotnull('b)) + context.relBuilder.filter( + context.relBuilder.and(dedupeFields.stream().map(context.relBuilder::isNotNull).toList())); + // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, + // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC + // NULLS FIRST, 'b ASC NULLS FIRST] + RexNode rowNumber = + context + .relBuilder + .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) + .over() + .partitionBy(dedupeFields) + .orderBy(dedupeFields) + .rowsTo(RexWindowBounds.CURRENT_ROW) + .as(ROW_NUMBER_COLUMN_FOR_DEDUP); + context.relBuilder.projectPlus(rowNumber); + RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); + // Filter ('_row_number_dedup_ <= n) + context.relBuilder.filter( + context.relBuilder.lessThanOrEqual( + _row_number_dedup_, context.relBuilder.literal(allowedDuplication))); + // DropColumns('_row_number_dedup_) + context.relBuilder.projectExcept(_row_number_dedup_); + } + @Override public RelNode visitWindow(Window node, CalcitePlanContext context) { visitChildren(node, context); diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index 177b911bdd1..9d5e45cd7b6 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -759,7 +759,7 @@ Check `introduce v3 engine <../../../dev/intro-v3-engine.md>`_ for more details. Check `join doc <../../ppl/cmd/join.rst>`_ for example. plugins.calcite.fallback.allowed -======================= +================================ Description ----------- @@ -771,7 +771,7 @@ If Calcite is enabled, you can use this setting to decide whether to allow fallb 3. This setting can be updated dynamically. plugins.calcite.pushdown.enabled -======================= +================================ Description ----------- @@ -783,7 +783,7 @@ If Calcite is enabled, you can use this setting to decide whether to enable the 3. This setting can be updated dynamically. plugins.calcite.pushdown.rowcount.estimation.factor -======================= +=================================================== Description ----------- @@ -793,3 +793,15 @@ If Calcite pushdown optimization is enabled, this setting is used to estimate th 1. The default value is 0.9 since 3.1.0. 2. This setting is node scope. 3. This setting can be updated dynamically. + +plugins.calcite.all_join_types.allowed +====================================== + +Description +----------- + +Join types ``inner``, ``left``, ``outer`` (alias of ``left``), ``semi`` and ``anti`` are supported by default. ``right``, ``full``, ``cross`` are performance sensitive join types which are disabled by default. Set config ``plugins.calcite.all_join_types.allowed = true`` to enable. + +1. The default value is false since 3.3.0. +2. This setting is node scope. +3. This setting can be updated dynamically. diff --git a/docs/user/ppl/admin/settings.rst b/docs/user/ppl/admin/settings.rst index fad7164d644..b15fc8159a7 100644 --- a/docs/user/ppl/admin/settings.rst +++ b/docs/user/ppl/admin/settings.rst @@ -17,7 +17,7 @@ Introduction When OpenSearch bootstraps, PPL plugin will register a few settings in OpenSearch cluster settings. Most of the settings are able to change dynamically so you can control the behavior of PPL plugin without need to bounce your cluster. plugins.ppl.enabled -====================== +=================== Description ----------- @@ -90,7 +90,7 @@ PPL query:: } plugins.query.memory_limit -================================= +========================== Description ----------- @@ -120,7 +120,7 @@ PPL query:: Note: the legacy settings of ``opendistro.ppl.query.memory_limit`` is deprecated, it will fallback to the new settings if you request an update with the legacy name. plugins.query.size_limit -=========================== +======================== Description ----------- @@ -159,3 +159,33 @@ Rollback to default value:: } Note: the legacy settings of ``opendistro.query.size_limit`` is deprecated, it will fallback to the new settings if you request an update with the legacy name. + +plugins.calcite.all_join_types.allowed +====================================== + +Description +----------- + +Since 3.3.0, join types ``inner``, ``left``, ``outer`` (alias of ``left``), ``semi`` and ``anti`` are supported by default. ``right``, ``full``, ``cross`` are performance sensitive join types which are disabled by default. Set config ``plugins.calcite.all_join_types.allowed = true`` to enable. + +Example +------- + +PPL query:: + + sh$ curl -sS -H 'Content-Type: application/json' \ + ... -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient" : {"plugins.calcite.all_join_types.allowed" : "true"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "calcite": { + "all_join_types": { + "allowed": "true" + } + } + } + } + } diff --git a/docs/user/ppl/cmd/join.rst b/docs/user/ppl/cmd/join.rst index e9c21cd1e50..45c091b5b4e 100644 --- a/docs/user/ppl/cmd/join.rst +++ b/docs/user/ppl/cmd/join.rst @@ -11,27 +11,35 @@ join Description =========== -| (Experimental) -| (From 3.0.0) | Using ``join`` command to combines two datasets together. The left side could be an index or results from a piped commands, the right side could be either an index or a subsearch. Version ======= 3.0.0 -Syntax -====== -[joinType] join [leftAlias] [rightAlias] on +Basic syntax in 3.0.0 +===================== +| [joinType] join [leftAlias] [rightAlias] (on | where) -* joinType: optional. The type of join to perform. The default is ``INNER`` if not specified. Other option is ``LEFT [OUTER]``, ``RIGHT [OUTER]``, ``FULL [OUTER]``, ``CROSS``, ``[LEFT] SEMI``, ``[LEFT] ANTI``. +* joinType: optional. The type of join to perform. The default is ``inner`` if not specified. Other option is ``left``, ``semi``, ``anti`` and performance sensitive types ``right``, ``full`` and ``cross``. * leftAlias: optional. The subsearch alias to use with the left join side, to avoid ambiguous naming. Fixed pattern: ``left = `` * rightAlias: optional. The subsearch alias to use with the right join side, to avoid ambiguous naming. Fixed pattern: ``right = `` -* joinCriteria: mandatory. It could be any comparison expression. -* right-dataset: mandatory. Right dataset could be either an index or a subsearch with/without alias. +* joinCriteria: mandatory. It could be any comparison expression. Must follow with ``on`` (since 3.0.0) or ``where`` (since 3.3.0) keyword. +* right-dataset: mandatory. Right dataset could be either an ``index`` or a ``subsearch`` with/without alias. + +Extended syntax since 3.3.0 +=========================== +| join [type=] [overwrite=] [max=n] ( | [leftAlias] [rightAlias] (on | where) ) +| From 3.3.0, the join syntax is enhanced to support more join options and join with field list. + +* type=: optional. The type of join to perform. The default is ``inner`` if not specified. Other option is ``left``, ``outer``(alias of ``left``), ``semi``, ``anti`` and performance sensitive types ``right``, ``full`` and ``cross``. +* overwrite=: optional. Only works with ``join-field-list``. Specifies whether duplicate-named fields from (subsearch results) should replace corresponding fields in the main search results. The default value is ``true``. +* max=n: optional. Controls how many subsearch results could be joined against to each row in main search. The default value is 0, means unlimited. +* join-field-list: optional. The fields used to build the join criteria. The join field list must exist on both sides. If no join field list is specified, all fields common to both sides will be used as join keys. The comma is optional. Configuration ============= -This command requires Calcite enabled. In 3.0.0-beta, as an experimental the Calcite configuration is disabled by default. +This command requires Calcite enabled. In 3.0.0, as an experimental the Calcite configuration is disabled by default. Enable Calcite:: @@ -58,13 +66,14 @@ Result set:: Usage ===== -Join:: +Join on criteria (in 3.0.0):: source = table1 | inner join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c + source = table1 | inner join left = l right = r where l.a = r.a table2 | fields l.a, r.a, b, c source = table1 | left join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c source = table1 | right join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c source = table1 | full left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c - source = table1 | cross join left = l right = r table2 + source = table1 | cross join left = l right = r on 1=1 table2 source = table1 | left semi join left = l right = r on l.a = r.a table2 source = table1 | left anti join left = l right = r on l.a = r.a table2 source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] @@ -74,6 +83,16 @@ Join:: source = table1 as t1 | join left = l right = r on l.a = r.a table2 as t2 | fields t1.a, t2.a source = table1 | join left = l right = r on l.a = r.a [ source = table2 ] as s | fields l.a, s.a +Extended syntax and option supported (since 3.3.0):: + + source = table1 | join type=outer left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c + source = table1 | join type=left left = l right = r where l.a = r.a table2 | fields l.a, r.a, b, c + source = table1 | join type=inner max=1 left = l right = r where l.a = r.a table2 | fields l.a, r.a, b, c + source = table1 | join a table2 | fields a, b, c + source = table1 | join a, b table2 | fields a, b, c + source = table1 | join type=outer a b table2 | fields a, b, c + source = table1 | join type=inner max=1 a, b table2 | fields a, b, c + source = table1 | join type=left overwrite=false max=0 a, b [source=table2 | rename d as b] | fields a, b, c Example 1: Two indices join =========================== @@ -116,9 +135,49 @@ PPL query:: | 100000.0 | 70 | England | +-------------+----------+-----------+ +Example 3: Join with field list +=============================== + +PPL query:: + + PPL> source = state_country + | where country = 'USA' OR country = 'England' + | join type=left overwrite=true name [ + source = occupation + | where salary > 0 + | fields name, country, salary + | sort salary + | head 3 + ] + | stats avg(salary) by span(age, 10) as age_span, country; + fetched rows / total rows = 5/5 + +-------------+----------+-----------+ + | avg(salary) | age_span | country | + |-------------+----------+-----------| + | null | 40 | null | + | 70000.0 | 30 | USA | + | 100000.0 | 70 | England | + +-------------+----------+-----------+ + +Example 4: Join with options +============================ + +PPL query:: + + PPL> source = state_country | join type=inner overwrite=false max=1 name occupation | stats avg(salary) by span(age, 10) as age_span, country; + fetched rows / total rows = 5/5 + +-------------+----------+---------+ + | avg(salary) | age_span | country | + |-------------+----------+---------| + | 120000.0 | 40 | USA | + | 100000.0 | 70 | USA | + | 105000.0 | 20 | Canada | + | 70000.0 | 30 | USA | + +-------------+----------+---------+ + Limitation ========== -If fields in the left outputs and right outputs have the same name. Typically, in the join criteria +For basic syntax in 3.0.0, if fields in the left outputs and right outputs have the same name. Typically, in the join criteria ``ON t1.id = t2.id``, the names ``id`` in output are ambiguous. To avoid ambiguous, the ambiguous fields in output rename to ``.id``, or else ``.id`` if no alias existing. @@ -138,3 +197,7 @@ Assume table1 and table2 only contain field ``id``, following PPL queries and th - table1.id, t2.id, a * - source=table1 | join right=tt on table1.id=t2.id [ source=table2 as t2 | eval b = id ] | eval a = 1 - table1.id, tt.id, tt.b, a + +For extended syntax (join with field list) in 3.3.0, when duplicate-named fields in output results are deduplicated, the fields in output determined by the value of 'overwrite' option. + +Since 3.3.0, join types ``inner``, ``left``, ``outer`` (alias of ``left``), ``semi`` and ``anti`` are supported by default. ``right``, ``full``, ``cross`` are performance sensitive join types which are disabled by default. Set config ``plugins.calcite.all_join_types.allowed = true`` to enable. diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index d3a533a2649..7ec660b1f89 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -67,6 +67,39 @@ public void supportSearchSargPushDown_timeRange() throws IOException { assertJsonEqualsIgnoreId(expected, result); } + // Only for Calcite + @Ignore("https://github.com/opensearch-project/OpenSearch/issues/3725") + public void testJoinWithCriteriaAndMaxOption() throws IOException { + String query = + "source=opensearch-sql_test_index_bank | join max=1 left=l right=r on" + + " l.account_number=r.account_number opensearch-sql_test_index_bank"; + var result = explainQueryToString(query); + String expected = loadExpectedPlan("explain_join_with_criteria_max_option.json"); + assertJsonEqualsIgnoreId(expected, result); + } + + // Only for Calcite + @Ignore("https://github.com/opensearch-project/OpenSearch/issues/3725") + public void testJoinWithFieldListAndMaxOption() throws IOException { + String query = + "source=opensearch-sql_test_index_bank | join type=inner max=1 account_number" + + " opensearch-sql_test_index_bank"; + var result = explainQueryToString(query); + String expected = loadExpectedPlan("explain_join_with_fields_max_option.json"); + assertJsonEqualsIgnoreId(expected, result); + } + + // Only for Calcite + @Test + public void testJoinWithFieldList() throws IOException { + String query = + "source=opensearch-sql_test_index_bank | join type=outer account_number" + + " opensearch-sql_test_index_bank"; + var result = explainQueryToString(query); + String expected = loadExpectedPlan("explain_join_with_fields.json"); + assertJsonEqualsIgnoreId(expected, result); + } + // Only for Calcite @Test public void supportPushDownSortMergeJoin() throws IOException { diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendCommandIT.java index 66e0e61cc46..d971a6f3cb1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendCommandIT.java @@ -18,6 +18,7 @@ import java.util.Locale; import org.json.JSONObject; import org.junit.Test; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.ppl.PPLIntegTestCase; public class CalcitePPLAppendCommandIT extends PPLIntegTestCase { @@ -90,71 +91,86 @@ public void testAppendEmptySearchCommand() throws IOException { @Test public void testAppendEmptySearchWithJoin() throws IOException { - List emptySourceWithJoinPPLs = - Arrays.asList( - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " - + " join left=L right=R on L.gender = R.gender %s ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT), - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " - + " cross join left=L right=R on L.gender = R.gender %s ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT), - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " - + " left join left=L right=R on L.gender = R.gender %s ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT), - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " - + " semi join left=L right=R on L.gender = R.gender %s ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT)); + withSettings( + Settings.Key.CALCITE_SUPPORT_ALL_JOIN_TYPES, + "true", + () -> { + List emptySourceWithJoinPPLs = + Arrays.asList( + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " + + " join left=L right=R on L.gender = R.gender %s ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT), + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " + + " cross join left=L right=R on L.gender = R.gender %s ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT), + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " + + " left join left=L right=R on L.gender = R.gender %s ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT), + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | " + + " semi join left=L right=R on L.gender = R.gender %s ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT)); - for (String ppl : emptySourceWithJoinPPLs) { - JSONObject actual = executeQuery(ppl); - verifySchemaInOrder( - actual, schema("sum_age_by_gender", "bigint"), schema("gender", "string")); - verifyDataRows(actual, rows(14947, "F"), rows(15224, "M")); - } + for (String ppl : emptySourceWithJoinPPLs) { + JSONObject actual = null; + try { + actual = executeQuery(ppl); + } catch (IOException e) { + throw new RuntimeException(e); + } + verifySchemaInOrder( + actual, schema("sum_age_by_gender", "bigint"), schema("gender", "string")); + verifyDataRows(actual, rows(14947, "F"), rows(15224, "M")); + } - List emptySourceWithRightOrFullJoinPPLs = - Arrays.asList( - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | where" - + " gender = 'F' | right join on gender = gender [source=%s | stats count() as" - + " cnt by gender ] ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT), - String.format( - Locale.ROOT, - "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | where" - + " gender = 'F' | full join on gender = gender [source=%s | stats count() as" - + " cnt by gender ] ]", - TEST_INDEX_ACCOUNT, - TEST_INDEX_ACCOUNT)); + List emptySourceWithRightOrFullJoinPPLs = + Arrays.asList( + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | where" + + " gender = 'F' | right join on gender = gender [source=%s | stats" + + " count() as cnt by gender ] ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT), + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age_by_gender by gender | append [ | where" + + " gender = 'F' | full join on gender = gender [source=%s | stats" + + " count() as cnt by gender ] ]", + TEST_INDEX_ACCOUNT, + TEST_INDEX_ACCOUNT)); - for (String ppl : emptySourceWithRightOrFullJoinPPLs) { - JSONObject actual = executeQuery(ppl); - verifySchemaInOrder( - actual, - schema("sum_age_by_gender", "bigint"), - schema("gender", "string"), - schema("cnt", "bigint")); - verifyDataRows( - actual, - rows(14947, "F", null), - rows(15224, "M", null), - rows(null, "F", 493), - rows(null, "M", 507)); - } + for (String ppl : emptySourceWithRightOrFullJoinPPLs) { + JSONObject actual = null; + try { + actual = executeQuery(ppl); + } catch (IOException e) { + throw new RuntimeException(e); + } + verifySchemaInOrder( + actual, + schema("sum_age_by_gender", "bigint"), + schema("gender", "string"), + schema("cnt", "bigint")); + verifyDataRows( + actual, + rows(14947, "F", null), + rows(15224, "M", null), + rows(null, "F", 493), + rows(null, "M", 507)); + } + }); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLJoinIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLJoinIT.java index 890206f0a39..c9886f687c2 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLJoinIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLJoinIT.java @@ -13,11 +13,11 @@ import static org.opensearch.sql.util.MatcherUtils.schema; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder; +import static org.opensearch.sql.util.MatcherUtils.verifyNumOfRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; import org.json.JSONObject; -import org.junit.Ignore; import org.junit.Test; import org.opensearch.client.Request; import org.opensearch.sql.legacy.TestsConstants; @@ -29,6 +29,7 @@ public class CalcitePPLJoinIT extends PPLIntegTestCase { public void init() throws Exception { super.init(); enableCalcite(); + supportAllJoinTypes(); loadIndex(Index.STATE_COUNTRY); loadIndex(Index.OCCUPATION); @@ -298,7 +299,7 @@ public void testComplexCrossJoin() throws IOException { executeQuery( String.format( "source = %s | where country = 'Canada' OR country = 'England' | join left=a," - + " right=b %s | sort a.age | stats count()", + + " right=b on 1=1 %s | sort a.age | stats count()", TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); verifySchema(actual, schema("count()", "bigint")); verifyDataRowsInOrder(actual, rows(30)); @@ -353,7 +354,7 @@ public void testCrossJoinWithJoinCriteriaFallbackToInnerJoin() throws IOExceptio assertJsonEquals(cross.toString(), inner.toString()); } - @Ignore // TODO seems a calcite bug + @Test public void testMultipleJoins() throws IOException { JSONObject actual = executeQuery( @@ -366,7 +367,7 @@ public void testMultipleJoins() throws IOException { + " a.a_name = b.name %s| eval aa_country = a.a_country| eval ab_country =" + " a.b_country| eval bb_country = b.country| fields a_name, age, state," + " aa_country, occupation, ab_country, salary, bb_country, hobby, language|" - + " cross join left=a, right=b %s| eval new_country = a.aa_country| eval" + + " cross join left=a, right=b on 1=1 %s| eval new_country = a.aa_country| eval" + " new_salary = b.salary| stats avg(new_salary) as avg_salary by span(age, 5)" + " as age_span, state| left semi join left=a, right=b ON a.state = b.state " + " %s| eval new_avg_salary = floor(avg_salary)| fields state, age_span," @@ -376,9 +377,10 @@ public void testMultipleJoins() throws IOException { TEST_INDEX_HOBBIES, TEST_INDEX_OCCUPATION, TEST_INDEX_STATE_COUNTRY)); + verifyNumOfRows(actual, 2); } - @Ignore // TODO seems a calcite bug + @Test public void testMultipleJoinsWithRelationSubquery() throws IOException { JSONObject actual = executeQuery( @@ -391,17 +393,18 @@ public void testMultipleJoinsWithRelationSubquery() throws IOException { + " ON a.a_name = b.name [ source = %s ]| eval aa_country =" + " a.a_country| eval ab_country = a.b_country| eval bb_country = b.country|" + " fields a_name, age, state, aa_country, occupation, ab_country, salary," - + " bb_country, hobby, language| cross join left=a, right=b [ source =" - + " %s ]| eval new_country = a.aa_country| eval new_salary = b.salary| stats" - + " avg(new_salary) as avg_salary by span(age, 5) as age_span, state| left semi" - + " join left=a, right=b ON a.state = b.state [ source = %s ]|" - + " eval new_avg_salary = floor(avg_salary)| fields state, age_span," - + " new_avg_salary", + + " bb_country, hobby, language| cross join left=a, right=b on 1=1 [ " + + " source = %s ]| eval new_country = a.aa_country| eval new_salary =" + + " b.salary| stats avg(new_salary) as avg_salary by span(age, 5) as age_span," + + " state| left semi join left=a, right=b ON a.state = b.state [ " + + " source = %s ]| eval new_avg_salary = floor(avg_salary)| fields state," + + " age_span, new_avg_salary", TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION, TEST_INDEX_HOBBIES, TEST_INDEX_OCCUPATION, TEST_INDEX_STATE_COUNTRY)); + verifyNumOfRows(actual, 2); } @Test @@ -723,4 +726,216 @@ public void testLeftJoinWithRelationSubquery() throws IOException { verifyDataRows( actual, rows(70000.0, 30, "USA"), rows(null, 40, null), rows(100000, 70, "England")); } + + @Test + public void testJoinWithFieldList() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner name,year,month %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int"), + schema("occupation", "string"), + schema("salary", "int")); + JSONObject actual2 = + executeQuery( + String.format( + "source=%s | join type=inner name,year,month %s | fields name, country", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyDataRows( + actual2, + rows("Jake", "England"), + rows("Hello", "USA"), + rows("John", "Canada"), + rows("Jane", "Canada"), + rows("David", "USA"), + rows("David", "Canada")); + } + + @Test + public void testJoinWithFieldList2() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner overwrite=false name,year,month %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int"), + schema("occupation", "string"), + schema("salary", "int")); + JSONObject actual2 = + executeQuery( + String.format( + "source=%s | join type=inner overwrite=false name,year,month %s | fields" + + " name, country", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyDataRows( + actual2, + rows("Jake", "USA"), + rows("Hello", "USA"), + rows("John", "Canada"), + rows("Jane", "Canada"), + rows("David", "USA"), + rows("David", "USA")); + } + + @Test + public void testJoinWithFieldListSelfJoin() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join name,year,month %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_STATE_COUNTRY)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int")); + verifyNumOfRows(actual, 8); + } + + @Test + public void testJoinWithFieldListSelfJoin2() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner overwrite=true name,year,month %s | join" + + " type=left overwrite=false name,year,month %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_STATE_COUNTRY, TEST_INDEX_STATE_COUNTRY)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int")); + verifyNumOfRows(actual, 8); + } + + @Test + public void testJoinWithoutFieldList() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner overwrite=false %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_STATE_COUNTRY)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int")); + verifyNumOfRows(actual, 8); + } + + @Test + public void testJoinWithFieldListMaxEqualsOne() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner max=1 name,year,month %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int"), + schema("occupation", "string"), + schema("salary", "int")); + JSONObject actual2 = + executeQuery( + String.format( + "source=%s | join type=inner max=1 name,year,month %s | fields name, country", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyDataRows( + actual2, + rows("Jake", "England"), + rows("Jane", "Canada"), + rows("John", "Canada"), + rows("Hello", "USA"), + rows("David", "USA")); + } + + @Test + public void testJoinComparing() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join type=inner max=0 country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 15); + actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join type=inner max=1 country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 5); + actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join type=inner max=2 country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 10); + actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join max=0 left=l right=r on l.country =" + + " r.country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 15); + actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join max=1 left=l right=r on l.country =" + + " r.country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 5); + actual = + executeQuery( + String.format( + "source=%s | where country = 'Canada' | join max=2 left=l right=r on l.country =" + + " r.country %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_OCCUPATION)); + verifyNumOfRows(actual, 10); + } + + @Test + public void testJoinWithoutFieldListMaxEqualsOne() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | join type=inner overwrite=false max=1 %s", + TEST_INDEX_STATE_COUNTRY, TEST_INDEX_STATE_COUNTRY)); + verifySchema( + actual, + schema("name", "string"), + schema("age", "int"), + schema("state", "string"), + schema("country", "string"), + schema("year", "int"), + schema("month", "int")); + verifyNumOfRows(actual, 8); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index faf1b8784ad..4e143154fe7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -468,7 +468,7 @@ public void testStatsByTimeSpan() throws IOException { String.format("source=%s | stats count() by span(birthdate,1M)", TEST_INDEX_BANK))); } - @Test + @Ignore("https://github.com/opensearch-project/OpenSearch/issues/3725") public void testDedupPushdown() throws IOException { String expected = loadExpectedPlan("explain_dedup_push.json"); assertJsonEqualsIgnoreId( @@ -488,7 +488,7 @@ public void testDedupKeepEmptyTruePushdown() throws IOException { + " | dedup gender KEEPEMPTY=true")); } - @Test + @Ignore("https://github.com/opensearch-project/OpenSearch/issues/3725") public void testDedupKeepEmptyFalsePushdown() throws IOException { String expected = loadExpectedPlan("explain_dedup_keepempty_false_push.json"); assertJsonEqualsIgnoreId( diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PPLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PPLIntegTestCase.java index db1ee17f344..4800bad4e06 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/PPLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PPLIntegTestCase.java @@ -168,6 +168,23 @@ public static void disableCalcite() throws IOException { "persistent", Settings.Key.CALCITE_ENGINE_ENABLED.getKeyValue(), "false")); } + public static void withCalciteEnabled(Runnable f) throws IOException { + boolean isCalciteEnabled = isCalciteEnabled(); + if (isCalciteEnabled) f.run(); + else { + try { + updateClusterSettings( + new SQLIntegTestCase.ClusterSetting( + "persistent", Key.CALCITE_ENGINE_ENABLED.getKeyValue(), "true")); + f.run(); + } finally { + updateClusterSettings( + new SQLIntegTestCase.ClusterSetting( + "persistent", Settings.Key.CALCITE_ENGINE_ENABLED.getKeyValue(), "false")); + } + } + } + public static void allowCalciteFallback() throws IOException { updateClusterSettings( new SQLIntegTestCase.ClusterSetting( @@ -208,6 +225,12 @@ public static void withFallbackEnabled(Runnable f, String msg) throws IOExceptio } } + public static void supportAllJoinTypes() throws IOException { + updateClusterSettings( + new SQLIntegTestCase.ClusterSetting( + "persistent", Key.CALCITE_SUPPORT_ALL_JOIN_TYPES.getKeyValue(), "true")); + } + public static void withSettings(Key setting, String value, Runnable f) throws IOException { String originalValue = getClusterSetting(setting.getKeyValue(), "transient"); if (originalValue.equals(value)) f.run(); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_criteria_max_option.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_criteria_max_option.json new file mode 100644 index 00000000000..08db116a7c9 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_criteria_max_option.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], r.account_number=[$13], r.firstname=[$14], r.address=[$15], r.birthdate=[$16], r.gender=[$17], r.city=[$18], r.lastname=[$19], r.balance=[$20], r.employer=[$21], r.state=[$22], r.age=[$23], r.email=[$24], r.male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[inner])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n LogicalFilter(condition=[<=($13, 1)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n LogicalFilter(condition=[IS NOT NULL($0)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $13)], joinType=[inner])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number, firstname, address, birthdate, gender, city, lastname, balance, employer, state, age, email, male], SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"birthdate\",\"gender\",\"city\",\"lastname\",\"balance\",\"employer\",\"state\",\"age\",\"email\",\"male\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number, firstname, address, birthdate, gender, city, lastname, balance, employer, state, age, email, male], FILTER->IS NOT NULL($0), COLLAPSE->account_number, SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"query\":{\"exists\":{\"field\":\"account_number\",\"boost\":1.0}},\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"birthdate\",\"gender\",\"city\",\"lastname\",\"balance\",\"employer\",\"state\",\"age\",\"email\",\"male\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}],\"collapse\":{\"field\":\"account_number\"}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields.json new file mode 100644 index 00000000000..0a662c047ee --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$13], firstname=[$14], address=[$15], birthdate=[$16], gender=[$17], city=[$18], lastname=[$19], balance=[$20], employer=[$21], state=[$22], age=[$23], email=[$24], male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[left])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableCalc(expr#0..13=[{inputs}], account_number=[$t1], firstname=[$t2], address=[$t3], birthdate=[$t4], gender=[$t5], city=[$t6], lastname=[$t7], balance=[$t8], employer=[$t9], state=[$t10], age=[$t11], email=[$t12], male=[$t13])\n EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $1)], joinType=[left])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number], SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number, firstname, address, birthdate, gender, city, lastname, balance, employer, state, age, email, male], SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"birthdate\",\"gender\",\"city\",\"lastname\",\"balance\",\"employer\",\"state\",\"age\",\"email\",\"male\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields_max_option.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields_max_option.json new file mode 100644 index 00000000000..c1ee2aa0b30 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_join_with_fields_max_option.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$13], firstname=[$14], address=[$15], birthdate=[$16], gender=[$17], city=[$18], lastname=[$19], balance=[$20], employer=[$21], state=[$22], age=[$23], email=[$24], male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[inner])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n LogicalFilter(condition=[<=($13, 1)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n LogicalFilter(condition=[IS NOT NULL($0)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableCalc(expr#0..13=[{inputs}], proj#0..12=[{exprs}])\n EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $13)], joinType=[inner])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number, firstname, address, birthdate, gender, city, lastname, balance, employer, state, age, email, male], FILTER->IS NOT NULL($0), COLLAPSE->account_number, SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"query\":{\"exists\":{\"field\":\"account_number\",\"boost\":1.0}},\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"birthdate\",\"gender\",\"city\",\"lastname\",\"balance\",\"employer\",\"state\",\"age\",\"email\",\"male\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}],\"collapse\":{\"field\":\"account_number\"}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number], SORT->[{\n \"account_number\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_last\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\"],\"excludes\":[]},\"sort\":[{\"account_number\":{\"order\":\"asc\",\"missing\":\"_last\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_criteria_max_option.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_criteria_max_option.json new file mode 100644 index 00000000000..11ca44cdea2 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_criteria_max_option.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], r.account_number=[$13], r.firstname=[$14], r.address=[$15], r.birthdate=[$16], r.gender=[$17], r.city=[$18], r.lastname=[$19], r.balance=[$20], r.employer=[$21], r.state=[$22], r.age=[$23], r.email=[$24], r.male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[inner])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n LogicalFilter(condition=[<=($13, 1)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n LogicalFilter(condition=[IS NOT NULL($0)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $13)], joinType=[inner])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..18=[{inputs}], proj#0..12=[{exprs}])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..19=[{inputs}], expr#20=[1], expr#21=[<=($t19, $t20)], proj#0..12=[{exprs}], $condition=[$t21])\n EnumerableWindow(window#0=[window(partition {0} order by [0] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[IS NOT NULL($t0)], proj#0..18=[{exprs}], $condition=[$t19])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields.json new file mode 100644 index 00000000000..21cbcfab737 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$13], firstname=[$14], address=[$15], birthdate=[$16], gender=[$17], city=[$18], lastname=[$19], balance=[$20], employer=[$21], state=[$22], age=[$23], email=[$24], male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[left])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableCalc(expr#0..13=[{inputs}], account_number=[$t1], firstname=[$t2], address=[$t3], birthdate=[$t4], gender=[$t5], city=[$t6], lastname=[$t7], balance=[$t8], employer=[$t9], state=[$t10], age=[$t11], email=[$t12], male=[$t13])\n EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $1)], joinType=[left])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..18=[{inputs}], account_number=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..18=[{inputs}], proj#0..12=[{exprs}])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields_max_option.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields_max_option.json new file mode 100644 index 00000000000..a2b931bba32 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_join_with_fields_max_option.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$13], firstname=[$14], address=[$15], birthdate=[$16], gender=[$17], city=[$18], lastname=[$19], balance=[$20], employer=[$21], state=[$22], age=[$23], email=[$24], male=[$25])\n LogicalJoin(condition=[=($0, $13)], joinType=[inner])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n LogicalFilter(condition=[<=($13, 1)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n LogicalFilter(condition=[IS NOT NULL($0)])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableCalc(expr#0..13=[{inputs}], account_number=[$t1], firstname=[$t2], address=[$t3], birthdate=[$t4], gender=[$t5], city=[$t6], lastname=[$t7], balance=[$t8], employer=[$t9], state=[$t10], age=[$t11], email=[$t12], male=[$t13])\n EnumerableLimit(fetch=[10000])\n EnumerableMergeJoin(condition=[=($0, $1)], joinType=[inner])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..18=[{inputs}], account_number=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..19=[{inputs}], expr#20=[1], expr#21=[<=($t19, $t20)], proj#0..12=[{exprs}], $condition=[$t21])\n EnumerableWindow(window#0=[window(partition {0} order by [0] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[IS NOT NULL($t0)], proj#0..18=[{exprs}], $condition=[$t19])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n" + } +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index 2b6352537a1..f2bea42e5ab 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -37,7 +37,8 @@ public class OpenSearchIndexRules { COUNT_STAR_INDEX_SCAN, LIMIT_INDEX_SCAN, SORT_INDEX_SCAN, - DEDUP_PUSH_DOWN, + // TODO enable if https://github.com/opensearch-project/OpenSearch/issues/3725 resolved + // DEDUP_PUSH_DOWN, SORT_PROJECT_EXPR_TRANSPOSE, EXPAND_COLLATION_ON_PROJECT_EXPR); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 5f5240e67a2..9a695b8cc39 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -135,6 +135,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting CALCITE_SUPPORT_ALL_JOIN_TYPES_SETTING = + Setting.boolSetting( + Key.CALCITE_SUPPORT_ALL_JOIN_TYPES.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + public static final Setting QUERY_MEMORY_LIMIT_SETTING = Setting.memorySizeSetting( Key.QUERY_MEMORY_LIMIT.getKeyValue(), @@ -365,6 +372,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR, CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING, new Updater(Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR)); + register( + settingBuilder, + clusterSettings, + Key.CALCITE_SUPPORT_ALL_JOIN_TYPES, + CALCITE_SUPPORT_ALL_JOIN_TYPES_SETTING, + new Updater(Key.CALCITE_SUPPORT_ALL_JOIN_TYPES)); register( settingBuilder, clusterSettings, @@ -541,6 +554,7 @@ public static List> pluginSettings() { .add(CALCITE_FALLBACK_ALLOWED_SETTING) .add(CALCITE_PUSHDOWN_ENABLED_SETTING) .add(CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING) + .add(CALCITE_SUPPORT_ALL_JOIN_TYPES_SETTING) .add(DEFAULT_PATTERN_METHOD_SETTING) .add(DEFAULT_PATTERN_MODE_SETTING) .add(DEFAULT_PATTERN_MAX_SAMPLE_COUNT_SETTING) diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index b8f12d3c491..94bd652bff4 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -87,6 +87,7 @@ STANDARD: 'STANDARD'; COST: 'COST'; EXTENDED: 'EXTENDED'; OVERRIDE: 'OVERRIDE'; +OVERWRITE: 'OVERWRITE'; // SORT FIELD KEYWORDS // TODO #3180: Fix broken sort functionality diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index c757ef497d2..24991b56ccc 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -8,6 +8,7 @@ parser grammar OpenSearchPPLParser; options { tokenVocab = OpenSearchPPLLexer; } + root : pplStatement? EOF ; @@ -459,25 +460,37 @@ sourceFilterArg // join joinCommand - : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableOrSubqueryClause + : JOIN (joinOption)* (fieldList)? right = tableOrSubqueryClause + | sqlLikeJoinType? JOIN (joinOption)* sideAlias joinHintList? joinCriteria right = tableOrSubqueryClause ; -joinType - : INNER? +sqlLikeJoinType + : INNER | CROSS - | LEFT OUTER? + | (LEFT OUTER? | OUTER) | RIGHT OUTER? | FULL OUTER? | LEFT? SEMI | LEFT? ANTI ; +joinType + : INNER + | CROSS + | OUTER + | LEFT + | RIGHT + | FULL + | SEMI + | ANTI + ; + sideAlias : (LEFT EQUAL leftAlias = qualifiedName)? COMMA? (RIGHT EQUAL rightAlias = qualifiedName)? ; joinCriteria - : ON logicalExpression + : (ON | WHERE) logicalExpression ; joinHintList @@ -489,6 +502,12 @@ hintPair | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint ; +joinOption + : OVERWRITE EQUAL booleanLiteral # overwriteOption + | TYPE EQUAL joinType # typeOption + | MAX EQUAL integerLiteral # maxOption + ; + renameClasue : orignalField = renameFieldExpression AS renamedField = renameFieldExpression ; @@ -682,7 +701,7 @@ tableFunction // fields fieldList - : fieldExpression (COMMA fieldExpression)* + : fieldExpression ((COMMA)? fieldExpression)* ; sortField @@ -1301,8 +1320,8 @@ keywordsCanBeId | multiFieldRelevanceFunctionName | commandName | collectionFunctionName - | comparisonOperator | explainMode + | REGEXP // commands assist keywords | CASE | ELSE diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 99ed93a5731..874450c3a8a 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -203,21 +203,40 @@ public UnresolvedPlan visitWhereCommand(WhereCommandContext ctx) { @Override public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ctx) { - Join.JoinType joinType = getJoinType(ctx.joinType()); - if (ctx.joinCriteria() == null) { - joinType = Join.JoinType.CROSS; + // a sql-like syntax if join criteria existed + boolean sqlLike = ctx.joinCriteria() != null; + Join.JoinType joinType = null; + if (sqlLike) { + joinType = ArgumentFactory.getJoinType(ctx.sqlLikeJoinType()); } + List arguments = + ctx.joinOption().stream().map(o -> (Argument) expressionBuilder.visit(o)).toList(); + Argument.ArgumentMap argumentMap = Argument.ArgumentMap.of(arguments); + if (argumentMap.get("type") != null) { + Join.JoinType joinTypeFromArgument = ArgumentFactory.getJoinType(argumentMap); + if (sqlLike && joinType != joinTypeFromArgument) { + throw new SemanticCheckException( + "Join type is ambiguous, remove either the join type before JOIN keyword or 'type='" + + " option."); + } + joinType = joinTypeFromArgument; + } + if (!sqlLike && argumentMap.get("type") == null) { + joinType = Join.JoinType.INNER; + } + validateJoinType(joinType); + Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - Optional leftAlias = - ctx.sideAlias().leftAlias != null - ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) - : Optional.empty(); + Optional leftAlias = Optional.empty(); Optional rightAlias = Optional.empty(); + if (ctx.sideAlias() != null && ctx.sideAlias().leftAlias != null) { + leftAlias = Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()); + } if (ctx.tableOrSubqueryClause().alias != null) { rightAlias = Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString()); } - if (ctx.sideAlias().rightAlias != null) { + if (ctx.sideAlias() != null && ctx.sideAlias().rightAlias != null) { rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString()); } @@ -236,8 +255,19 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); + Optional> joinFields = Optional.empty(); + if (ctx.fieldList() != null) { + joinFields = Optional.of(getFieldList(ctx.fieldList())); + } return new Join( - projectExceptMeta(right), leftAlias, rightAlias, joinType, joinCondition, joinHint); + projectExceptMeta(right), + leftAlias, + rightAlias, + joinType, + joinCondition, + joinHint, + joinFields, + argumentMap); } private Join.JoinHint getJoinHint(OpenSearchPPLParser.JoinHintListContext ctx) { @@ -261,28 +291,16 @@ private Join.JoinHint getJoinHint(OpenSearchPPLParser.JoinHintListContext ctx) { return joinHint; } - private Join.JoinType getJoinType(OpenSearchPPLParser.JoinTypeContext ctx) { - Join.JoinType joinType; - if (ctx == null) { - joinType = Join.JoinType.INNER; - } else if (ctx.INNER() != null) { - joinType = Join.JoinType.INNER; - } else if (ctx.SEMI() != null) { - joinType = Join.JoinType.SEMI; - } else if (ctx.ANTI() != null) { - joinType = Join.JoinType.ANTI; - } else if (ctx.LEFT() != null) { - joinType = Join.JoinType.LEFT; - } else if (ctx.RIGHT() != null) { - joinType = Join.JoinType.RIGHT; - } else if (ctx.CROSS() != null) { - joinType = Join.JoinType.CROSS; - } else if (ctx.FULL() != null) { - joinType = Join.JoinType.FULL; - } else { - joinType = Join.JoinType.INNER; + private void validateJoinType(Join.JoinType joinType) { + Object config = settings.getSettingValue(Key.CALCITE_SUPPORT_ALL_JOIN_TYPES); + if (config != null && !((Boolean) config)) { + if (Join.highCostJoinTypes().contains(joinType)) { + throw new SemanticCheckException( + String.format( + "Join type %s is performance sensitive. Set %s to true to enable it.", + joinType.name(), Key.CALCITE_SUPPORT_ALL_JOIN_TYPES.getKeyValue())); + } } - return joinType; } @Override diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 519ce384038..6451ea258ca 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -6,40 +6,6 @@ package org.opensearch.sql.ppl.parser; import static org.opensearch.sql.expression.function.BuiltinFunctionName.*; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountEvalFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DoubleLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FloatLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsQualifiedNameContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsTableQualifiedNameContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RenameFieldExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -60,6 +26,7 @@ import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; @@ -98,6 +65,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountEvalFunctionCallContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; @@ -118,6 +86,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RenameFieldExpressionContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; @@ -732,6 +701,21 @@ public UnresolvedExpression visitWindowFunction(OpenSearchPPLParser.WindowFuncti return new WindowFunction(f); } + @Override + public UnresolvedExpression visitOverwriteOption(OpenSearchPPLParser.OverwriteOptionContext ctx) { + return new Argument("overwrite", (Literal) this.visit(ctx.booleanLiteral())); + } + + @Override + public UnresolvedExpression visitJoinType(OpenSearchPPLParser.JoinTypeContext ctx) { + return ArgumentFactory.getArgumentValue(ctx); + } + + @Override + public UnresolvedExpression visitMaxOption(OpenSearchPPLParser.MaxOptionContext ctx) { + return new Argument("max", (Literal) this.visit(ctx.integerLiteral())); + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index ebe3d6d9e2f..2f64ba907fa 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -5,15 +5,6 @@ package org.opensearch.sql.ppl.utils; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RareCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TopCommandContext; - import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -21,7 +12,10 @@ import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; @@ -161,4 +155,66 @@ private static Literal getArgumentValue(ParserRuleContext ctx) { ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } + + /** + * parse argument value into Literal. + * + * @param ctx ParserRuleContext instance + * @return Literal + */ + public static Argument getArgumentValue(OpenSearchPPLParser.JoinTypeContext ctx) { + Join.JoinType type = getJoinType(ctx); + return new Argument("type", new Literal(type.name(), DataType.STRING)); + } + + public static Join.JoinType getJoinType(OpenSearchPPLParser.SqlLikeJoinTypeContext ctx) { + if (ctx == null) return Join.JoinType.INNER; + if (ctx.INNER() != null) return Join.JoinType.INNER; + if (ctx.SEMI() != null) return Join.JoinType.SEMI; + if (ctx.ANTI() != null) return Join.JoinType.ANTI; + if (ctx.LEFT() != null) return Join.JoinType.LEFT; + if (ctx.RIGHT() != null) return Join.JoinType.RIGHT; + if (ctx.CROSS() != null) return Join.JoinType.CROSS; + if (ctx.FULL() != null) return Join.JoinType.FULL; + if (ctx.OUTER() != null) return Join.JoinType.LEFT; + throw new SemanticCheckException(String.format("Unsupported join type %s", ctx.getText())); + } + + public static Join.JoinType getJoinType(OpenSearchPPLParser.JoinTypeContext ctx) { + if (ctx == null) return Join.JoinType.INNER; + if (ctx.INNER() != null) return Join.JoinType.INNER; + if (ctx.SEMI() != null) return Join.JoinType.SEMI; + if (ctx.ANTI() != null) return Join.JoinType.ANTI; + if (ctx.LEFT() != null) return Join.JoinType.LEFT; + if (ctx.RIGHT() != null) return Join.JoinType.RIGHT; + if (ctx.CROSS() != null) return Join.JoinType.CROSS; + if (ctx.FULL() != null) return Join.JoinType.FULL; + if (ctx.OUTER() != null) return Join.JoinType.LEFT; + throw new SemanticCheckException(String.format("Unsupported join type %s", ctx.getText())); + } + + public static Join.JoinType getJoinType(Argument.ArgumentMap argumentMap) { + Join.JoinType joinType; + String type = argumentMap.get("type").toString(); + if (type.equalsIgnoreCase(Join.JoinType.INNER.name())) { + joinType = Join.JoinType.INNER; + } else if (type.equalsIgnoreCase(Join.JoinType.SEMI.name())) { + joinType = Join.JoinType.SEMI; + } else if (type.equalsIgnoreCase(Join.JoinType.ANTI.name())) { + joinType = Join.JoinType.ANTI; + } else if (type.equalsIgnoreCase(Join.JoinType.LEFT.name())) { + joinType = Join.JoinType.LEFT; + } else if (type.equalsIgnoreCase(Join.JoinType.RIGHT.name())) { + joinType = Join.JoinType.RIGHT; + } else if (type.equalsIgnoreCase(Join.JoinType.CROSS.name())) { + joinType = Join.JoinType.CROSS; + } else if (type.equalsIgnoreCase(Join.JoinType.FULL.name())) { + joinType = Join.JoinType.FULL; + } else if (type.equalsIgnoreCase("OUTER")) { + joinType = Join.JoinType.LEFT; + } else { + throw new SemanticCheckException(String.format("Supported join type %s", type)); + } + return joinType; + } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 3c5001f39df..6ff0d4f3306 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -158,13 +158,41 @@ public String visitJoin(Join node, String context) { rightTableOrSubquery.startsWith("source=") ? rightTableOrSubquery.substring("source=".length()) : rightTableOrSubquery; - String joinType = node.getJoinType().name().toLowerCase(Locale.ROOT); - String leftAlias = node.getLeftAlias().map(l -> " left = " + l).orElse(""); - String rightAlias = node.getRightAlias().map(r -> " right = " + r).orElse(""); - String condition = - node.getJoinCondition().map(c -> expressionAnalyzer.analyze(c, context)).orElse("true"); - return StringUtils.format( - "%s | %s join%s%s on %s %s", left, joinType, leftAlias, rightAlias, condition, right); + Argument.ArgumentMap argumentMap = node.getArgumentMap(); + String max = + argumentMap.get("max") == null + ? "0" + : argumentMap.get("max").toString().toLowerCase(Locale.ROOT); + if (node.getJoinCondition().isEmpty()) { + String joinType = + argumentMap.get("type") == null + ? "inner" + : argumentMap.get("type").toString().toLowerCase(Locale.ROOT); + String overwrite = + argumentMap.get("overwrite") == null + ? "true" + : argumentMap.get("overwrite").toString().toLowerCase(Locale.ROOT); + String fieldList = + node.getJoinFields().isEmpty() + ? "" + : String.join( + ",", + node.getJoinFields().get().stream() + .map(c -> expressionAnalyzer.analyze(c, context)) + .toList()); + return StringUtils.format( + "%s | join type=%s overwrite=%s max=%s %s %s", + left, joinType, overwrite, max, fieldList, right); + } else { + String joinType = node.getJoinType().name().toLowerCase(Locale.ROOT); + String leftAlias = node.getLeftAlias().map(l -> " left = " + l).orElse(""); + String rightAlias = node.getRightAlias().map(r -> " right = " + r).orElse(""); + String condition = + node.getJoinCondition().map(c -> expressionAnalyzer.analyze(c, context)).orElse("true"); + return StringUtils.format( + "%s | %s join max=%s%s%s on %s %s", + left, joinType, max, leftAlias, rightAlias, condition, right); + } } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java index 652d6e77e3a..01bfafbf3e8 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java @@ -33,6 +33,7 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelRunners; import org.apache.commons.lang3.StringUtils; +import org.junit.Before; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.calcite.CalcitePlanContext; @@ -48,6 +49,8 @@ public class CalcitePPLAbstractTest { private final CalciteRelNodeVisitor planTransformer; private final RelToSqlConverter converter; protected final Settings settings; + public PPLSyntaxParser pplParser = new PPLSyntaxParser(); + ; public CalcitePPLAbstractTest(CalciteAssert.SchemaSpec... schemaSpecs) { this.config = config(schemaSpecs); @@ -56,7 +59,11 @@ public CalcitePPLAbstractTest(CalciteAssert.SchemaSpec... schemaSpecs) { this.settings = mock(Settings.class); } - public PPLSyntaxParser pplParser = new PPLSyntaxParser(); + @Before + public void init() { + doReturn(true).when(settings).getSettingValue(Settings.Key.CALCITE_ENGINE_ENABLED); + doReturn(true).when(settings).getSettingValue(Settings.Key.CALCITE_SUPPORT_ALL_JOIN_TYPES); + } protected Frameworks.ConfigBuilder config(CalciteAssert.SchemaSpec... schemaSpecs) { final SchemaPlus rootSchema = Frameworks.createRootSchema(true); @@ -91,7 +98,6 @@ public RelNode getRelNode(String ppl) { } private Node plan(PPLSyntaxParser parser, String query) { - doReturn(true).when(settings).getSettingValue(Settings.Key.CALCITE_ENGINE_ENABLED); final AstStatementBuilder builder = new AstStatementBuilder( new AstBuilder(query, settings), diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java index 508b7a9d9bc..e8f0390a666 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLJoinTest.java @@ -5,9 +5,14 @@ package org.opensearch.sql.ppl.calcite; +import static org.mockito.Mockito.doReturn; + import org.apache.calcite.rel.RelNode; import org.apache.calcite.test.CalciteAssert; +import org.junit.Assert; import org.junit.Test; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.SemanticCheckException; public class CalcitePPLJoinTest extends CalcitePPLAbstractTest { @@ -247,7 +252,7 @@ public void testFullOuter() { @Test public void testCrossJoin() { - String ppl = "source=EMP as e | cross join DEPT as d"; + String ppl = "source=EMP as e | cross join on 1=1 DEPT as d"; RelNode root = getRelNode(ppl); String expectedLogical = "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," @@ -518,37 +523,45 @@ public void testMultipleJoinsWithRelationSubquery() { source = BONUS | where JOB = 'SALESMAN' ] - | cross join left = l right = r + | join type=left overwrite=true SAL [ source = SALGRADE | where LOSAL <= 1500 | sort - GRADE + | rename HISAL as SAL ] """; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalJoin(condition=[true], joinType=[inner])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], COMM=[$6]," + + " DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10], r.ENAME=[$11], r.JOB=[$12]," + + " r.SAL=[$13], r.COMM=[$14], GRADE=[$15], LOSAL=[$16], SAL=[$17])\n" + + " LogicalJoin(condition=[=($5, $17)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10]," + " r.ENAME=[$11], r.JOB=[$12], r.SAL=[$13], r.COMM=[$14])\n" - + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" - + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" - + " LogicalSort(fetch=[10])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" - + " LogicalTableScan(table=[[scott, DEPT]])\n" - + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" - + " LogicalTableScan(table=[[scott, BONUS]])\n" - + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" - + " LogicalFilter(condition=[<=($1, 1500)])\n" - + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(GRADE=[$0], LOSAL=[$1], SAL=[$2])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; verifyLogical(root, expectedLogical); - verifyResultCount(root, 15); + verifyResultCount(root, 5); String expectedSparkSql = - "SELECT *\n" + "SELECT `t3`.`EMPNO`, `t3`.`ENAME`, `t3`.`JOB`, `t3`.`MGR`, `t3`.`HIREDATE`, `t3`.`COMM`," + + " `t3`.`DEPTNO`, `t3`.`r.DEPTNO`, `t3`.`DNAME`, `t3`.`LOC`, `t3`.`r.ENAME`," + + " `t3`.`r.JOB`, `t3`.`r.SAL`, `t3`.`r.COMM`, `t6`.`GRADE`, `t6`.`LOSAL`," + + " `t6`.`SAL`\n" + "FROM (SELECT `t1`.`EMPNO`, `t1`.`ENAME`, `t1`.`JOB`, `t1`.`MGR`, `t1`.`HIREDATE`," + " `t1`.`SAL`, `t1`.`COMM`, `t1`.`DEPTNO`, `t1`.`r.DEPTNO`, `t1`.`DNAME`, `t1`.`LOC`," + " `t2`.`ENAME` `r.ENAME`, `t2`.`JOB` `r.JOB`, `t2`.`SAL` `r.SAL`, `t2`.`COMM`" @@ -566,10 +579,10 @@ public void testMultipleJoinsWithRelationSubquery() { + "LEFT JOIN (SELECT *\n" + "FROM `scott`.`BONUS`\n" + "WHERE `JOB` = 'SALESMAN') `t2` ON `t1`.`JOB` = `t2`.`JOB`) `t3`\n" - + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "LEFT JOIN (SELECT `GRADE`, `LOSAL`, `HISAL` `SAL`\n" + "FROM `scott`.`SALGRADE`\n" + "WHERE `LOSAL` <= 1500\n" - + "ORDER BY `GRADE` DESC) `t5`"; + + "ORDER BY `GRADE` DESC) `t6` ON `t3`.`SAL` = `t6`.`SAL`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -589,47 +602,55 @@ public void testMultipleJoinsWithRelationSubqueryWithAlias() { source = BONUS as t3 | where JOB = 'SALESMAN' ] - | cross join + | join type=left overwrite=true SAL [ source = SALGRADE as t4 | where LOSAL <= 1500 | sort - GRADE + | rename HISAL as SAL ] """; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalJoin(condition=[true], joinType=[inner])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], COMM=[$6]," + + " DEPTNO=[$7], DEPT.DEPTNO=[$8], DNAME=[$9], LOC=[$10], BONUS.ENAME=[$11]," + + " BONUS.JOB=[$12], BONUS.SAL=[$13], BONUS.COMM=[$14], GRADE=[$15], LOSAL=[$16]," + + " SAL=[$17])\n" + + " LogicalJoin(condition=[=($5, $17)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPT.DEPTNO=[$8], DNAME=[$9], LOC=[$10]," - + " BONUS.ENAME=[$11], BONUS.JOB=[$12], BONUS.SAL=[$13]," - + " BONUS.COMM=[$14])\n" - + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " BONUS.ENAME=[$11], BONUS.JOB=[$12], BONUS.SAL=[$13], BONUS.COMM=[$14])\n" + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPT.DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" - + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" - + " LogicalSort(fetch=[10])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" - + " LogicalTableScan(table=[[scott, DEPT]])\n" - + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" - + " LogicalTableScan(table=[[scott, BONUS]])\n" - + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" - + " LogicalFilter(condition=[<=($1, 1500)])\n" - + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(GRADE=[$0], LOSAL=[$1], SAL=[$2])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; verifyLogical(root, expectedLogical); - verifyResultCount(root, 15); + verifyResultCount(root, 5); String expectedSparkSql = - "SELECT *\n" + "SELECT `t3`.`EMPNO`, `t3`.`ENAME`, `t3`.`JOB`, `t3`.`MGR`, `t3`.`HIREDATE`, `t3`.`COMM`," + + " `t3`.`DEPTNO`, `t3`.`DEPT.DEPTNO`, `t3`.`DNAME`, `t3`.`LOC`, `t3`.`BONUS.ENAME`," + + " `t3`.`BONUS.JOB`, `t3`.`BONUS.SAL`, `t3`.`BONUS.COMM`, `t6`.`GRADE`, `t6`.`LOSAL`," + + " `t6`.`SAL`\n" + "FROM (SELECT `t1`.`EMPNO`, `t1`.`ENAME`, `t1`.`JOB`, `t1`.`MGR`, `t1`.`HIREDATE`," + " `t1`.`SAL`, `t1`.`COMM`, `t1`.`DEPTNO`, `t1`.`DEPT.DEPTNO`, `t1`.`DNAME`," - + " `t1`.`LOC`, `t2`.`ENAME` `BONUS.ENAME`, `t2`.`JOB` `BONUS.JOB`," - + " `t2`.`SAL` `BONUS.SAL`, `t2`.`COMM` `BONUS.COMM`\n" + + " `t1`.`LOC`, `t2`.`ENAME` `BONUS.ENAME`, `t2`.`JOB` `BONUS.JOB`, `t2`.`SAL`" + + " `BONUS.SAL`, `t2`.`COMM` `BONUS.COMM`\n" + "FROM (SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`," - + " `t`.`SAL`, `t`.`COMM`, `t`.`DEPTNO`, `t0`.`DEPTNO` `DEPT.DEPTNO`," - + " `t0`.`DNAME`, `t0`.`LOC`\n" + + " `t`.`SAL`, `t`.`COMM`, `t`.`DEPTNO`, `t0`.`DEPTNO` `DEPT.DEPTNO`, `t0`.`DNAME`," + + " `t0`.`LOC`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + "FROM `scott`.`EMP`\n" + "LIMIT 10) `t`\n" @@ -640,10 +661,10 @@ public void testMultipleJoinsWithRelationSubqueryWithAlias() { + "LEFT JOIN (SELECT *\n" + "FROM `scott`.`BONUS`\n" + "WHERE `JOB` = 'SALESMAN') `t2` ON `t1`.`JOB` = `t2`.`JOB`) `t3`\n" - + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "LEFT JOIN (SELECT `GRADE`, `LOSAL`, `HISAL` `SAL`\n" + "FROM `scott`.`SALGRADE`\n" + "WHERE `LOSAL` <= 1500\n" - + "ORDER BY `GRADE` DESC) `t5`"; + + "ORDER BY `GRADE` DESC) `t6` ON `t3`.`SAL` = `t6`.`SAL`"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -663,38 +684,45 @@ public void testMultipleJoinsWithRelationSubqueryWithAlias2() { source = BONUS as t3 | where JOB = 'SALESMAN' ] - | cross join + | join type=left overwrite=true SAL [ source = SALGRADE as t4 | where LOSAL <= 1500 | sort - GRADE + | rename HISAL as SAL ] """; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalJoin(condition=[true], joinType=[inner])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], COMM=[$6]," + + " DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10], r.ENAME=[$11], r.JOB=[$12]," + + " r.SAL=[$13], r.COMM=[$14], GRADE=[$15], LOSAL=[$16], SAL=[$17])\n" + + " LogicalJoin(condition=[=($5, $17)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10]," + " r.ENAME=[$11], r.JOB=[$12], r.SAL=[$13], r.COMM=[$14])\n" - + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " LogicalJoin(condition=[=($2, $12)], joinType=[left])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" - + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" - + " LogicalSort(fetch=[10])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" - + " LogicalTableScan(table=[[scott, DEPT]])\n" - + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" - + " LogicalTableScan(table=[[scott, BONUS]])\n" - + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" - + " LogicalFilter(condition=[<=($1, 1500)])\n" - + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalSort(fetch=[10])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[AND(>($0, 10), =($2, 'CHICAGO':VARCHAR))])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalFilter(condition=[=($1, 'SALESMAN':VARCHAR)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(GRADE=[$0], LOSAL=[$1], SAL=[$2])\n" + + " LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" + + " LogicalFilter(condition=[<=($1, 1500)])\n" + + " LogicalTableScan(table=[[scott, SALGRADE]])\n"; verifyLogical(root, expectedLogical); - verifyResultCount(root, 15); + verifyResultCount(root, 5); String expectedSparkSql = - "SELECT *\n" + "SELECT `t3`.`EMPNO`, `t3`.`ENAME`, `t3`.`JOB`, `t3`.`MGR`, `t3`.`HIREDATE`, `t3`.`COMM`," + + " `t3`.`DEPTNO`, `t3`.`r.DEPTNO`, `t3`.`DNAME`, `t3`.`LOC`, `t3`.`r.ENAME`," + + " `t3`.`r.JOB`, `t3`.`r.SAL`, `t3`.`r.COMM`, `t6`.`GRADE`, `t6`.`LOSAL`, `t6`.`SAL`\n" + "FROM (SELECT `t1`.`EMPNO`, `t1`.`ENAME`, `t1`.`JOB`, `t1`.`MGR`, `t1`.`HIREDATE`," + " `t1`.`SAL`, `t1`.`COMM`, `t1`.`DEPTNO`, `t1`.`r.DEPTNO`, `t1`.`DNAME`, `t1`.`LOC`," + " `t2`.`ENAME` `r.ENAME`, `t2`.`JOB` `r.JOB`, `t2`.`SAL` `r.SAL`, `t2`.`COMM`" @@ -712,10 +740,340 @@ public void testMultipleJoinsWithRelationSubqueryWithAlias2() { + "LEFT JOIN (SELECT *\n" + "FROM `scott`.`BONUS`\n" + "WHERE `JOB` = 'SALESMAN') `t2` ON `t1`.`JOB` = `t2`.`JOB`) `t3`\n" - + "CROSS JOIN (SELECT `GRADE`, `LOSAL`, `HISAL`\n" + + "LEFT JOIN (SELECT `GRADE`, `LOSAL`, `HISAL` `SAL`\n" + "FROM `scott`.`SALGRADE`\n" + "WHERE `LOSAL` <= 1500\n" - + "ORDER BY `GRADE` DESC) `t5`"; + + "ORDER BY `GRADE` DESC) `t6` ON `t3`.`SAL` = `t6`.`SAL`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithFieldList() { + String ppl = "source=EMP | join DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `DEPT`.`DEPTNO`, `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testSplJoinWithJoinArguments() { + String ppl = "source=EMP | join type=inner overwrite=false DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `EMP`.`DEPTNO`, `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithFieldListAndJoinArguments2() { + String ppl = "source=EMP | join type=left overwrite=false DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `EMP`.`DEPTNO`, `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testSemiJoinWithFieldList() { + String ppl = "source=EMP | join type=semi overwrite=true DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalJoin(condition=[=($7, $8)], joinType=[semi])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE EXISTS (SELECT 1\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`)"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithoutFieldList() { + String ppl = "source=EMP | join DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `DEPT`.`DEPTNO`, `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithoutFieldListSelfJoin() { + String ppl = "source=EMP | join EMP"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$8], ENAME=[$9], JOB=[$10], MGR=[$11], HIREDATE=[$12], SAL=[$13]," + + " COMM=[$14], DEPTNO=[$15])\n" + + " LogicalJoin(condition=[AND(=($0, $8), =($1, $9), =($2, $10), =($3, $11), =($4," + + " $12), =($5, $13), =($6, $14), =($7, $15))], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 4); + + String expectedSparkSql = + "SELECT `EMP0`.`EMPNO`, `EMP0`.`ENAME`, `EMP0`.`JOB`, `EMP0`.`MGR`, `EMP0`.`HIREDATE`," + + " `EMP0`.`SAL`, `EMP0`.`COMM`, `EMP0`.`DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN `scott`.`EMP` `EMP0` ON `EMP`.`EMPNO` = `EMP0`.`EMPNO` AND `EMP`.`ENAME`" + + " = `EMP0`.`ENAME` AND (`EMP`.`JOB` = `EMP0`.`JOB` AND `EMP`.`MGR` = `EMP0`.`MGR`)" + + " AND (`EMP`.`HIREDATE` = `EMP0`.`HIREDATE` AND `EMP`.`SAL` = `EMP0`.`SAL` AND" + + " (`EMP`.`COMM` = `EMP0`.`COMM` AND `EMP`.`DEPTNO` = `EMP0`.`DEPTNO`))"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + @Test + public void testJoinWithMultiplePredicatesWithWhere() { + String ppl = + "source=EMP | join left = l right = r where l.DEPTNO = r.DEPTNO AND l.DEPTNO > 10 AND" + + " EMP.SAL < 3000 DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[AND(=($7, $8), >($7, 10), <($5, 3000))]," + + " joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 9); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `EMP`.`DEPTNO`, `DEPT`.`DEPTNO` `r.DEPTNO`," + + " `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO` AND `EMP`.`DEPTNO` >" + + " 10 AND `EMP`.`SAL` < 3000"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithFieldListSelfJoinOverrideIsFalse() { + String ppl = "source=EMP | join type=outer overwrite=false EMPNO ENAME JOB, MGR EMP"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7])\n" + + " LogicalJoin(condition=[AND(=($0, $8), =($1, $9), =($2, $10), =($3, $11))]," + + " joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `EMP`.`DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN `scott`.`EMP` `EMP0` ON `EMP`.`EMPNO` = `EMP0`.`EMPNO` AND `EMP`.`ENAME` =" + + " `EMP0`.`ENAME` AND `EMP`.`JOB` = `EMP0`.`JOB` AND `EMP`.`MGR` = `EMP0`.`MGR`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithFieldListSelfJoinOverrideIsTrue() { + String ppl = "source=EMP | join type=outer overwrite=true EMPNO ENAME JOB, MGR EMP"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$8], ENAME=[$9], JOB=[$10], MGR=[$11], HIREDATE=[$12], SAL=[$13]," + + " COMM=[$14], DEPTNO=[$15])\n" + + " LogicalJoin(condition=[AND(=($0, $8), =($1, $9), =($2, $10), =($3, $11))]," + + " joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP0`.`EMPNO`, `EMP0`.`ENAME`, `EMP0`.`JOB`, `EMP0`.`MGR`, `EMP0`.`HIREDATE`," + + " `EMP0`.`SAL`, `EMP0`.`COMM`, `EMP0`.`DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN `scott`.`EMP` `EMP0` ON `EMP`.`EMPNO` = `EMP0`.`EMPNO` AND `EMP`.`ENAME` =" + + " `EMP0`.`ENAME` AND `EMP`.`JOB` = `EMP0`.`JOB` AND `EMP`.`MGR` = `EMP0`.`MGR`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithFieldListSelfJoin() { + String ppl = "source=EMP | join type=outer EMPNO ENAME JOB, MGR EMP"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$8], ENAME=[$9], JOB=[$10], MGR=[$11], HIREDATE=[$12], SAL=[$13]," + + " COMM=[$14], DEPTNO=[$15])\n" + + " LogicalJoin(condition=[AND(=($0, $8), =($1, $9), =($2, $10), =($3, $11))]," + + " joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP0`.`EMPNO`, `EMP0`.`ENAME`, `EMP0`.`JOB`, `EMP0`.`MGR`, `EMP0`.`HIREDATE`," + + " `EMP0`.`SAL`, `EMP0`.`COMM`, `EMP0`.`DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN `scott`.`EMP` `EMP0` ON `EMP`.`EMPNO` = `EMP0`.`EMPNO` AND `EMP`.`ENAME` =" + + " `EMP0`.`ENAME` AND `EMP`.`JOB` = `EMP0`.`JOB` AND `EMP`.`MGR` = `EMP0`.`MGR`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testDisableHighCostJoinTypes() { + String ppl1 = "source=EMP as e | full join on e.DEPTNO = d.DEPTNO DEPT as d"; + String ppl2 = "source=EMP | join type=full overwrite=false EMPNO ENAME JOB, MGR EMP"; + String err = + "Join type FULL is performance sensitive. Set plugins.calcite.all_join_types.allowed to" + + " true to enable it."; + + // disable high cost join types + doReturn(false).when(settings).getSettingValue(Settings.Key.CALCITE_SUPPORT_ALL_JOIN_TYPES); + Throwable t = Assert.assertThrows(SemanticCheckException.class, () -> getRelNode(ppl1)); + verifyErrorMessageContains(t, err); + t = Assert.assertThrows(SemanticCheckException.class, () -> getRelNode(ppl2)); + verifyErrorMessageContains(t, err); + // enable high cost types + doReturn(true).when(settings).getSettingValue(Settings.Key.CALCITE_SUPPORT_ALL_JOIN_TYPES); + getRelNode(ppl1); + getRelNode(ppl2); + } + + @Test + public void testJoinWithFieldListMaxGreaterThanZero() { + String ppl = "source=EMP | join type=outer max=1 DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n" + + " LogicalFilter(condition=[<=($3, 1)])\n" + + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2]," + + " _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `t1`.`DEPTNO`, `t1`.`DNAME`, `t1`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN (SELECT `DEPTNO`, `DNAME`, `LOC`\n" + + "FROM (SELECT `DEPTNO`, `DNAME`, `LOC`, ROW_NUMBER() OVER (PARTITION BY `DEPTNO`" + + " ORDER BY `DEPTNO` NULLS LAST) `_row_number_dedup_`\n" + + "FROM `scott`.`DEPT`) `t`\n" + + "WHERE `_row_number_dedup_` <= 1) `t1` ON `EMP`.`DEPTNO` = `t1`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithCriteriaMaxGreaterThanZero() { + String ppl = "source=EMP | outer join max=1 left=l right=r on l.DEPTNO=r.DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], r.DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n" + + " LogicalFilter(condition=[<=($3, 1)])\n" + + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2]," + + " _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `EMP`.`DEPTNO`, `t1`.`DEPTNO` `r.DEPTNO`, `t1`.`DNAME`," + + " `t1`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN (SELECT `DEPTNO`, `DNAME`, `LOC`\n" + + "FROM (SELECT `DEPTNO`, `DNAME`, `LOC`, ROW_NUMBER() OVER (PARTITION BY `DEPTNO`" + + " ORDER BY `DEPTNO` NULLS LAST) `_row_number_dedup_`\n" + + "FROM `scott`.`DEPT`) `t`\n" + + "WHERE `_row_number_dedup_` <= 1) `t1` ON `EMP`.`DEPTNO` = `t1`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithMaxEqualsZero() { + String ppl = "source=EMP | join type=outer max=0 DEPTNO DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$8], DNAME=[$9], LOC=[$10])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 14); + + String expectedSparkSql = + "SELECT `EMP`.`EMPNO`, `EMP`.`ENAME`, `EMP`.`JOB`, `EMP`.`MGR`, `EMP`.`HIREDATE`," + + " `EMP`.`SAL`, `EMP`.`COMM`, `DEPT`.`DEPTNO`, `DEPT`.`DNAME`, `DEPT`.`LOC`\n" + + "FROM `scott`.`EMP`\n" + + "LEFT JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testJoinWithMaxLessThanZero() { + String ppl = "source=EMP | join type=outer max=-1 DEPTNO DEPT"; + Throwable t = Assert.assertThrows(SemanticCheckException.class, () -> getRelNode(ppl)); + verifyErrorMessageContains(t, "max option must be a positive integer"); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index b55d36dd367..d851730fe3b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -83,10 +83,10 @@ public class AstBuilderTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final PPLSyntaxParser parser = new PPLSyntaxParser(); - private final Settings settings = Mockito.mock(Settings.class); + private final PPLSyntaxParser parser = new PPLSyntaxParser(); + @Test public void testDynamicSourceClauseThrowsUnsupportedException() { String query = "source=[myindex, logs, fieldIndex=\"test\"]"; diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index 76643c20a8c..73bb6a3b42a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -402,30 +402,53 @@ public void testSubqueryAlias() { @Test public void testJoin() { assertEquals( - "source=t | cross join on true s | fields + id", - anonymize("source=t | cross join s | fields id")); + "source=t | cross join max=0 on *** = *** s | fields + id", + anonymize("source=t | cross join on 1=1 s | fields id")); assertEquals( - "source=t | inner join on id = uid s | fields + id", + "source=t | inner join max=0 on id = uid s | fields + id", anonymize("source=t | inner join on id = uid s | fields id")); assertEquals( - "source=t as l | inner join left = l right = r on id = uid s as r | fields + id", + "source=t as l | inner join max=0 left = l right = r on id = uid s as r | fields + id", anonymize("source=t | join left = l right = r on id = uid s | fields id")); assertEquals( - "source=t | left join right = r on id = uid s as r | fields + id", + "source=t | left join max=0 right = r on id = uid s as r | fields + id", anonymize("source=t | left join right = r on id = uid s | fields id")); assertEquals( - "source=t as t1 | inner join left = t1 right = t2 on id = uid s as t2 | fields + t1.id", + "source=t as t1 | inner join max=0 left = t1 right = t2 on id = uid s as t2 | fields +" + + " t1.id", anonymize("source=t as t1 | inner join on id = uid s as t2 | fields t1.id")); assertEquals( - "source=t as t1 | right join left = t1 right = t2 on t1.id = t2.id s as t2 | fields +" + "source=t as t1 | right join max=0 left = t1 right = t2 on t1.id = t2.id s as t2 | fields +" + " t1.id", - anonymize("source=t as t1 | right join on t1.id = t2.id s as t2 | fields t1.id")); + anonymize("source=t as t1 | right join max=0 on t1.id = t2.id s as t2 | fields t1.id")); assertEquals( - "source=t as t1 | right join left = t1 right = t2 on t1.id = t2.id [ source=s | fields + id" - + " ] as t2 | fields + t1.id", + "source=t as t1 | right join max=0 left = t1 right = t2 on t1.id = t2.id [ source=s |" + + " fields + id ] as t2 | fields + t1.id", anonymize( - "source=t as t1 | right join on t1.id = t2.id [ source=s | fields id] as t2 | fields" - + " t1.id")); + "source=t as t1 | right join max=0 on t1.id = t2.id [ source=s | fields id] as t2 |" + + " fields t1.id")); + assertEquals( + "source=t | inner join max=2 on id = uid s | fields + id", + anonymize("source=t | inner join max=2 on id = uid s | fields id")); + } + + @Test + public void testJoinWithFieldList() { + assertEquals( + "source=t | join type=inner overwrite=true max=0 s | fields + id", + anonymize("source=t | join s | fields id")); + assertEquals( + "source=t | join type=inner overwrite=true max=0 id s | fields + id", + anonymize("source=t | join id s | fields id")); + assertEquals( + "source=t | join type=left overwrite=false max=0 id1,id2 s | fields + id1", + anonymize("source=t | join type=left overwrite=false id1,id2 s | fields id1")); + assertEquals( + "source=t | join type=left overwrite=false max=0 id1,id2 s | fields + id1", + anonymize("source=t | join type=outer overwrite=false id1 id2 s | fields id1")); + assertEquals( + "source=t | join type=left overwrite=true max=2 id1,id2 s | fields + id1", + anonymize("source=t | join type=outer max=2 id1 id2 s | fields id1")); } @Test