diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 7c936b189fe..44b8e5ad4e8 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; +import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_DEDUP; import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME; import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_MAIN; import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_SUBSEARCH; @@ -843,13 +844,14 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) { if (keepEmpty) { /* * | dedup 2 a, b keepempty=false - * DropColumns('_row_number_) - * +- Filter ('_row_number_ <= n OR isnull('a) OR isnull('b)) - * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * DropColumns('_row_number_dedup_) + * +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b)) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] * +- ... */ // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, - // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC + // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a + // ASC // NULLS FIRST, 'b ASC NULLS FIRST] RexNode rowNumber = context @@ -859,23 +861,23 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) { .partitionBy(dedupeFields) .orderBy(dedupeFields) .rowsTo(RexWindowBounds.CURRENT_ROW) - .as("_row_number_"); + .as(ROW_NUMBER_COLUMN_FOR_DEDUP); context.relBuilder.projectPlus(rowNumber); - RexNode _row_number_ = context.relBuilder.field("_row_number_"); - // Filter (isnull('a) OR isnull('b) OR '_row_number_ <= n) + RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); + // Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n) context.relBuilder.filter( context.relBuilder.or( context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()), context.relBuilder.lessThanOrEqual( - _row_number_, context.relBuilder.literal(allowedDuplication)))); + _row_number_dedup_, context.relBuilder.literal(allowedDuplication)))); // DropColumns('_row_number_) - context.relBuilder.projectExcept(_row_number_); + context.relBuilder.projectExcept(_row_number_dedup_); } else { /* * | dedup 2 a, b keepempty=false - * DropColumns('_row_number_) - * +- Filter ('_row_number_ <= n) - * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * DropColumns('_row_number_dedup_) + * +- Filter ('_row_number_dedup_ <= n) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] * +- Filter (isnotnull('a) AND isnotnull('b)) * +- ... */ @@ -884,7 +886,8 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) { context.relBuilder.and( dedupeFields.stream().map(context.relBuilder::isNotNull).toList())); // Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, - // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC + // specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a + // ASC // NULLS FIRST, 'b ASC NULLS FIRST] RexNode rowNumber = context @@ -894,15 +897,15 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) { .partitionBy(dedupeFields) .orderBy(dedupeFields) .rowsTo(RexWindowBounds.CURRENT_ROW) - .as("_row_number_"); + .as(ROW_NUMBER_COLUMN_FOR_DEDUP); context.relBuilder.projectPlus(rowNumber); - RexNode _row_number_ = context.relBuilder.field("_row_number_"); - // Filter ('_row_number_ <= n) + RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP); + // Filter ('_row_number_dedup_ <= n) context.relBuilder.filter( context.relBuilder.lessThanOrEqual( - _row_number_, context.relBuilder.literal(allowedDuplication))); - // DropColumns('_row_number_) - context.relBuilder.projectExcept(_row_number_); + _row_number_dedup_, context.relBuilder.literal(allowedDuplication))); + // DropColumns('_row_number_dedup_) + context.relBuilder.projectExcept(_row_number_dedup_); } return context.relBuilder.peek(); } diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java index 42d0addb53e..f98aebd2e7c 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java @@ -20,11 +20,15 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttle; import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.rex.RexWindow; import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; @@ -43,6 +47,9 @@ public interface PlanUtils { + /** this is only for dedup command, do not reuse it in other command */ + String ROW_NUMBER_COLUMN_FOR_DEDUP = "_row_number_dedup_"; + String ROW_NUMBER_COLUMN_NAME = "_row_number_"; String ROW_NUMBER_COLUMN_NAME_MAIN = "_row_number_main_"; String ROW_NUMBER_COLUMN_NAME_SUBSEARCH = "_row_number_subsearch_"; @@ -347,4 +354,41 @@ static RexNode derefMapCall(RexNode rexNode) { } return rexNode; } + + /** Check if contains RexOver */ + static boolean containsRowNumberDedup(LogicalProject project) { + return project.getProjects().stream() + .anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER); + } + + /** Get all RexWindow list from LogicalProject */ + static List getRexWindowFromProject(LogicalProject project) { + final List res = new ArrayList<>(); + final RexVisitorImpl visitor = + new RexVisitorImpl<>(true) { + @Override + public Void visitOver(RexOver over) { + res.add(over.getWindow()); + return null; + } + }; + visitor.visitEach(project.getProjects()); + return res; + } + + static List getSelectColumns(List rexNodes) { + final List selectedColumns = new ArrayList<>(); + final RexVisitorImpl visitor = + new RexVisitorImpl(true) { + @Override + public Void visitInputRef(RexInputRef inputRef) { + if (!selectedColumns.contains(inputRef.getIndex())) { + selectedColumns.add(inputRef.getIndex()); + } + return null; + } + }; + visitor.visitEach(rexNodes); + return selectedColumns; + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index d43be38137a..a9f84c9bc63 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -23,7 +23,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.logical.LogicalSort; -import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.tools.FrameworkConfig; @@ -298,7 +297,7 @@ private FrameworkConfig buildFrameworkConfig() { .parserConfig(SqlParser.Config.DEFAULT) // TODO check .defaultSchema(opensearchSchema) .traitDefs((List) null) - .programs(Programs.calc(DefaultRelMetadataProvider.INSTANCE)) + .programs(Programs.standard()) .typeSystem(OpenSearchTypeSystem.INSTANCE); return configBuilder.build(); } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/tpch/CalcitePPLTpchIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/tpch/CalcitePPLTpchIT.java index 27fe8c49edf..29e50cbcc19 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/tpch/CalcitePPLTpchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/tpch/CalcitePPLTpchIT.java @@ -54,7 +54,7 @@ public void testQ1() throws IOException { schema("count_order", "bigint")); verifyDataRows( actual, - rows( + closeTo( "A", "F", 37474, @@ -65,7 +65,7 @@ public void testQ1() throws IOException { isPushdownEnabled() ? 25419.231826792962 : 25419.231826792948, isPushdownEnabled() ? 0.0508660351826793 : 0.050866035182679493, 1478), - rows( + closeTo( "N", "F", 1041, @@ -76,7 +76,7 @@ public void testQ1() throws IOException { 27402.659736842103, isPushdownEnabled() ? 0.04289473684210526 : 0.042894736842105284, 38), - rows( + closeTo( "N", "O", 75168, @@ -87,7 +87,7 @@ public void testQ1() throws IOException { isPushdownEnabled() ? 25632.42277116627 : 25632.422771166166, isPushdownEnabled() ? 0.049697381842910573 : 0.04969738184291069, 2941), - rows( + closeTo( "R", "F", 36511, diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 8d398b82832..5f300002caf 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -427,6 +427,36 @@ public void testStatsByTimeSpan() throws IOException { String.format("source=%s | stats count() by span(birthdate,1M)", TEST_INDEX_BANK))); } + @Test + public void testDedupPushdown() throws IOException { + String expected = loadExpectedPlan("explain_dedup_push.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account | fields account_number, gender, age" + + " | dedup 1 gender")); + } + + @Test + public void testDedupKeepEmptyTruePushdown() throws IOException { + String expected = loadExpectedPlan("explain_dedup_keepempty_true_push.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account | fields account_number, gender, age" + + " | dedup gender KEEPEMPTY=true")); + } + + @Test + public void testDedupKeepEmptyFalsePushdown() throws IOException { + String expected = loadExpectedPlan("explain_dedup_keepempty_false_push.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account | fields account_number, gender, age" + + " | dedup gender KEEPEMPTY=false")); + } + @Test public void testSingleFieldRelevanceQueryFunctionExplain() throws IOException { // This test is only applicable if pushdown is enabled diff --git a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java index de27506e230..b7e1bf150aa 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java @@ -292,16 +292,21 @@ protected boolean matchesSafely(JSONArray array) { }; } - public static TypeSafeMatcher closeTo(Number... values) { + public static TypeSafeMatcher closeTo(Object... values) { final double error = 1e-10; return new TypeSafeMatcher() { @Override protected boolean matchesSafely(JSONArray item) { - List expectedValues = new ArrayList<>(Arrays.asList(values)); - List actualValues = new ArrayList<>(); - item.iterator().forEachRemaining(v -> actualValues.add((Number) v)); + List expectedValues = new ArrayList<>(Arrays.asList(values)); + List actualValues = new ArrayList<>(); + item.iterator().forEachRemaining(v -> actualValues.add((Object) v)); return actualValues.stream() - .allMatch(v -> valuesAreClose(v, expectedValues.get(actualValues.indexOf(v)))); + .allMatch( + v -> + v instanceof Number + ? valuesAreClose( + (Number) v, (Number) expectedValues.get(actualValues.indexOf(v))) + : v.equals(expectedValues.get(actualValues.indexOf(v)))); } @Override diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_false_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_false_push.json new file mode 100644 index 00000000000..d5d7326fcd4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_false_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $1 ORDER BY $1)])\n LogicalFilter(condition=[IS NOT NULL($1)])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, gender, age], FILTER->IS NOT NULL($1), COLLAPSE->gender, LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"exists\":{\"field\":\"gender\",\"boost\":1.0}},\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"collapse\":{\"field\":\"gender.keyword\"}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_true_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_true_push.json new file mode 100644 index 00000000000..92b6103864f --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_keepempty_true_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[OR(IS NULL($1), <=($3, 1))])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $4 ORDER BY $4)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t1)], expr#5=[1], expr#6=[<=($t3, $t5)], expr#7=[OR($t4, $t6)], proj#0..2=[{exprs}], $condition=[$t7])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, gender, age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_push.json new file mode 100644 index 00000000000..d5d7326fcd4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_dedup_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $1 ORDER BY $1)])\n LogicalFilter(condition=[IS NOT NULL($1)])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, gender, age], FILTER->IS NOT NULL($1), COLLAPSE->gender, LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"exists\":{\"field\":\"gender\",\"boost\":1.0}},\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"collapse\":{\"field\":\"gender.keyword\"}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_output.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_output.json index 9474e4d1e31..0e64b6580d9 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_output.json +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_output.json @@ -1,6 +1,6 @@ { "calcite": { - "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], age2=[$t1], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[2], expr#4=[+($t2, $t3)], expr#5=[IS NOT NULL($t2)], state=[$t0], age2=[$t4], $condition=[$t5])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[city, state, age], FILTER->>($2, 30), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},avg_age=AVG($2)), SORT->[0 ASC FIRST]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"city\",\"state\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" } } diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_false_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_false_push.json new file mode 100644 index 00000000000..625dc968ab4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_false_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $1 ORDER BY $1)])\n LogicalFilter(condition=[IS NOT NULL($1)])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..17=[{inputs}], expr#18=[1], expr#19=[<=($t17, $t18)], account_number=[$t0], gender=[$t4], age=[$t8], $condition=[$t19])\n EnumerableWindow(window#0=[window(partition {4} order by [4] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..16=[{inputs}], expr#17=[IS NOT NULL($t4)], proj#0..16=[{exprs}], $condition=[$t17])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_true_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_true_push.json new file mode 100644 index 00000000000..d1592e9fa89 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_keepempty_true_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[OR(IS NULL($1), <=($3, 1))])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $4 ORDER BY $4)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..17=[{inputs}], expr#18=[IS NULL($t4)], expr#19=[1], expr#20=[<=($t17, $t19)], expr#21=[OR($t18, $t20)], account_number=[$t0], gender=[$t4], age=[$t8], $condition=[$t21])\n EnumerableWindow(window#0=[window(partition {4} order by [4] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_push.json new file mode 100644 index 00000000000..625dc968ab4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_dedup_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(account_number=[$0], gender=[$1], age=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $1 ORDER BY $1)])\n LogicalFilter(condition=[IS NOT NULL($1)])\n LogicalProject(account_number=[$0], gender=[$4], age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..17=[{inputs}], expr#18=[1], expr#19=[<=($t17, $t18)], account_number=[$t0], gender=[$t4], age=[$t8], $condition=[$t19])\n EnumerableWindow(window#0=[window(partition {4} order by [4] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..16=[{inputs}], expr#17=[IS NOT NULL($t4)], proj#0..16=[{exprs}], $condition=[$t17])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_output.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_output.json index 1705f593a76..6b8e10d07b3 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_output.json +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_output.json @@ -1,6 +1,6 @@ { "calcite": { - "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", "physical": "EnumerableCalc(expr#0..2=[{inputs}], age2=[$t1])\n EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], proj#0..2=[{exprs}], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:BIGINT], expr#7=[CASE($t5, $t6, $t2)], expr#8=[CAST($t7):DOUBLE], expr#9=[/($t8, $t3)], expr#10=[2], expr#11=[+($t9, $t10)], expr#12=[IS NOT NULL($t8)], state=[$t1], age2=[$t11], $condition=[$t12])\n EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])\n EnumerableAggregate(group=[{5, 7}], agg#0=[$SUM0($8)], agg#1=[COUNT($8)])\n EnumerableCalc(expr#0..16=[{inputs}], expr#17=[30], expr#18=[>($t8, $t17)], proj#0..16=[{exprs}], $condition=[$t18])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" } } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_false_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_false_push.json new file mode 100644 index 00000000000..4f85572e388 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_false_push.json @@ -0,0 +1,30 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "DedupeOperator", + "description": { + "dedupeList": "[gender]", + "allowedDuplication": 1, + "keepEmpty": false, + "consecutive": false + }, + "children": [{ + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]}}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + }] + }] + }] + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_true_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_true_push.json new file mode 100644 index 00000000000..46fa0793af9 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_keepempty_true_push.json @@ -0,0 +1,30 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "DedupeOperator", + "description": { + "dedupeList": "[gender]", + "allowedDuplication": 1, + "keepEmpty": true, + "consecutive": false + }, + "children": [{ + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]}}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + }] + }] + }] + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_push.json new file mode 100644 index 00000000000..e7728735ee0 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_dedup_push.json @@ -0,0 +1,30 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "DedupeOperator", + "description": { + "dedupeList": "[gender]", + "allowedDuplication": 1, + "keepEmpty": false, + "consecutive": false + }, + "children": [{ + "name": "ProjectOperator", + "description": { + "fields": "[account_number, gender, age]" + }, + "children": [{ + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"gender\",\"age\"],\"excludes\":[]}}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + }] + }] + }] + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java new file mode 100644 index 00000000000..a070e3ef1ba --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.opensearch.planner.physical; + +import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_DEDUP; + +import java.util.List; +import java.util.function.Predicate; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexWindow; +import org.apache.calcite.sql.SqlKind; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; + +@Value.Enclosing +public class OpenSearchDedupPushdownRule extends RelRule { + private static final Logger LOG = LogManager.getLogger(); + + protected OpenSearchDedupPushdownRule(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final LogicalProject finalOutput = call.rel(0); + // TODO Used when number of duplication is more than 1 + final LogicalFilter numOfDedupFilter = call.rel(1); + final LogicalProject projectWithWindow = call.rel(2); + final CalciteLogicalIndexScan scan = call.rel(3); + List windows = PlanUtils.getRexWindowFromProject(projectWithWindow); + if (windows.isEmpty() || windows.stream().anyMatch(w -> w.partitionKeys.size() > 1)) { + // TODO leverage inner_hits for multiple partition keys + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the dedup with multiple fields"); + } + return; + } + final List fieldNameList = projectWithWindow.getInput().getRowType().getFieldNames(); + List selectColumns = PlanUtils.getSelectColumns(windows.getFirst().partitionKeys); + String fieldName = fieldNameList.get(selectColumns.getFirst()); + + CalciteLogicalIndexScan newScan = scan.pushDownCollapse(finalOutput, fieldName); + if (newScan != null) { + call.transformTo(newScan); + } + } + + private static boolean validFilter(LogicalFilter filter) { + if (filter.getCondition().getKind() != SqlKind.LESS_THAN_OR_EQUAL) { + return false; + } + List operandsOfCondition = ((RexCall) filter.getCondition()).getOperands(); + RexNode leftOperand = operandsOfCondition.getFirst(); + if (!(leftOperand instanceof RexInputRef ref)) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the dedup since the left operand is not RexInputRef"); + } + return false; + } + String referenceName = filter.getRowType().getFieldNames().get(ref.getIndex()); + if (!referenceName.equals(ROW_NUMBER_COLUMN_FOR_DEDUP)) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Cannot pushdown the dedup since the left operand is not {}", + ROW_NUMBER_COLUMN_FOR_DEDUP); + } + return false; + } + RexNode rightOperand = operandsOfCondition.getLast(); + if (!(rightOperand instanceof RexLiteral numLiteral)) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the dedup since the right operand is not RexLiteral"); + } + return false; + } + Integer num = numLiteral.getValueAs(Integer.class); + if (num == null || num > 1) { + // TODO leverage inner_hits for num > 1 + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the dedup since number of duplicate events is larger than 1"); + } + return false; + } + return true; + } + + /** + * Match fixed pattern:
+ * LogicalProject(remove _row_number_dedup_)
+ * LogicalFilter(condition=[<=($1, numOfDedup)])
+ * LogicalProject(..., _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $0)])
+ * LogicalFilter(condition=[IS NOT NULL($0)])
+ */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = + ImmutableOpenSearchDedupPushdownRule.Config.builder() + .build() + .withOperandSupplier( + b0 -> + b0.operand(LogicalProject.class) + .oneInput( + b1 -> + b1.operand(LogicalFilter.class) + .predicate(OpenSearchDedupPushdownRule::validFilter) + .oneInput( + b2 -> + b2.operand(LogicalProject.class) + .predicate(PlanUtils::containsRowNumberDedup) + .oneInput( + b3 -> + b3.operand(CalciteLogicalIndexScan.class) + .predicate( + Predicate.not( + OpenSearchIndexScanRule + ::isLimitPushed) + .and( + OpenSearchIndexScanRule + ::noAggregatePushed)) + .noInputs())))); + + @Override + default OpenSearchDedupPushdownRule toRule() { + return new OpenSearchDedupPushdownRule(this); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index 05f3f6cc4f9..c619ef27da4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -22,6 +22,8 @@ public class OpenSearchIndexRules { OpenSearchLimitIndexScanRule.Config.DEFAULT.toRule(); private static final OpenSearchSortIndexScanRule SORT_INDEX_SCAN = OpenSearchSortIndexScanRule.Config.DEFAULT.toRule(); + private static final OpenSearchDedupPushdownRule DEDUP_PUSH_DOWN = + OpenSearchDedupPushdownRule.Config.DEFAULT.toRule(); public static final List OPEN_SEARCH_INDEX_SCAN_RULES = ImmutableList.of( @@ -30,7 +32,8 @@ public class OpenSearchIndexRules { AGGREGATE_INDEX_SCAN, COUNT_STAR_INDEX_SCAN, LIMIT_INDEX_SCAN, - SORT_INDEX_SCAN); + SORT_INDEX_SCAN, + DEDUP_PUSH_DOWN); // prevent instantiation private OpenSearchIndexRules() {} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 302751277ce..06638aeed93 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -35,6 +35,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.collapse.CollapseBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.SortBuilder; @@ -297,6 +298,10 @@ public void pushTypeMapping(Map typeMapping) { exprValueFactory.extendTypeMapping(typeMapping); } + public void pushDownCollapse(String field) { + sourceBuilder.collapse(new CollapseBuilder(field)); + } + private boolean isSortByDocOnly() { List> sorts = sourceBuilder.sorts(); if (sorts != null) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index 61df5b0282e..9f30b3b497c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -122,6 +122,8 @@ public double estimateRowCount(RelMetadataQuery mq) { switch (action.type) { case AGGREGATION -> mq.getRowCount((RelNode) action.digest); case PROJECT, SORT -> rowCount; + // Refer the org.apache.calcite.rel.core.Aggregate.estimateRowCount + case COLLAPSE -> rowCount * (1.0 - Math.pow(.5, 1)); case FILTER -> NumberUtil.multiply( rowCount, RelMdUtil.guessSelectivity((RexNode) action.digest)); case SCRIPT -> NumberUtil.multiply( @@ -138,6 +140,7 @@ public static class PushDownContext extends ArrayDeque { private boolean isAggregatePushed = false; @Getter private boolean isLimitPushed = false; + @Getter private boolean isProjectPushed = false; @Override public PushDownContext clone() { @@ -152,6 +155,9 @@ public boolean add(PushDownAction pushDownAction) { if (pushDownAction.type == PushDownType.LIMIT) { isLimitPushed = true; } + if (pushDownAction.type == PushDownType.PROJECT) { + isProjectPushed = true; + } return super.add(pushDownAction); } @@ -305,7 +311,8 @@ protected enum PushDownType { AGGREGATION, SORT, LIMIT, - SCRIPT + SCRIPT, + COLLAPSE // HIGHLIGHT, // NESTED } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index c5d3aa03072..29c380623bc 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -37,8 +37,10 @@ import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.planner.physical.EnumerableIndexScanRule; import org.opensearch.sql.opensearch.planner.physical.OpenSearchIndexRules; import org.opensearch.sql.opensearch.request.AggregateAnalyzer; @@ -153,6 +155,44 @@ private static RexNode constructCondition(List conditions, RexBuilder r : conditions.get(0); } + public CalciteLogicalIndexScan pushDownCollapse(Project finalOutput, String fieldName) { + ExprType fieldType = osIndex.getFieldTypes().get(fieldName); + if (fieldType == null) { + // the fieldName must be one of index fields + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the dedup '{}' due to it is not a index field", fieldName); + } + return null; + } + ExprType originalExprType = fieldType.getOriginalExprType(); + String originalFieldName = originalExprType.getOriginalPath().orElse(fieldName); + if (!ExprCoreType.numberTypes().contains(originalExprType) + && !originalExprType.legacyTypeName().equals("KEYWORD") + && !originalExprType.legacyTypeName().equals("TEXT")) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Cannot pushdown the dedup '{}' due to only keyword and number type are accepted, but" + + " its type is {}", + originalFieldName, + originalExprType.legacyTypeName()); + } + return null; + } + // For text, use its subfield if exists. + String field = OpenSearchTextType.toKeywordSubField(originalFieldName, fieldType); + if (field == null) { + LOG.debug("Cannot pushdown the dedup due to no keyword subfield for {}.", fieldName); + return null; + } + CalciteLogicalIndexScan newScan = this.copyWithNewSchema(finalOutput.getRowType()); + newScan.pushDownContext.add( + PushDownAction.of( + PushDownType.COLLAPSE, + fieldName, + requestBuilder -> requestBuilder.pushDownCollapse(field))); + return newScan; + } + /** * When pushing down a project, we need to create a new CalciteLogicalIndexScan with the updated * schema since we cannot override getRowType() which is defined to be final. diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLDedupTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLDedupTest.java index 52a0ad492da..6cf3c91fe7b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLDedupTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLDedupTest.java @@ -20,13 +20,12 @@ public void testDedup1() { String ppl = "source=EMP | dedup 1 DEPTNO"; RelNode root = getRelNode(ppl); String expectedLogical = - "" - + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7])\n" + " LogicalFilter(condition=[<=($8, 1)])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $7" - + " ORDER BY $7)])\n" + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION" + + " BY $7 ORDER BY $7)])\n" + " LogicalFilter(condition=[IS NOT NULL($7)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -44,10 +43,10 @@ public void testDedup1() { + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " ROW_NUMBER() OVER (PARTITION BY `DEPTNO` ORDER BY `DEPTNO` NULLS LAST)" - + " `_row_number_`\n" + + " `_row_number_dedup_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` IS NOT NULL) `t0`\n" - + "WHERE `_row_number_` <= 1"; + + "WHERE `_row_number_dedup_` <= 1"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -56,13 +55,12 @@ public void testDedup2() { String ppl = "source=EMP | dedup 2 DEPTNO"; RelNode root = getRelNode(ppl); String expectedLogical = - "" - + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7])\n" + " LogicalFilter(condition=[<=($8, 2)])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $7" - + " ORDER BY $7)])\n" + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION" + + " BY $7 ORDER BY $7)])\n" + " LogicalFilter(condition=[IS NOT NULL($7)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); @@ -86,10 +84,10 @@ public void testDedup2() { + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " ROW_NUMBER() OVER (PARTITION BY `DEPTNO` ORDER BY `DEPTNO` NULLS LAST)" - + " `_row_number_`\n" + + " `_row_number_dedup_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` IS NOT NULL) `t0`\n" - + "WHERE `_row_number_` <= 2"; + + "WHERE `_row_number_dedup_` <= 2"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -102,8 +100,8 @@ public void testDedupKeepEmpty1() { + " COMM=[$6], DEPTNO=[$7])\n" + " LogicalFilter(condition=[OR(IS NULL($7), IS NULL($2), <=($8, 1))])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $7," - + " $2 ORDER BY $7, $2)])\n" + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION" + + " BY $7, $2 ORDER BY $7, $2)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = @@ -131,9 +129,9 @@ public void testDedupKeepEmpty1() { "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " ROW_NUMBER() OVER (PARTITION BY `DEPTNO`, `JOB` ORDER BY `DEPTNO` NULLS LAST, `JOB`" - + " NULLS LAST) `_row_number_`\n" + + " NULLS LAST) `_row_number_dedup_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "WHERE `DEPTNO` IS NULL OR `JOB` IS NULL OR `_row_number_` <= 1"; + + "WHERE `DEPTNO` IS NULL OR `JOB` IS NULL OR `_row_number_dedup_` <= 1"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -146,8 +144,8 @@ public void testDedupKeepEmpty2() { + " COMM=[$6], DEPTNO=[$7])\n" + " LogicalFilter(condition=[OR(IS NULL($7), IS NULL($2), <=($8, 2))])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $7," - + " $2 ORDER BY $7, $2)])\n" + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION" + + " BY $7, $2 ORDER BY $7, $2)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = @@ -181,9 +179,9 @@ public void testDedupKeepEmpty2() { "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " ROW_NUMBER() OVER (PARTITION BY `DEPTNO`, `JOB` ORDER BY `DEPTNO` NULLS LAST, `JOB`" - + " NULLS LAST) `_row_number_`\n" + + " NULLS LAST) `_row_number_dedup_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "WHERE `DEPTNO` IS NULL OR `JOB` IS NULL OR `_row_number_` <= 2"; + + "WHERE `DEPTNO` IS NULL OR `JOB` IS NULL OR `_row_number_dedup_` <= 2"; verifyPPLToSparkSQL(root, expectedSparkSql); } }