From 7ddb1d77b9beeffe34f5a1f3360ca16753815e32 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Mon, 20 Oct 2025 17:09:00 +0800 Subject: [PATCH 1/9] convert sort aggregate metrics to term sort Signed-off-by: Lantao Jin --- .../sql/calcite/utils/PlanUtils.java | 46 ++++- .../sql/calcite/remote/CalciteExplainIT.java | 1 - .../calcite/explain_agg_sort_on_metrics1.yaml | 2 +- .../ExpandCollationOnProjectExprRule.java | 3 +- .../OpenSearchAggregateIndexScanRule.java | 28 ++- .../physical/OpenSearchDedupPushdownRule.java | 5 +- .../OpenSearchFilterIndexScanRule.java | 5 +- .../physical/OpenSearchIndexRules.java | 3 + .../physical/OpenSearchIndexScanRule.java | 74 ------- .../OpenSearchLimitIndexScanRule.java | 3 +- .../physical/OpenSearchSortIndexScanRule.java | 9 +- .../physical/SortAggregationMetricsRule.java | 181 ++++++++++++++++++ .../SortProjectExprTransposeRule.java | 3 +- .../opensearch/request/AggregateAnalyzer.java | 37 +++- .../request/OpenSearchRequestBuilder.java | 2 + .../scan/AbstractCalciteIndexScan.java | 25 ++- .../storage/scan/CalciteLogicalIndexScan.java | 17 +- .../storage/scan/PushDownContext.java | 7 +- .../request/AggregateAnalyzerTest.java | 15 +- .../calcite/CalcitePPLAggregationTest.java | 51 +++++ 20 files changed, 396 insertions(+), 121 deletions(-) delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java index aaeb089020c..153135c5cf8 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java @@ -14,8 +14,10 @@ import com.google.common.collect.ImmutableList; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -23,8 +25,11 @@ import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCorrelVariable; @@ -38,6 +43,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; @@ -474,13 +480,51 @@ public Void visitInputRef(RexInputRef inputRef) { return selectedColumns; } + // `RelDecorrelator` may generate a Project with duplicated fields, e.g. Project($0,$0). + // There will be problem if pushing down the pattern like `Aggregate(AGG($0),{1})-Project($0,$0)`, + // as it will lead to field-name conflict. + // We should wait and rely on `AggregateProjectMergeRule` to mitigate it by having this constraint + // Nevertheless, that rule cannot handle all cases if there is RexCall in the Project, + // e.g. Project($0, $0, +($0,1)). We cannot push down the Aggregate for this corner case. + // TODO: Simplify the Project where there is RexCall by adding a new rule. + static boolean distinctProjectList(LogicalProject project) { + // Change to Set> to resolve + // https://github.com/opensearch-project/sql/issues/4347 + Set> rexSet = new HashSet<>(); + return project.getNamedProjects().stream().allMatch(rexSet::add); + } + + static boolean containsRexOver(LogicalProject project) { + return project.getProjects().stream().anyMatch(RexOver::containsOver); + } + + /** + * The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its + * collation is empty. For example: sort name | head 5 should not be pushed down + * because it has a field collation. + * + * @param sort The LogicalSort to check. + * @return True if the LogicalSort is a LIMIT, false otherwise. + */ + static boolean isLogicalSortLimit(LogicalSort sort) { + return sort.fetch != null; + } + + static boolean projectContainsExpr(Project project) { + return project.getProjects().stream().anyMatch(p -> p instanceof RexCall); + } + + static boolean sortByFieldsOnly(Sort sort) { + return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null; + } + /** * Get a string representation of the argument types expressed in ExprType for error messages. * * @param argTypes the list of argument types as {@link RelDataType} * @return a string in the format [type1,type2,...] representing the argument types */ - public static String getActualSignature(List argTypes) { + static String getActualSignature(List argTypes) { return "[" + argTypes.stream() .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 86f585e4547..f8ace549bdc 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1023,7 +1023,6 @@ public void testExplainCountsByAgg() throws IOException { @Test public void testExplainSortOnMetricsNoBucketNullable() throws IOException { - // TODO enhancement later: https://github.com/opensearch-project/sql/issues/4282 enabledOnlyWhenPushdownIsEnabled(); String expected = loadExpectedPlan("explain_agg_sort_on_metrics1.yaml"); assertYamlEqualsJsonIgnoreId( diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml index 81082ac86e7..fb7435b52c1 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml @@ -10,4 +10,4 @@ calcite: physical: | EnumerableLimit(fetch=[10000]) EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[SORT_AGG_METRICS->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","_source":{"includes":["count()","state"],"excludes":[]},"aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"count()":"asc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java index 36acc3b0dab..57db35b092a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java @@ -19,6 +19,7 @@ import org.apache.calcite.rel.core.Project; import org.apache.commons.lang3.tuple.Pair; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil; /** @@ -108,7 +109,7 @@ public interface Config extends RelRule.Config { .oneInput( b1 -> b1.operand(EnumerableProject.class) - .predicate(OpenSearchIndexScanRule::projectContainsExpr) + .predicate(PlanUtils::projectContainsExpr) .predicate(p -> !p.containsOver()) .anyInputs())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java index 51539314718..c8ab3f46c91 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -23,7 +23,9 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.immutables.value.Value; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.expression.function.udf.binning.WidthBucketFunction; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** Planner rule that push a {@link LogicalAggregate} down to {@link CalciteLogicalIndexScan} */ @@ -82,7 +84,7 @@ protected void apply( LogicalAggregate aggregate, LogicalProject project, CalciteLogicalIndexScan scan) { - AbstractRelNode newRelNode = scan.pushDownAggregate(aggregate, project); + AbstractRelNode newRelNode = scan.pushDownAggregate(aggregate, project, null); if (newRelNode != null) { call.transformTo(newRelNode); } @@ -106,17 +108,17 @@ public interface Config extends RelRule.Config { // 1. No RexOver and no duplicate projection // 2. Contains width_bucket function on date field referring // to bin command with parameter bins - Predicate.not(OpenSearchIndexScanRule::containsRexOver) - .and(OpenSearchIndexScanRule::distinctProjectList) + Predicate.not(PlanUtils::containsRexOver) + .and(PlanUtils::distinctProjectList) .or(Config::containsWidthBucketFuncOnDate)) .oneInput( b2 -> b2.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule::isLimitPushed) + AbstractCalciteIndexScan::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs()))); Config COUNT_STAR = @@ -138,8 +140,8 @@ public interface Config extends RelRule.Config { b1 -> b1.operand(CalciteLogicalIndexScan.class) .predicate( - Predicate.not(OpenSearchIndexScanRule::isLimitPushed) - .and(OpenSearchIndexScanRule::noAggregatePushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and(AbstractCalciteIndexScan::noAggregatePushed)) .noInputs())); // TODO: No need this rule once https://github.com/opensearch-project/sql/issues/4403 is // addressed @@ -173,22 +175,18 @@ public interface Config extends RelRule.Config { // 2. Contains width_bucket function on date // field referring // to bin command with parameter bins - Predicate.not( - OpenSearchIndexScanRule - ::containsRexOver) - .and( - OpenSearchIndexScanRule - ::distinctProjectList) + Predicate.not(PlanUtils::containsRexOver) + .and(PlanUtils::distinctProjectList) .or(Config::containsWidthBucketFuncOnDate)) .oneInput( b3 -> b3.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs())))); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java index a51fa365905..19e4781ec37 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger; import org.immutables.value.Value; import org.opensearch.sql.calcite.utils.PlanUtils; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; @Value.Enclosing @@ -125,10 +126,10 @@ public interface Config extends RelRule.Config { b3.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs())))); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java index 306617ae2cf..e82ccb6dffa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java @@ -12,6 +12,7 @@ import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.logical.LogicalFilter; import org.immutables.value.Value; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** Planner rule that push a {@link LogicalFilter} down to {@link CalciteLogicalIndexScan} */ @@ -64,8 +65,8 @@ public interface Config extends RelRule.Config { // handle filter pushdown after limit. Both "limit after // filter" and "filter after limit" result in the same // limit-after-filter DSL. - Predicate.not(OpenSearchIndexScanRule::isLimitPushed) - .and(OpenSearchIndexScanRule::noAggregatePushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and(AbstractCalciteIndexScan::noAggregatePushed)) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index 0e947126314..c41254e1e63 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -31,6 +31,8 @@ public class OpenSearchIndexRules { SortProjectExprTransposeRule.Config.DEFAULT.toRule(); private static final ExpandCollationOnProjectExprRule EXPAND_COLLATION_ON_PROJECT_EXPR = ExpandCollationOnProjectExprRule.Config.DEFAULT.toRule(); + private static final SortAggregationMetricsRule SORT_AGGREGATION_METRICS_RULE = + SortAggregationMetricsRule.Config.DEFAULT.toRule(); // Rule that always pushes down relevance functions regardless of pushdown settings public static final OpenSearchRelevanceFunctionPushdownRule RELEVANCE_FUNCTION_PUSHDOWN = @@ -48,6 +50,7 @@ public class OpenSearchIndexRules { // TODO enable if https://github.com/opensearch-project/OpenSearch/issues/3725 resolved // DEDUP_PUSH_DOWN, SORT_PROJECT_EXPR_TRANSPOSE, + SORT_AGGREGATION_METRICS_RULE, EXPAND_COLLATION_ON_PROJECT_EXPR); // prevent instantiation diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java deleted file mode 100644 index 24abb3c3bc9..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.planner.physical; - -import java.util.HashSet; -import java.util.Set; -import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rel.logical.LogicalSort; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexOver; -import org.apache.calcite.util.Pair; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; -import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; - -public interface OpenSearchIndexScanRule { - /** - * CalciteOpenSearchIndexScan doesn't allow push-down anymore (except Sort under some strict - * condition) after Aggregate push-down. - */ - static boolean noAggregatePushed(AbstractCalciteIndexScan scan) { - if (scan.getPushDownContext().isAggregatePushed()) return false; - final RelOptTable table = scan.getTable(); - return table.unwrap(OpenSearchIndex.class) != null; - } - - static boolean isLimitPushed(AbstractCalciteIndexScan scan) { - return scan.getPushDownContext().isLimitPushed(); - } - - // `RelDecorrelator` may generate a Project with duplicated fields, e.g. Project($0,$0). - // There will be problem if pushing down the pattern like `Aggregate(AGG($0),{1})-Project($0,$0)`, - // as it will lead to field-name conflict. - // We should wait and rely on `AggregateProjectMergeRule` to mitigate it by having this constraint - // Nevertheless, that rule cannot handle all cases if there is RexCall in the Project, - // e.g. Project($0, $0, +($0,1)). We cannot push down the Aggregate for this corner case. - // TODO: Simplify the Project where there is RexCall by adding a new rule. - static boolean distinctProjectList(LogicalProject project) { - // Change to Set> to resolve - // https://github.com/opensearch-project/sql/issues/4347 - Set> rexSet = new HashSet<>(); - return project.getNamedProjects().stream().allMatch(rexSet::add); - } - - static boolean containsRexOver(LogicalProject project) { - return project.getProjects().stream().anyMatch(RexOver::containsOver); - } - - /** - * The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its - * collation is empty. For example: sort name | head 5 should not be pushed down - * because it has a field collation. - * - * @param sort The LogicalSort to check. - * @return True if the LogicalSort is a LIMIT, false otherwise. - */ - static boolean isLogicalSortLimit(LogicalSort sort) { - return sort.fetch != null; - } - - static boolean projectContainsExpr(Project project) { - return project.getProjects().stream().anyMatch(p -> p instanceof RexCall); - } - - static boolean sortByFieldsOnly(Sort sort) { - return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java index a6832b06a24..31a7bf233d2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java @@ -13,6 +13,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** @@ -88,7 +89,7 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(LogicalSort.class) - .predicate(OpenSearchIndexScanRule::isLogicalSortLimit) + .predicate(PlanUtils::isLogicalSortLimit) .oneInput(b1 -> b1.operand(CalciteLogicalIndexScan.class).noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java index 47274f467fc..b519c50a1ec 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java @@ -10,6 +10,7 @@ import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.core.Sort; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; @Value.Enclosing @@ -43,7 +44,7 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(Sort.class) - .predicate(OpenSearchIndexScanRule::sortByFieldsOnly) + .predicate(PlanUtils::sortByFieldsOnly) .oneInput( b1 -> b1.operand(AbstractCalciteIndexScan.class) @@ -51,7 +52,11 @@ public interface Config extends RelRule.Config { // because pushing down a sort after a limit will be treated // as sort-then-limit by OpenSearch DSL. .predicate( - Predicate.not(OpenSearchIndexScanRule::isLimitPushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and( + Predicate.not( + AbstractCalciteIndexScan + ::isMetricsOrderPushed))) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java new file mode 100644 index 00000000000..3efdec21173 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.immutables.value.Value; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.calcite.utils.PlanUtils; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; + +@Value.Enclosing +public class SortAggregationMetricsRule extends RelRule { + + protected SortAggregationMetricsRule(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final LogicalSort sort = call.rel(0); + final LogicalProject projectAddedBySort = call.rel(1); + final LogicalAggregate aggregate = call.rel(2); + final LogicalFilter filter = call.rel(3); + final LogicalProject project = call.rel(4); + final CalciteLogicalIndexScan scan = call.rel(5); + // Only support single metric sort + if (sort.getCollation().getFieldCollations().size() != 1) { + return; + } + // Only support single metric in aggregate + if (aggregate.getAggCallList().size() != 1) { + return; + } + int possibleMetricsIndexInSort = + sort.getCollation().getFieldCollations().getFirst().getFieldIndex(); + RexNode possibleMetricsInProject = + projectAddedBySort.getProjects().get(possibleMetricsIndexInSort); + if (!(possibleMetricsInProject instanceof RexInputRef inputRef)) { + return; + } + int possibleMetricsIndexInProject = inputRef.getIndex(); + RelDataTypeField possibleMetricsInAggregate = + aggregate.getRowType().getFieldList().get(possibleMetricsIndexInProject); + if (possibleMetricsInAggregate.getType().getSqlTypeName().getFamily() + != SqlTypeFamily.NUMERIC) { + return; + } + if (!aggregate + .getAggCallList() + .getFirst() + .getName() + .equals(possibleMetricsInAggregate.getName())) { + return; + } + List groupSet = aggregate.getGroupSet().asList(); + RexNode condition = filter.getCondition(); + Function isNotNullFromAgg = + rex -> + rex instanceof RexCall rexCall + && rexCall.getOperator() == SqlStdOperatorTable.IS_NOT_NULL + && rexCall.getOperands().get(0) instanceof RexInputRef ref + && groupSet.contains(ref.getIndex()); + if (isNotNullFromAgg.apply(condition) + || (condition instanceof RexCall rexCall + && rexCall.getOperator() == SqlStdOperatorTable.AND + && rexCall.getOperands().stream().allMatch(isNotNullFromAgg::apply))) { + // Try to do the aggregate push down and ignore the filter if the filter sources from the + // aggregate's hint. See{@link CalciteRelNodeVisitor::visitAggregation} + RelFieldCollation.Direction direction = + sort.getCollation().getFieldCollations().getFirst().direction; + apply(call, projectAddedBySort, aggregate, project, scan, direction); + } + } + + protected void apply( + RelOptRuleCall call, + LogicalProject projectAddedBySort, + LogicalAggregate aggregate, + LogicalProject project, + CalciteLogicalIndexScan scan, + RelFieldCollation.Direction metricOrder) { + AbstractRelNode newScan = scan.pushDownAggregate(aggregate, project, metricOrder); + if (newScan != null) { + RelNode newScanWithProject = + call.builder().push(newScan).project(projectAddedBySort.getProjects()).build(); + call.transformTo(newScanWithProject); + } + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + SortAggregationMetricsRule.Config DEFAULT = + ImmutableSortAggregationMetricsRule.Config.builder() + .build() + .withDescription("Sort-Project-Agg-Filter-Project-TableScan") + .withOperandSupplier( + b0 -> + b0.operand(LogicalSort.class) + .predicate(PlanUtils::sortByFieldsOnly) + .oneInput( + b1 -> + b1.operand(LogicalProject.class) + .oneInput( + b2 -> + b2.operand(LogicalAggregate.class) + .predicate( + agg -> + agg.getHints().stream() + .anyMatch( + hint -> + hint.hintName.equals( + "stats_args") + && hint.kvOptions + .get( + Argument + .BUCKET_NULLABLE) + .equals("false"))) + .oneInput( + b3 -> + b3.operand(LogicalFilter.class) + .predicate( + OpenSearchAggregateIndexScanRule + .Config + ::mayBeFilterFromBucketNonNull) + .oneInput( + b4 -> + b4.operand(LogicalProject.class) + .predicate( + Predicate.not( + PlanUtils + ::containsRexOver) + .and( + PlanUtils + ::distinctProjectList) + .or( + OpenSearchAggregateIndexScanRule + .Config + ::containsWidthBucketFuncOnDate)) + .oneInput( + b5 -> + b5.operand( + CalciteLogicalIndexScan + .class) + .predicate( + Predicate + .not( + AbstractCalciteIndexScan + ::isLimitPushed) + .and( + AbstractCalciteIndexScan + ::noAggregatePushed)) + .noInputs())))))); + + @Override + default SortAggregationMetricsRule toRule() { + return new SortAggregationMetricsRule(this); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java index dfec6754908..2ed94292096 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java @@ -26,6 +26,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.commons.lang3.tuple.Pair; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil; /** @@ -132,7 +133,7 @@ public interface Config extends RelRule.Config { b1.operand(LogicalProject.class) .predicate( Predicate.not(LogicalProject::containsOver) - .and(OpenSearchIndexScanRule::projectContainsExpr)) + .and(PlanUtils::projectContainsExpr)) .anyInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 99d4c9eb235..ddf5f7c1564 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -42,6 +42,7 @@ import java.util.function.Function; import lombok.RequiredArgsConstructor; import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; @@ -141,6 +142,7 @@ static class AggregateBuilderHelper { final Map fieldTypes; final RelOptCluster cluster; final boolean bucketNullable; + final RelFieldCollation.Direction metricOrder; > T build(RexNode node, T aggBuilder) { return build(node, aggBuilder::field, aggBuilder::script); @@ -188,6 +190,7 @@ public static Pair, OpenSearchAggregationResponseParser RelDataType rowType, Map fieldTypes, List outputFields, + RelFieldCollation.Direction metricOrder, RelOptCluster cluster) throws ExpressionNotAnalyzableException { requireNonNull(aggregate, "aggregate"); @@ -201,7 +204,7 @@ public static Pair, OpenSearchAggregationResponseParser .orElseGet(() -> "true")); List groupList = aggregate.getGroupSet().asList(); AggregateBuilderHelper helper = - new AggregateBuilderHelper(rowType, fieldTypes, cluster, bucketNullable); + new AggregateBuilderHelper(rowType, fieldTypes, cluster, bucketNullable, metricOrder); List aggFieldNames = outputFields.subList(groupList.size(), outputFields.size()); // Process all aggregate calls Pair> builderAndParser = @@ -213,7 +216,7 @@ public static Pair, OpenSearchAggregationResponseParser // but only count() can apply doc_count optimization in bucket aggregation. boolean countAllOnly = !aggregate.getGroupSet().isEmpty(); Pair, Builder> countAggNameAndBuilderPair = - removeCountAggregationBuilders(metricBuilder, countAllOnly); + removeCountAggregationBuilders(metricBuilder, countAllOnly, metricOrder); Builder newMetricBuilder = countAggNameAndBuilderPair.getRight(); List countAggNames = countAggNameAndBuilderPair.getLeft(); @@ -228,8 +231,21 @@ public static Pair, OpenSearchAggregationResponseParser new NoBucketAggregationParser(metricParserList)); } } else if (aggregate.getGroupSet().length() == 1 - && isAutoDateSpan(project.getProjects().get(groupList.getFirst()))) { + && (isAutoDateSpan(project.getProjects().get(groupList.getFirst())) + || metricOrder != null)) { ValuesSourceAggregationBuilder bucketBuilder = createBucket(0, project, helper); + if (metricOrder != null + && bucketBuilder instanceof TermsAggregationBuilder termsAggregationBuilder) { + String path = + newMetricBuilder.getAggregatorFactories().stream() + .map(AggregationBuilder::getName) + .toList() + .getFirst(); + termsAggregationBuilder.order( + metricOrder == RelFieldCollation.Direction.ASCENDING + ? BucketOrder.aggregation(path, true) + : BucketOrder.aggregation(path, false)); + } if (newMetricBuilder != null) { bucketBuilder.subAggregations(newMetricBuilder); } @@ -277,7 +293,11 @@ && isAutoDateSpan(project.getProjects().get(groupList.getFirst()))) { * with the original metric builder. */ private static Pair, Builder> removeCountAggregationBuilders( - Builder metricBuilder, boolean countAllOnly) { + Builder metricBuilder, boolean countAllOnly, RelFieldCollation.Direction metricOrder) { + // if we have a specific metric order, skip the optimization of count agg removing + if (metricOrder != null) { + return Pair.of(List.of(), metricBuilder); + } List countAggregatorFactories = metricBuilder.getAggregatorFactories().stream() .filter(ValueCountAggregationBuilder.class::isInstance) @@ -628,11 +648,10 @@ private static CompositeValuesSourceBuilder createTermsSourceBuilder( private static ValuesSourceAggregationBuilder createTermsAggregationBuilder( String bucketName, RexNode group, AggregateBuilderHelper helper) { TermsAggregationBuilder sourceBuilder = - helper.build( - group, - new TermsAggregationBuilder(bucketName) - .size(AGGREGATION_BUCKET_SIZE) - .order(BucketOrder.key(true))); + helper.build(group, new TermsAggregationBuilder(bucketName).size(AGGREGATION_BUCKET_SIZE)); + if (helper.metricOrder == null) { + sourceBuilder.order(BucketOrder.key(true)); + } // Time types values are converted to LONG in ExpressionAggregationScript::execute if (List.of(TIMESTAMP, TIME, DATE) .contains(OpenSearchTypeFactory.convertRelDataTypeToExprType(group.getType()))) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index e40962d59fb..fa7db3e375a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -298,6 +298,8 @@ public void pushDownCollapse(String field) { sourceBuilder.collapse(new CollapseBuilder(field)); } + public void pushDownSortMetrics() {} + /** * Push down nested to sourceBuilder. * diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index ad02a898128..e1862ad29d1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -103,7 +103,8 @@ public double estimateRowCount(RelMetadataQuery mq) { osIndex.getMaxResultWindow().doubleValue(), (rowCount, operation) -> switch (operation.type()) { - case AGGREGATION -> mq.getRowCount((RelNode) operation.digest()); + case AGGREGATION, SORT_AGG_METRICS -> mq.getRowCount( + (RelNode) operation.digest()); case PROJECT, SORT -> rowCount; // Refer the org.apache.calcite.rel.metadata.RelMdRowCount case COLLAPSE -> rowCount / 10; @@ -141,6 +142,10 @@ public double estimateRowCount(RelMetadataQuery mq) { // Ignored Project in cost accumulation, but it will affect the external cost case PROJECT -> {} case SORT -> dCpu += dRows; + case SORT_AGG_METRICS -> { + dRows = dRows * .9 / 10; // *.9 because always bucket IS_NOT_NULL + dCpu += dRows; + } // Refer the org.apache.calcite.rel.metadata.RelMdRowCount.getRowCount(Aggregate rel,...) case COLLAPSE -> { dRows = dRows / 10; @@ -328,4 +333,22 @@ public AbstractCalciteIndexScan pushDownSort(List collations) } return null; } + + /** + * CalciteOpenSearchIndexScan doesn't allow push-down anymore (except Sort under some strict + * condition) after Aggregate push-down. + */ + public boolean noAggregatePushed() { + if (this.getPushDownContext().isAggregatePushed()) return false; + final RelOptTable table = this.getTable(); + return table.unwrap(OpenSearchIndex.class) != null; + } + + public boolean isLimitPushed() { + return this.getPushDownContext().isLimitPushed(); + } + + public boolean isMetricsOrderPushed() { + return this.getPushDownContext().isMetricOrderPushed(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index f4fbc66d6d8..4da8721efaa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -11,6 +11,7 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import javax.annotation.Nullable; import lombok.Getter; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptCluster; @@ -273,7 +274,8 @@ private RelTraitSet reIndexCollations(List selectedColumns) { return newTraitSet; } - public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { + public AbstractRelNode pushDownAggregate( + Aggregate aggregate, Project project, @Nullable RelFieldCollation.Direction metricOrder) { try { CalciteLogicalIndexScan newScan = new CalciteLogicalIndexScan( @@ -293,7 +295,13 @@ public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { List outputFields = aggregate.getRowType().getFieldNames(); final Pair, OpenSearchAggregationResponseParser> aggregationBuilder = AggregateAnalyzer.analyze( - aggregate, project, getRowType(), fieldTypes, outputFields, getCluster()); + aggregate, + project, + getRowType(), + fieldTypes, + outputFields, + metricOrder, + getCluster()); Map extendedTypeMapping = aggregate.getRowType().getFieldList().stream() .collect( @@ -308,7 +316,10 @@ public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { aggregationBuilder, extendedTypeMapping, outputFields.subList(0, aggregate.getGroupSet().cardinality())); - newScan.pushDownContext.add(PushDownType.AGGREGATION, aggregate, action); + newScan.pushDownContext.add( + metricOrder == null ? PushDownType.AGGREGATION : PushDownType.SORT_AGG_METRICS, + aggregate, + action); if (aggregationBuilder.getLeft().size() == 1 && aggregationBuilder.getLeft().getFirst() instanceof AutoDateHistogramAggregationBuilder autoDateHistogram) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java index f7306604c1b..33bb9725841 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java @@ -48,6 +48,7 @@ public class PushDownContext extends AbstractCollection { private boolean isLimitPushed = false; private boolean isProjectPushed = false; + private boolean isMetricOrderPushed = false; public PushDownContext(OpenSearchIndex osIndex) { this.osIndex = osIndex; @@ -120,6 +121,9 @@ public boolean add(PushDownOperation operation) { if (operation.type() == PushDownType.PROJECT) { isProjectPushed = true; } + if (operation.type() == PushDownType.SORT_AGG_METRICS) { + isMetricOrderPushed = true; + } operation.action().transform(this, operation); return true; } @@ -149,7 +153,8 @@ enum PushDownType { SORT, LIMIT, SCRIPT, - COLLAPSE + COLLAPSE, + SORT_AGG_METRICS // HIGHLIGHT, // NESTED } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index b3a1d766d8b..82a1430dd94 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -153,7 +153,8 @@ void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze(aggregate, project, rowType, fieldTypes, outputFields, null); + AggregateAnalyzer.analyze( + aggregate, project, rowType, fieldTypes, outputFields, null, null); assertEquals( "[{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}," + " {\"avg\":{\"avg\":{\"field\":\"a\"}}}," @@ -234,7 +235,8 @@ void analyze_aggCall_extended() throws ExpressionNotAnalyzableException { List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze(aggregate, project, rowType, fieldTypes, outputFields, null); + AggregateAnalyzer.analyze( + aggregate, project, rowType, fieldTypes, outputFields, null, null); assertEquals( "[{\"var_samp\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," + " {\"var_pop\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," @@ -273,7 +275,8 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1)); Project project = createMockProject(List.of(0, 1)); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze(aggregate, project, rowType, fieldTypes, outputFields, null); + AggregateAnalyzer.analyze( + aggregate, project, rowType, fieldTypes, outputFields, null, null); assertEquals( "[{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[" @@ -315,7 +318,7 @@ void analyze_aggCall_TextWithoutKeyword() { ExpressionNotAnalyzableException.class, () -> AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, List.of("sum"), null)); + aggregate, project, rowType, fieldTypes, List.of("sum"), null, null)); assertEquals("[field] must not be null: [sum]", exception.getCause().getMessage()); } @@ -342,7 +345,7 @@ void analyze_groupBy_TextWithoutKeyword() { ExpressionNotAnalyzableException.class, () -> AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null)); + aggregate, project, rowType, fieldTypes, outputFields, null, null)); assertEquals("[field] must not be null", exception.getCause().getMessage()); } @@ -687,7 +690,7 @@ void verify() throws ExpressionNotAnalyzableException { } Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze( - agg, project, rowType, fieldTypes, outputFields, agg.getCluster()); + agg, project, rowType, fieldTypes, outputFields, null, agg.getCluster()); if (expectedDsl != null) { assertEquals(expectedDsl, result.getLeft().toString()); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java index 0e0068b5169..1446c7b0470 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java @@ -938,4 +938,55 @@ public void testMinOnTimeField() { String expectedSparkSql = "SELECT MIN(`HIREDATE`) `min_hire_date`\nFROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + @Test + public void testSortAggregationMetrics1() { + String ppl = "source=EMP | stats bucket_nullable=false avg(SAL) as avg by DEPTNO | sort - avg"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" + + " LogicalProject(avg=[$1], DEPTNO=[$0])\n" + + " LogicalAggregate(group=[{0}], avg=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalFilter(condition=[IS NOT NULL($7)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "avg=2916.666666; DEPTNO=10\navg=2175.; DEPTNO=20\navg=1566.666666; DEPTNO=30\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT AVG(`SAL`) `avg`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` IS NOT NULL\n" + + "GROUP BY `DEPTNO`\n" + + "ORDER BY 1 DESC"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testSortAggregationMetrics2() { + String ppl = + "source=EMP | stats avg(SAL) as avg by span(HIREDATE, 1year) as hiredate_span | sort" + + " avg"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])\n" + + " LogicalProject(avg=[$1], hiredate_span=[$0])\n" + + " LogicalAggregate(group=[{1}], avg=[AVG($0)])\n" + + " LogicalProject(SAL=[$5], hiredate_span=[SPAN($4, 1, 'y')])\n" + + " LogicalFilter(condition=[IS NOT NULL($4)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT AVG(`SAL`) `avg`, `SPAN`(`HIREDATE`, 1, 'y') `hiredate_span`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `HIREDATE` IS NOT NULL\n" + + "GROUP BY `SPAN`(`HIREDATE`, 1, 'y')\n" + + "ORDER BY 1"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } } From a8a75537b79de21cfaa6ac290ae28f85440ebdfe Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 21 Oct 2025 18:26:46 +0800 Subject: [PATCH 2/9] refactor Signed-off-by: Lantao Jin --- .../sql/calcite/remote/CalciteExplainIT.java | 31 +- .../calcite/explain_agg_sort_on_metrics1.json | 6 - .../calcite/explain_agg_sort_on_metrics1.yaml | 4 +- .../calcite/explain_agg_sort_on_metrics2.yaml | 16 +- .../calcite/explain_agg_sort_on_metrics3.yaml | 12 + .../calcite/explain_agg_sort_on_metrics4.yaml | 12 + ...lain_agg_sort_on_metrics_multi_terms1.yaml | 13 + .../OpenSearchAggregateIndexScanRule.java | 2 +- .../physical/SortAggregationMetricsRule.java | 136 +------- .../opensearch/request/AggregateAnalyzer.java | 57 +--- .../scan/AbstractCalciteIndexScan.java | 18 +- .../scan/CalciteEnumerableIndexScan.java | 1 + .../storage/scan/CalciteLogicalIndexScan.java | 76 ++++- .../storage/scan/context/AbstractAction.java | 12 + .../AggPushDownAction.java} | 312 +++++++----------- .../context/AggregationBuilderAction.java | 13 + .../storage/scan/context/FilterDigest.java | 15 + .../storage/scan/context/LimitDigest.java | 13 + .../scan/context/OSRequestBuilderAction.java | 15 + .../storage/scan/context/PushDownContext.java | 126 +++++++ .../scan/context/PushDownOperation.java | 19 ++ .../storage/scan/context/PushDownType.java | 19 ++ .../request/AggregateAnalyzerTest.java | 47 ++- .../scan/CalciteIndexScanCostTest.java | 7 + 24 files changed, 553 insertions(+), 429 deletions(-) delete mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java rename opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/{PushDownContext.java => context/AggPushDownAction.java} (53%) create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index c218f005249..0aabc795bb2 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1023,7 +1023,7 @@ public void testExplainCountsByAgg() throws IOException { } @Test - public void testExplainSortOnMetricsNoBucketNullable() throws IOException { + public void testExplainSortOnMetrics() throws IOException { enabledOnlyWhenPushdownIsEnabled(); String expected = loadExpectedPlan("explain_agg_sort_on_metrics1.yaml"); assertYamlEqualsJsonIgnoreId( @@ -1031,8 +1031,35 @@ public void testExplainSortOnMetricsNoBucketNullable() throws IOException { explainQueryToString( "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() by" + " state | sort `count()`")); - expected = loadExpectedPlan("explain_agg_sort_on_metrics2.yaml"); + assertYamlEqualsJsonIgnoreId( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false sum(balance)" + + " as sum by state | sort - sum")); + // TODO limit should pushdown to non-composite agg + expected = loadExpectedPlan("explain_agg_sort_on_metrics3.yaml"); + assertYamlEqualsJsonIgnoreId( + expected, + explainQueryToString( + String.format( + "source=%s | stats count() as cnt by span(birthdate, 1d) | sort - cnt", + TEST_INDEX_BANK))); + expected = loadExpectedPlan("explain_agg_sort_on_metrics4.yaml"); + assertYamlEqualsJsonIgnoreId( + expected, + explainQueryToString( + String.format( + "source=%s | stats bucket_nullable=false sum(balance) by span(age, 5) | sort -" + + " `sum(balance)`", + TEST_INDEX_BANK))); + } + + @Ignore + public void testExplainSortOnMetricsMultiTerms() throws IOException { + // TODO support multi-terms + enabledOnlyWhenPushdownIsEnabled(); + String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms1.yaml"); assertYamlEqualsJsonIgnoreId( expected, explainQueryToString( diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json deleted file mode 100644 index 02352aa32e2..00000000000 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "calcite": { - "logical": "LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])\n LogicalProject(count()=[$1], state=[$0])\n LogicalAggregate(group=[{0}], count()=[COUNT()])\n LogicalProject(state=[$7])\n LogicalFilter(condition=[IS NOT NULL($7)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", - "physical": "EnumerableLimit(fetch=[10000])\n EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[FILTER->IS NOT NULL($7), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"exists\":{\"field\":\"state\",\"boost\":1.0}},\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":false,\"order\":\"asc\"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" - } -} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml index fb7435b52c1..b837e4968d4 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml @@ -8,6 +8,4 @@ calcite: LogicalFilter(condition=[IS NOT NULL($7)]) CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) physical: | - EnumerableLimit(fetch=[10000]) - EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[SORT_AGG_METRICS->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","_source":{"includes":["count()","state"],"excludes":[]},"aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"count()":"asc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), SORT_AGG_METRICS->[1 ASC FIRST], PROJECT->[count(), state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"count()":"asc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml index 8a45ecc2f92..1808eba1f08 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml @@ -1,13 +1,11 @@ calcite: logical: | - LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) - LogicalProject(count()=[$2], gender=[$0], state=[$1]) - LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) - LogicalProject(gender=[$4], state=[$7]) - LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(sum=[$1], state=[$0]) + LogicalAggregate(group=[{0}], sum=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) physical: | - EnumerableLimit(fetch=[10000]) - EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), PROJECT->[count(), gender, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":false,"order":"asc"}}},{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},sum=SUM($0)), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[sum, state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"sum":"desc"},{"_key":"asc"}]},"aggregations":{"sum":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml new file mode 100644 index 00000000000..a40c5cec466 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml @@ -0,0 +1,12 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(cnt=[$1], span(birthdate,1d)=[$0]) + LogicalAggregate(group=[{0}], cnt=[COUNT()]) + LogicalProject(span(birthdate,1d)=[SPAN($3, 1, 'd')]) + LogicalFilter(condition=[IS NOT NULL($3)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]]) + physical: | + EnumerableLimit(fetch=[10000]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[birthdate], FILTER->IS NOT NULL($0), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},cnt=COUNT()), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[cnt, span(birthdate,1d)]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","query":{"exists":{"field":"birthdate","boost":1.0}},"_source":{"includes":["birthdate"],"excludes":[]},"aggregations":{"span(birthdate,1d)":{"date_histogram":{"field":"birthdate","fixed_interval":"1d","offset":0,"order":[{"cnt":"desc"},{"_key":"asc"}],"keyed":false,"min_doc_count":0},"aggregations":{"cnt":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml new file mode 100644 index 00000000000..74ff751bcef --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml @@ -0,0 +1,12 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(sum(balance)=[$1], span(age,5)=[$0]) + LogicalAggregate(group=[{1}], sum(balance)=[SUM($0)]) + LogicalProject(balance=[$7], span(age,5)=[SPAN($10, 5, null:NULL)]) + LogicalFilter(condition=[IS NOT NULL($10)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]]) + physical: | + EnumerableLimit(fetch=[10000]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[balance, age], FILTER->IS NOT NULL($1), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},sum(balance)=SUM($0)), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[sum(balance), span(age,5)]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","query":{"exists":{"field":"age","boost":1.0}},"_source":{"includes":["balance","age"],"excludes":[]},"aggregations":{"span(age,5)":{"histogram":{"field":"age","interval":5.0,"offset":0.0,"order":[{"sum(balance)":"desc"},{"_key":"asc"}],"keyed":false,"min_doc_count":0},"aggregations":{"sum(balance)":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml new file mode 100644 index 00000000000..8a45ecc2f92 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml @@ -0,0 +1,13 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalProject(count()=[$2], gender=[$0], state=[$1]) + LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) + LogicalProject(gender=[$4], state=[$7]) + LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), PROJECT->[count(), gender, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":false,"order":"asc"}}},{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java index c8ab3f46c91..df47d1c9a77 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -84,7 +84,7 @@ protected void apply( LogicalAggregate aggregate, LogicalProject project, CalciteLogicalIndexScan scan) { - AbstractRelNode newRelNode = scan.pushDownAggregate(aggregate, project, null); + AbstractRelNode newRelNode = scan.pushDownAggregate(aggregate, project); if (newRelNode != null) { call.transformTo(newRelNode); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java index 3efdec21173..0fa0b4ebef0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java @@ -5,26 +5,11 @@ package org.opensearch.sql.opensearch.planner.physical; -import java.util.List; -import java.util.function.Function; import java.util.function.Predicate; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; -import org.apache.calcite.rel.AbstractRelNode; -import org.apache.calcite.rel.RelFieldCollation; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.logical.LogicalAggregate; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalSort; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.immutables.value.Value; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; @@ -39,72 +24,14 @@ protected SortAggregationMetricsRule(Config config) { @Override public void onMatch(RelOptRuleCall call) { final LogicalSort sort = call.rel(0); - final LogicalProject projectAddedBySort = call.rel(1); - final LogicalAggregate aggregate = call.rel(2); - final LogicalFilter filter = call.rel(3); - final LogicalProject project = call.rel(4); - final CalciteLogicalIndexScan scan = call.rel(5); + final CalciteLogicalIndexScan scan = call.rel(1); // Only support single metric sort if (sort.getCollation().getFieldCollations().size() != 1) { return; } - // Only support single metric in aggregate - if (aggregate.getAggCallList().size() != 1) { - return; - } - int possibleMetricsIndexInSort = - sort.getCollation().getFieldCollations().getFirst().getFieldIndex(); - RexNode possibleMetricsInProject = - projectAddedBySort.getProjects().get(possibleMetricsIndexInSort); - if (!(possibleMetricsInProject instanceof RexInputRef inputRef)) { - return; - } - int possibleMetricsIndexInProject = inputRef.getIndex(); - RelDataTypeField possibleMetricsInAggregate = - aggregate.getRowType().getFieldList().get(possibleMetricsIndexInProject); - if (possibleMetricsInAggregate.getType().getSqlTypeName().getFamily() - != SqlTypeFamily.NUMERIC) { - return; - } - if (!aggregate - .getAggCallList() - .getFirst() - .getName() - .equals(possibleMetricsInAggregate.getName())) { - return; - } - List groupSet = aggregate.getGroupSet().asList(); - RexNode condition = filter.getCondition(); - Function isNotNullFromAgg = - rex -> - rex instanceof RexCall rexCall - && rexCall.getOperator() == SqlStdOperatorTable.IS_NOT_NULL - && rexCall.getOperands().get(0) instanceof RexInputRef ref - && groupSet.contains(ref.getIndex()); - if (isNotNullFromAgg.apply(condition) - || (condition instanceof RexCall rexCall - && rexCall.getOperator() == SqlStdOperatorTable.AND - && rexCall.getOperands().stream().allMatch(isNotNullFromAgg::apply))) { - // Try to do the aggregate push down and ignore the filter if the filter sources from the - // aggregate's hint. See{@link CalciteRelNodeVisitor::visitAggregation} - RelFieldCollation.Direction direction = - sort.getCollation().getFieldCollations().getFirst().direction; - apply(call, projectAddedBySort, aggregate, project, scan, direction); - } - } - - protected void apply( - RelOptRuleCall call, - LogicalProject projectAddedBySort, - LogicalAggregate aggregate, - LogicalProject project, - CalciteLogicalIndexScan scan, - RelFieldCollation.Direction metricOrder) { - AbstractRelNode newScan = scan.pushDownAggregate(aggregate, project, metricOrder); + CalciteLogicalIndexScan newScan = scan.pushDownSortAggregateMetrics(sort); if (newScan != null) { - RelNode newScanWithProject = - call.builder().push(newScan).project(projectAddedBySort.getProjects()).build(); - call.transformTo(newScanWithProject); + call.transformTo(newScan); } } @@ -114,64 +41,17 @@ public interface Config extends RelRule.Config { SortAggregationMetricsRule.Config DEFAULT = ImmutableSortAggregationMetricsRule.Config.builder() .build() - .withDescription("Sort-Project-Agg-Filter-Project-TableScan") + .withDescription("Sort-TableScan(agg-pushed)") .withOperandSupplier( b0 -> b0.operand(LogicalSort.class) .predicate(PlanUtils::sortByFieldsOnly) .oneInput( b1 -> - b1.operand(LogicalProject.class) - .oneInput( - b2 -> - b2.operand(LogicalAggregate.class) - .predicate( - agg -> - agg.getHints().stream() - .anyMatch( - hint -> - hint.hintName.equals( - "stats_args") - && hint.kvOptions - .get( - Argument - .BUCKET_NULLABLE) - .equals("false"))) - .oneInput( - b3 -> - b3.operand(LogicalFilter.class) - .predicate( - OpenSearchAggregateIndexScanRule - .Config - ::mayBeFilterFromBucketNonNull) - .oneInput( - b4 -> - b4.operand(LogicalProject.class) - .predicate( - Predicate.not( - PlanUtils - ::containsRexOver) - .and( - PlanUtils - ::distinctProjectList) - .or( - OpenSearchAggregateIndexScanRule - .Config - ::containsWidthBucketFuncOnDate)) - .oneInput( - b5 -> - b5.operand( - CalciteLogicalIndexScan - .class) - .predicate( - Predicate - .not( - AbstractCalciteIndexScan - ::isLimitPushed) - .and( - AbstractCalciteIndexScan - ::noAggregatePushed)) - .noInputs())))))); + b1.operand(CalciteLogicalIndexScan.class) + .predicate( + Predicate.not(AbstractCalciteIndexScan::noAggregatePushed)) + .noInputs())); @Override default SortAggregationMetricsRule toRule() { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index e0782b48d8e..74cc0d9d5eb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -42,7 +42,6 @@ import java.util.function.Function; import lombok.RequiredArgsConstructor; import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; @@ -71,7 +70,6 @@ import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.search.sort.SortOrder; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.data.type.ExprCoreType; @@ -134,13 +132,12 @@ public static class CompositeAggUnSupportedException extends RuntimeException { private AggregateAnalyzer() {} @RequiredArgsConstructor - static class AggregateBuilderHelper { + public static class AggregateBuilderHelper { final RelDataType rowType; final Map fieldTypes; final RelOptCluster cluster; final boolean bucketNullable; final int bucketSize; - final RelFieldCollation.Direction metricOrder; > T build(RexNode node, T aggBuilder) { return build(node, aggBuilder::field, aggBuilder::script); @@ -185,26 +182,12 @@ T inferValue(RexNode node, Class clazz) { public static Pair, OpenSearchAggregationResponseParser> analyze( Aggregate aggregate, Project project, - RelDataType rowType, - Map fieldTypes, List outputFields, - RelFieldCollation.Direction metricOrder, - RelOptCluster cluster, - int bucketSize) + AggregateBuilderHelper helper) throws ExpressionNotAnalyzableException { requireNonNull(aggregate, "aggregate"); try { - boolean bucketNullable = - Boolean.parseBoolean( - aggregate.getHints().stream() - .filter(hits -> hits.hintName.equals("stats_args")) - .map(hint -> hint.kvOptions.getOrDefault(Argument.BUCKET_NULLABLE, "true")) - .findFirst() - .orElseGet(() -> "true")); List groupList = aggregate.getGroupSet().asList(); - AggregateBuilderHelper helper = - new AggregateBuilderHelper( - rowType, fieldTypes, cluster, bucketNullable, bucketSize, metricOrder); List aggFieldNames = outputFields.subList(groupList.size(), outputFields.size()); // Process all aggregate calls Pair> builderAndParser = @@ -216,7 +199,7 @@ public static Pair, OpenSearchAggregationResponseParser // but only count() can apply doc_count optimization in bucket aggregation. boolean countAllOnly = !aggregate.getGroupSet().isEmpty(); Pair, Builder> countAggNameAndBuilderPair = - removeCountAggregationBuilders(metricBuilder, countAllOnly, metricOrder); + removeCountAggregationBuilders(metricBuilder, countAllOnly); Builder newMetricBuilder = countAggNameAndBuilderPair.getRight(); List countAggNames = countAggNameAndBuilderPair.getLeft(); @@ -231,21 +214,8 @@ public static Pair, OpenSearchAggregationResponseParser new NoBucketAggregationParser(metricParserList)); } } else if (aggregate.getGroupSet().length() == 1 - && (isAutoDateSpan(project.getProjects().get(groupList.getFirst())) - || metricOrder != null)) { + && (isAutoDateSpan(project.getProjects().get(groupList.getFirst())))) { ValuesSourceAggregationBuilder bucketBuilder = createBucket(0, project, helper); - if (metricOrder != null - && bucketBuilder instanceof TermsAggregationBuilder termsAggregationBuilder) { - String path = - newMetricBuilder.getAggregatorFactories().stream() - .map(AggregationBuilder::getName) - .toList() - .getFirst(); - termsAggregationBuilder.order( - metricOrder == RelFieldCollation.Direction.ASCENDING - ? BucketOrder.aggregation(path, true) - : BucketOrder.aggregation(path, false)); - } if (newMetricBuilder != null) { bucketBuilder.subAggregations(newMetricBuilder); } @@ -258,7 +228,7 @@ public static Pair, OpenSearchAggregationResponseParser List> buckets = createCompositeBuckets(groupList, project, helper); aggregationBuilder = - AggregationBuilders.composite("composite_buckets", buckets).size(bucketSize); + AggregationBuilders.composite("composite_buckets", buckets).size(helper.bucketSize); if (newMetricBuilder != null) { aggregationBuilder.subAggregations(metricBuilder); } @@ -266,7 +236,7 @@ public static Pair, OpenSearchAggregationResponseParser Collections.singletonList(aggregationBuilder), new CompositeAggregationParser(metricParserList, countAggNames)); } catch (CompositeAggUnSupportedException e) { - if (bucketNullable) { + if (helper.bucketNullable) { throw new UnsupportedOperationException(e.getMessage()); } aggregationBuilder = createNestedBuckets(groupList, project, newMetricBuilder, helper); @@ -292,11 +262,7 @@ public static Pair, OpenSearchAggregationResponseParser * with the original metric builder. */ private static Pair, Builder> removeCountAggregationBuilders( - Builder metricBuilder, boolean countAllOnly, RelFieldCollation.Direction metricOrder) { - // if we have a specific metric order, skip the optimization of count agg removing - if (metricOrder != null) { - return Pair.of(List.of(), metricBuilder); - } + Builder metricBuilder, boolean countAllOnly) { List countAggregatorFactories = metricBuilder.getAggregatorFactories().stream() .filter(ValueCountAggregationBuilder.class::isInstance) @@ -643,10 +609,11 @@ private static CompositeValuesSourceBuilder createTermsSourceBuilder( private static ValuesSourceAggregationBuilder createTermsAggregationBuilder( String bucketName, RexNode group, AggregateBuilderHelper helper) { TermsAggregationBuilder sourceBuilder = - helper.build(group, new TermsAggregationBuilder(bucketName).size(helper.bucketSize)); - if (helper.metricOrder == null) { - sourceBuilder.order(BucketOrder.key(true)); - } + helper.build( + group, + new TermsAggregationBuilder(bucketName) + .size(helper.bucketSize) + .order(BucketOrder.key(true))); return withValueTypeHint( sourceBuilder, sourceBuilder::userValueTypeHint, diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index e1862ad29d1..cfb2201a8ef 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -45,6 +45,15 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AbstractAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownOperation; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; /** An abstract relational operator representing a scan of an OpenSearchIndex type. */ @Getter @@ -103,9 +112,10 @@ public double estimateRowCount(RelMetadataQuery mq) { osIndex.getMaxResultWindow().doubleValue(), (rowCount, operation) -> switch (operation.type()) { - case AGGREGATION, SORT_AGG_METRICS -> mq.getRowCount( - (RelNode) operation.digest()); + case AGGREGATION -> mq.getRowCount((RelNode) operation.digest()); case PROJECT, SORT -> rowCount; + case SORT_AGG_METRICS -> NumberUtil.min( + rowCount, osIndex.getBucketSize().doubleValue()); // Refer the org.apache.calcite.rel.metadata.RelMdRowCount case COLLAPSE -> rowCount / 10; case FILTER, SCRIPT -> NumberUtil.multiply( @@ -213,7 +223,7 @@ protected abstract AbstractCalciteIndexScan buildScan( RelDataType schema, PushDownContext pushDownContext); - private List getCollationNames(List collations) { + protected List getCollationNames(List collations) { return collations.stream() .map(collation -> getRowType().getFieldNames().get(collation.getFieldIndex())) .toList(); @@ -227,7 +237,7 @@ private List getCollationNames(List collations) { * @param collations List of collation names to check against aggregators. * @return True if any collation name matches an aggregator output, false otherwise. */ - private boolean hasAggregatorInSortBy(List collations) { + protected boolean hasAggregatorInSortBy(List collations) { Stream aggregates = pushDownContext.stream() .filter(action -> action.type() == PushDownType.AGGREGATION) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java index 7fb32cc761a..29e0c6ceb3c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java @@ -31,6 +31,7 @@ import org.opensearch.sql.calcite.plan.Scannable; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; /** The physical relational operator representing a scan of an OpenSearchIndex type. */ public class CalciteEnumerableIndexScan extends AbstractCalciteIndexScan diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index 9576507cb53..ee044cdff71 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -11,7 +11,6 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; -import javax.annotation.Nullable; import lombok.Getter; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptCluster; @@ -26,6 +25,7 @@ import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalSort; @@ -40,8 +40,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder; import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; @@ -55,6 +57,14 @@ import org.opensearch.sql.opensearch.request.PredicateAnalyzer.QueryExpression; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AbstractAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; /** The logical relational operator representing a scan of an OpenSearchIndex type. */ @Getter @@ -104,6 +114,11 @@ public CalciteLogicalIndexScan copyWithNewSchema(RelDataType schema) { getCluster(), traitSet, hints, table, osIndex, schema, pushDownContext.clone()); } + public CalciteLogicalIndexScan copyWithNewTraitSet(RelTraitSet traitSet) { + return new CalciteLogicalIndexScan( + getCluster(), traitSet, hints, table, osIndex, schema, pushDownContext.clone()); + } + @Override public void register(RelOptPlanner planner) { super.register(planner); @@ -274,8 +289,38 @@ private RelTraitSet reIndexCollations(List selectedColumns) { return newTraitSet; } - public AbstractRelNode pushDownAggregate( - Aggregate aggregate, Project project, @Nullable RelFieldCollation.Direction metricOrder) { + public CalciteLogicalIndexScan pushDownSortAggregateMetrics(Sort sort) { + try { + if (!pushDownContext.isAggregatePushed()) return null; + List aggregationBuilders = + pushDownContext.getAggPushDownAction().getAggregationBuilder().getLeft(); + if (aggregationBuilders.size() != 1) { + return null; + } + if (!(aggregationBuilders.getFirst() instanceof CompositeAggregationBuilder)) { + return null; + } + List collationNames = getCollationNames(sort.getCollation().getFieldCollations()); + if (!hasAggregatorInSortBy(collationNames)) { + return null; + } + AbstractAction newAction = + (AggregationBuilderAction) + aggAction -> + aggAction.pushDownSortAggMetrics( + sort.getCollation().getFieldCollations(), rowType.getFieldNames()); + Object digest = sort.getCollation().getFieldCollations(); + pushDownContext.add(PushDownType.SORT_AGG_METRICS, digest, newAction); + return copyWithNewTraitSet(sort.getTraitSet()); + } catch (Exception e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the sort aggregate {}", sort, e); + } + } + return null; + } + + public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { try { CalciteLogicalIndexScan newScan = new CalciteLogicalIndexScan( @@ -294,16 +339,18 @@ public AbstractRelNode pushDownAggregate( .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); List outputFields = aggregate.getRowType().getFieldNames(); int bucketSize = osIndex.getBucketSize(); + boolean bucketNullable = + Boolean.parseBoolean( + aggregate.getHints().stream() + .filter(hits -> hits.hintName.equals("stats_args")) + .map(hint -> hint.kvOptions.getOrDefault(Argument.BUCKET_NULLABLE, "true")) + .findFirst() + .orElseGet(() -> "true")); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper( + getRowType(), fieldTypes, getCluster(), bucketNullable, bucketSize); final Pair, OpenSearchAggregationResponseParser> aggregationBuilder = - AggregateAnalyzer.analyze( - aggregate, - project, - getRowType(), - fieldTypes, - outputFields, - metricOrder, - getCluster(), - bucketSize); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); Map extendedTypeMapping = aggregate.getRowType().getFieldList().stream() .collect( @@ -318,10 +365,7 @@ public AbstractRelNode pushDownAggregate( aggregationBuilder, extendedTypeMapping, outputFields.subList(0, aggregate.getGroupSet().cardinality())); - newScan.pushDownContext.add( - metricOrder == null ? PushDownType.AGGREGATION : PushDownType.SORT_AGG_METRICS, - aggregate, - action); + newScan.pushDownContext.add(PushDownType.AGGREGATION, aggregate, action); if (aggregationBuilder.getLeft().size() == 1 && aggregationBuilder.getLeft().getFirst() instanceof AutoDateHistogramAggregationBuilder autoDateHistogram) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java new file mode 100644 index 00000000000..fa48ceded6d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +public interface AbstractAction { + void apply(T target); + + void transform(PushDownContext context, PushDownOperation operation); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java similarity index 53% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java index 33bb9725841..fbdc2b19203 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java @@ -3,29 +3,28 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.opensearch.storage.scan; +package org.opensearch.sql.opensearch.storage.scan.context; -import com.google.common.collect.Iterators; -import java.util.AbstractCollection; -import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.calcite.rel.RelFieldCollation; -import org.apache.calcite.rel.RelFieldCollation.Direction; -import org.apache.calcite.rel.RelFieldCollation.NullDirection; -import org.apache.calcite.rex.RexNode; import org.apache.commons.lang3.tuple.Pair; -import org.jetbrains.annotations.NotNull; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.AggregatorFactories.Builder; +import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.BucketOrder; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.DateHistogramValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.HistogramValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; @@ -33,185 +32,18 @@ import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.agg.BucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.MetricParserHelper; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; @Getter -public class PushDownContext extends AbstractCollection { - private final OpenSearchIndex osIndex; - private final OpenSearchRequestBuilder requestBuilder; - private ArrayDeque operationsForRequestBuilder; - - private boolean isAggregatePushed = false; - private AggPushDownAction aggPushDownAction; - private ArrayDeque operationsForAgg; - - private boolean isLimitPushed = false; - private boolean isProjectPushed = false; - private boolean isMetricOrderPushed = false; - - public PushDownContext(OpenSearchIndex osIndex) { - this.osIndex = osIndex; - this.requestBuilder = osIndex.createRequestBuilder(); - } - - @Override - public PushDownContext clone() { - PushDownContext newContext = new PushDownContext(osIndex); - newContext.addAll(this); - return newContext; - } - - /** - * Create a new {@link PushDownContext} without the collation action. - * - * @return A new push-down context without the collation action. - */ - public PushDownContext cloneWithoutSort() { - PushDownContext newContext = new PushDownContext(osIndex); - for (PushDownOperation action : this) { - if (action.type() != PushDownType.SORT) { - newContext.add(action); - } - } - return newContext; - } - - @NotNull - @Override - public Iterator iterator() { - if (operationsForRequestBuilder == null) { - return Collections.emptyIterator(); - } else if (operationsForAgg == null) { - return operationsForRequestBuilder.iterator(); - } else { - return Iterators.concat(operationsForRequestBuilder.iterator(), operationsForAgg.iterator()); - } - } - - @Override - public int size() { - return (operationsForRequestBuilder == null ? 0 : operationsForRequestBuilder.size()) - + (operationsForAgg == null ? 0 : operationsForAgg.size()); - } - - ArrayDeque getOperationsForRequestBuilder() { - if (operationsForRequestBuilder == null) { - this.operationsForRequestBuilder = new ArrayDeque<>(); - } - return operationsForRequestBuilder; - } - - ArrayDeque getOperationsForAgg() { - if (operationsForAgg == null) { - this.operationsForAgg = new ArrayDeque<>(); - } - return operationsForAgg; - } - - @Override - public boolean add(PushDownOperation operation) { - if (operation.type() == PushDownType.AGGREGATION) { - isAggregatePushed = true; - this.aggPushDownAction = (AggPushDownAction) operation.action(); - } - if (operation.type() == PushDownType.LIMIT) { - isLimitPushed = true; - } - if (operation.type() == PushDownType.PROJECT) { - isProjectPushed = true; - } - if (operation.type() == PushDownType.SORT_AGG_METRICS) { - isMetricOrderPushed = true; - } - operation.action().transform(this, operation); - return true; - } - - void add(PushDownType type, Object digest, AbstractAction action) { - add(new PushDownOperation(type, digest, action)); - } - - public boolean containsDigest(Object digest) { - return this.stream().anyMatch(action -> action.digest().equals(digest)); - } - - public OpenSearchRequestBuilder createRequestBuilder() { - OpenSearchRequestBuilder newRequestBuilder = osIndex.createRequestBuilder(); - if (operationsForRequestBuilder != null) { - operationsForRequestBuilder.forEach( - operation -> ((OSRequestBuilderAction) operation.action()).apply(newRequestBuilder)); - } - return newRequestBuilder; - } -} - -enum PushDownType { - FILTER, - PROJECT, - AGGREGATION, - SORT, - LIMIT, - SCRIPT, - COLLAPSE, - SORT_AGG_METRICS - // HIGHLIGHT, - // NESTED -} - -/** - * Represents a push down operation that can be applied to an OpenSearchRequestBuilder. - * - * @param type PushDownType enum - * @param digest the digest of the pushed down operator - * @param action the lambda action to apply on the OpenSearchRequestBuilder - */ -record PushDownOperation(PushDownType type, Object digest, AbstractAction action) { - public String toString() { - return type + "->" + digest; - } -} - -interface AbstractAction { - void apply(T target); - - void transform(PushDownContext context, PushDownOperation operation); -} - -interface OSRequestBuilderAction extends AbstractAction { - default void transform(PushDownContext context, PushDownOperation operation) { - apply(context.getRequestBuilder()); - context.getOperationsForRequestBuilder().add(operation); - } -} - -interface AggregationBuilderAction extends AbstractAction { - default void transform(PushDownContext context, PushDownOperation operation) { - apply(context.getAggPushDownAction()); - context.getOperationsForAgg().add(operation); - } -} - -record FilterDigest(int scriptCount, RexNode condition) { - @Override - public String toString() { - return condition.toString(); - } -} - -record LimitDigest(int limit, int offset) { - @Override - public String toString() { - return offset == 0 ? String.valueOf(limit) : "[" + limit + " from " + offset + "]"; - } -} - -// TODO: shall we do deep copy for this action since it's mutable? -class AggPushDownAction implements OSRequestBuilderAction { +@EqualsAndHashCode +public class AggPushDownAction implements OSRequestBuilderAction { private Pair, OpenSearchAggregationResponseParser> aggregationBuilder; private final Map extendedTypeMapping; - @Getter private final long scriptCount; + private final long scriptCount; // Record the output field names of all buckets as the sequence of buckets private List bucketNames; @@ -237,6 +69,110 @@ public void apply(OpenSearchRequestBuilder requestBuilder) { requestBuilder.pushTypeMapping(extendedTypeMapping); } + private BucketAggregationParser convertTo(OpenSearchAggregationResponseParser parser) { + if (parser instanceof BucketAggregationParser) { + return (BucketAggregationParser) parser; + } else if (parser instanceof CompositeAggregationParser) { + MetricParserHelper helper = ((CompositeAggregationParser) parser).getMetricsParser(); + return new BucketAggregationParser( + helper.getMetricParserMap().values().stream().toList(), helper.getCountAggNameList()); + } else { + throw new IllegalStateException("Unexpected parser type: " + parser.getClass()); + } + } + + public void pushDownSortAggMetrics(List collations, List fieldNames) { + if (aggregationBuilder.getLeft().isEmpty()) return; + AggregationBuilder builder = aggregationBuilder.getLeft().getFirst(); + if (builder instanceof CompositeAggregationBuilder composite) { + String path = getAggregationPath(collations, fieldNames, composite); + BucketOrder bucketOrder = + collations.get(0).getDirection() == RelFieldCollation.Direction.ASCENDING + ? BucketOrder.aggregation(path, true) + : BucketOrder.aggregation(path, false); + + if (composite.sources().size() == 1) { + if (composite.sources().get(0) instanceof TermsValuesSourceBuilder terms + && !terms.missingBucket()) { + TermsAggregationBuilder termsBuilder = new TermsAggregationBuilder(terms.name()); + termsBuilder.size(composite.size()); + termsBuilder.field(terms.field()); + termsBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, termsBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(termsBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } else if (composite.sources().get(0) + instanceof DateHistogramValuesSourceBuilder dateHisto) { + DateHistogramAggregationBuilder dateHistoBuilder = + new DateHistogramAggregationBuilder(dateHisto.name()); + dateHistoBuilder.field(dateHisto.field()); + try { + dateHistoBuilder.fixedInterval(dateHisto.getIntervalAsFixed()); + } catch (IllegalArgumentException e) { + dateHistoBuilder.calendarInterval(dateHisto.getIntervalAsCalendar()); + } + dateHistoBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, dateHistoBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(dateHistoBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } else if (composite.sources().get(0) instanceof HistogramValuesSourceBuilder histo + && !histo.missingBucket()) { + HistogramAggregationBuilder histoBuilder = new HistogramAggregationBuilder(histo.name()); + histoBuilder.field(histo.field()); + histoBuilder.interval(histo.interval()); + histoBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, histoBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(histoBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } + } else { + if (composite.sources().stream() + .allMatch( + src -> src instanceof TermsValuesSourceBuilder terms && !terms.missingBucket())) { + // multi-term agg + return; + } + } + throw new OpenSearchRequestBuilder.PushDownUnSupportedException( + "Cannot pushdown sort aggregate metrics"); + } + } + + private String getAggregationPath( + List collations, + List fieldNames, + CompositeAggregationBuilder composite) { + String path; + if (composite.getSubAggregations().isEmpty()) { + // count agg optimized, get the path name from field names + path = fieldNames.get(collations.get(0).getFieldIndex()); + } else { + path = composite.getSubAggregations().stream().toList().get(0).getName(); + } + return path; + } + + private > T attachSubAggregations( + Collection subAggregations, String path, T aggregationBuilder) { + AggregatorFactories.Builder metricBuilder = new AggregatorFactories.Builder(); + if (subAggregations.isEmpty()) { + metricBuilder.addAggregator(AggregationBuilders.count(path).field("_index")); + } else { + metricBuilder.addAggregator(subAggregations.stream().toList().get(0)); + } + aggregationBuilder.subAggregations(metricBuilder); + return aggregationBuilder; + } + public void pushDownSortIntoAggBucket( List collations, List fieldNames) { // aggregationBuilder.getLeft() could be empty when count agg optimization works @@ -260,10 +196,12 @@ public void pushDownSortIntoAggBucket( */ String bucketName = fieldNames.get(collation.getFieldIndex()); CompositeValuesSourceBuilder bucket = buckets.get(bucketNames.indexOf(bucketName)); - Direction direction = collation.getDirection(); - NullDirection nullDirection = collation.nullDirection; + RelFieldCollation.Direction direction = collation.getDirection(); + RelFieldCollation.NullDirection nullDirection = collation.nullDirection; SortOrder order = - Direction.DESCENDING.equals(direction) ? SortOrder.DESC : SortOrder.ASC; + RelFieldCollation.Direction.DESCENDING.equals(direction) + ? SortOrder.DESC + : SortOrder.ASC; if (bucket.missingBucket()) { MissingOrder missingOrder = switch (nullDirection) { @@ -285,7 +223,7 @@ public void pushDownSortIntoAggBucket( newBuckets.add(buckets.get(bucketNames.indexOf(name))); newBucketNames.add(name); }); - Builder newAggBuilder = new Builder(); + AggregatorFactories.Builder newAggBuilder = new AggregatorFactories.Builder(); compositeAggBuilder.getSubAggregations().forEach(newAggBuilder::addAggregator); aggregationBuilder = Pair.of( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java new file mode 100644 index 00000000000..87b1fd45cad --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +public interface AggregationBuilderAction extends AbstractAction { + default void transform(PushDownContext context, PushDownOperation operation) { + apply(context.getAggPushDownAction()); + context.getOperationsForAgg().add(operation); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java new file mode 100644 index 00000000000..30c8d01feb1 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import org.apache.calcite.rex.RexNode; + +public record FilterDigest(int scriptCount, RexNode condition) { + @Override + public String toString() { + return condition.toString(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java new file mode 100644 index 00000000000..342cd33d4b8 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +public record LimitDigest(int limit, int offset) { + @Override + public String toString() { + return offset == 0 ? String.valueOf(limit) : "[" + limit + " from " + offset + "]"; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java new file mode 100644 index 00000000000..82a8682587e --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; + +public interface OSRequestBuilderAction extends AbstractAction { + default void transform(PushDownContext context, PushDownOperation operation) { + apply(context.getRequestBuilder()); + context.getOperationsForRequestBuilder().add(operation); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java new file mode 100644 index 00000000000..693aff80466 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import com.google.common.collect.Iterators; +import java.util.AbstractCollection; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Iterator; +import lombok.Getter; +import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; + +@Getter +public class PushDownContext extends AbstractCollection { + private final OpenSearchIndex osIndex; + private final OpenSearchRequestBuilder requestBuilder; + private ArrayDeque operationsForRequestBuilder; + + private boolean isAggregatePushed = false; + private AggPushDownAction aggPushDownAction; + private ArrayDeque operationsForAgg; + + private boolean isLimitPushed = false; + private boolean isProjectPushed = false; + private boolean isMetricOrderPushed = false; + + public PushDownContext(OpenSearchIndex osIndex) { + this.osIndex = osIndex; + this.requestBuilder = osIndex.createRequestBuilder(); + } + + @Override + public PushDownContext clone() { + PushDownContext newContext = new PushDownContext(osIndex); + newContext.addAll(this); + return newContext; + } + + /** + * Create a new {@link PushDownContext} without the collation action. + * + * @return A new push-down context without the collation action. + */ + public PushDownContext cloneWithoutSort() { + PushDownContext newContext = new PushDownContext(osIndex); + for (PushDownOperation action : this) { + if (action.type() != PushDownType.SORT) { + newContext.add(action); + } + } + return newContext; + } + + @NotNull + @Override + public Iterator iterator() { + if (operationsForRequestBuilder == null) { + return Collections.emptyIterator(); + } else if (operationsForAgg == null) { + return operationsForRequestBuilder.iterator(); + } else { + return Iterators.concat(operationsForRequestBuilder.iterator(), operationsForAgg.iterator()); + } + } + + @Override + public int size() { + return (operationsForRequestBuilder == null ? 0 : operationsForRequestBuilder.size()) + + (operationsForAgg == null ? 0 : operationsForAgg.size()); + } + + ArrayDeque getOperationsForRequestBuilder() { + if (operationsForRequestBuilder == null) { + this.operationsForRequestBuilder = new ArrayDeque<>(); + } + return operationsForRequestBuilder; + } + + ArrayDeque getOperationsForAgg() { + if (operationsForAgg == null) { + this.operationsForAgg = new ArrayDeque<>(); + } + return operationsForAgg; + } + + @Override + public boolean add(PushDownOperation operation) { + if (operation.type() == PushDownType.AGGREGATION) { + isAggregatePushed = true; + this.aggPushDownAction = (AggPushDownAction) operation.action(); + } + if (operation.type() == PushDownType.LIMIT) { + isLimitPushed = true; + } + if (operation.type() == PushDownType.PROJECT) { + isProjectPushed = true; + } + if (operation.type() == PushDownType.SORT_AGG_METRICS) { + isMetricOrderPushed = true; + } + operation.action().transform(this, operation); + return true; + } + + public void add(PushDownType type, Object digest, AbstractAction action) { + add(new PushDownOperation(type, digest, action)); + } + + public boolean containsDigest(Object digest) { + return this.stream().anyMatch(action -> action.digest().equals(digest)); + } + + public OpenSearchRequestBuilder createRequestBuilder() { + OpenSearchRequestBuilder newRequestBuilder = osIndex.createRequestBuilder(); + if (operationsForRequestBuilder != null) { + operationsForRequestBuilder.forEach( + operation -> ((OSRequestBuilderAction) operation.action()).apply(newRequestBuilder)); + } + return newRequestBuilder; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java new file mode 100644 index 00000000000..c5779564369 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +/** + * Represents a push down operation that can be applied to an OpenSearchRequestBuilder. + * + * @param type PushDownType enum + * @param digest the digest of the pushed down operator + * @param action the lambda action to apply on the OpenSearchRequestBuilder + */ +public record PushDownOperation(PushDownType type, Object digest, AbstractAction action) { + public String toString() { + return type + "->" + digest; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java new file mode 100644 index 00000000000..a241a7c0d48 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +public enum PushDownType { + FILTER, + PROJECT, + AGGREGATION, + SORT, + LIMIT, + SCRIPT, + COLLAPSE, + SORT_AGG_METRICS + // HIGHLIGHT, + // NESTED +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index c2260911f75..2b79103a06f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -152,9 +152,10 @@ void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { createMockAggregate( List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}," + " {\"avg\":{\"avg\":{\"field\":\"a\"}}}," @@ -234,9 +235,10 @@ void analyze_aggCall_extended() throws ExpressionNotAnalyzableException { createMockAggregate( List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"var_samp\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," + " {\"var_pop\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," @@ -274,9 +276,10 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException { List outputFields = List.of("a", "b", "cnt"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1)); Project project = createMockProject(List.of(0, 1)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[" @@ -313,19 +316,12 @@ void analyze_aggCall_TextWithoutKeyword() { "sum"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(2)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, - () -> - AggregateAnalyzer.analyze( - aggregate, - project, - rowType, - fieldTypes, - List.of("sum"), - null, - null, - BUCKET_SIZE)); + () -> AggregateAnalyzer.analyze(aggregate, project, List.of("sum"), helper)); assertEquals("[field] must not be null: [sum]", exception.getCause().getMessage()); } @@ -347,19 +343,12 @@ void analyze_groupBy_TextWithoutKeyword() { List outputFields = List.of("c", "cnt"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0)); Project project = createMockProject(List.of(2)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, - () -> - AggregateAnalyzer.analyze( - aggregate, - project, - rowType, - fieldTypes, - outputFields, - null, - null, - BUCKET_SIZE)); + () -> AggregateAnalyzer.analyze(aggregate, project, outputFields, helper)); assertEquals("[field] must not be null", exception.getCause().getMessage()); } @@ -702,9 +691,11 @@ void verify() throws ExpressionNotAnalyzableException { if (agg.getInput(0) instanceof Project) { project = (Project) agg.getInput(0); } + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - agg, project, rowType, fieldTypes, outputFields, null, agg.getCluster(), BUCKET_SIZE); + AggregateAnalyzer.analyze(agg, project, outputFields, helper); if (expectedDsl != null) { assertEquals(expectedDsl, result.getLeft().toString()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java index f02b40b9eae..c67d7cfaa3e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java @@ -49,6 +49,13 @@ import org.opensearch.sql.common.setting.Settings.Key; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownOperation; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; @ExtendWith(MockitoExtension.class) public class CalciteIndexScanCostTest { From 4cf940df2f768b4890745ff58530304db63f9af5 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 21 Oct 2025 18:37:35 +0800 Subject: [PATCH 3/9] fix conflicts Signed-off-by: Lantao Jin --- .../org/opensearch/sql/calcite/remote/CalciteExplainIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 475aec1d154..7c71f87bcf1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1069,7 +1069,7 @@ public void testExplainSortOnMetricsMultiTerms() throws IOException { // TODO support multi-terms enabledOnlyWhenPushdownIsEnabled(); String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms1.yaml"); - assertYamlEqualsJsonIgnoreId( + assertYamlEqualsIgnoreId( expected, explainQueryYaml( "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() by" From 1a1fe05af54a295a8d41c33d8172c56a7b2b0811 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 21 Oct 2025 18:41:17 +0800 Subject: [PATCH 4/9] fix conflicts2 Signed-off-by: Lantao Jin --- .../org/opensearch/sql/calcite/remote/CalciteExplainIT.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 7c71f87bcf1..771c26f7372 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1043,21 +1043,21 @@ public void testExplainSortOnMetrics() throws IOException { expected = loadExpectedPlan("explain_agg_sort_on_metrics2.yaml"); assertYamlEqualsIgnoreId( expected, - explainQueryToString( + explainQueryYaml( "source=opensearch-sql_test_index_account | stats bucket_nullable=false sum(balance)" + " as sum by state | sort - sum")); // TODO limit should pushdown to non-composite agg expected = loadExpectedPlan("explain_agg_sort_on_metrics3.yaml"); assertYamlEqualsIgnoreId( expected, - explainQueryToString( + explainQueryYaml( String.format( "source=%s | stats count() as cnt by span(birthdate, 1d) | sort - cnt", TEST_INDEX_BANK))); expected = loadExpectedPlan("explain_agg_sort_on_metrics4.yaml"); assertYamlEqualsIgnoreId( expected, - explainQueryToString( + explainQueryYaml( String.format( "source=%s | stats bucket_nullable=false sum(balance) by span(age, 5) | sort -" + " `sum(balance)`", From 71e87ef3c2e3b6c203d7cc75a82c72c05d93b42c Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 23 Oct 2025 14:07:22 +0800 Subject: [PATCH 5/9] Add more javadoc Signed-off-by: Lantao Jin --- .../storage/scan/context/AbstractAction.java | 11 +++++++++++ .../storage/scan/context/AggPushDownAction.java | 1 + .../scan/context/AggregationBuilderAction.java | 7 +++++++ .../storage/scan/context/OSRequestBuilderAction.java | 8 ++++++++ .../storage/scan/context/PushDownContext.java | 1 + .../opensearch/storage/scan/context/PushDownType.java | 1 + 6 files changed, 29 insertions(+) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java index fa48ceded6d..65ef6233ffb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java @@ -5,8 +5,19 @@ package org.opensearch.sql.opensearch.storage.scan.context; +/** + * A lambda action to apply on the target T + * + * @param the target type + */ public interface AbstractAction { void apply(T target); + /** + * Apply the action on the target T and add the operation to the context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ void transform(PushDownContext context, PushDownOperation operation); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java index fbdc2b19203..acb13d67286 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java @@ -37,6 +37,7 @@ import org.opensearch.sql.opensearch.response.agg.MetricParserHelper; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +/** A lambda aggregation pushdown action to apply on the {@link OpenSearchRequestBuilder} */ @Getter @EqualsAndHashCode public class AggPushDownAction implements OSRequestBuilderAction { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java index 87b1fd45cad..cd3e84bf7cf 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java @@ -5,7 +5,14 @@ package org.opensearch.sql.opensearch.storage.scan.context; +/** A lambda action to apply on the {@link AggPushDownAction} */ public interface AggregationBuilderAction extends AbstractAction { + /** + * Apply the action on the target {@link AggPushDownAction} and add the operation to the context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ default void transform(PushDownContext context, PushDownOperation operation) { apply(context.getAggPushDownAction()); context.getOperationsForAgg().add(operation); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java index 82a8682587e..bba33883b49 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java @@ -7,7 +7,15 @@ import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +/** A lambda action to apply on the {@link OpenSearchRequestBuilder} */ public interface OSRequestBuilderAction extends AbstractAction { + /** + * Apply the action on the target {@link OpenSearchRequestBuilder} and add the operation to the + * context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ default void transform(PushDownContext context, PushDownOperation operation) { apply(context.getRequestBuilder()); context.getOperationsForRequestBuilder().add(operation); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java index 693aff80466..dd36c2090b9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java @@ -15,6 +15,7 @@ import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +/** Push down context is used to store all the push down operations that are applied to the query */ @Getter public class PushDownContext extends AbstractCollection { private final OpenSearchIndex osIndex; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java index a241a7c0d48..2a9eccb7a0e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java @@ -5,6 +5,7 @@ package org.opensearch.sql.opensearch.storage.scan.context; +/** Push down types. */ public enum PushDownType { FILTER, PROJECT, From 0a6bc4f8c281d365f4b5d71968fe358598ced70f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 23 Oct 2025 21:52:22 +0800 Subject: [PATCH 6/9] address comments Signed-off-by: Lantao Jin --- .../sql/calcite/remote/CalciteExplainIT.java | 41 +++++++++++++++++++ ...range_metric_sort_agg_metric_not_push.yaml | 14 +++++++ ...ite_autodate_sort_agg_metric_not_push.yaml | 14 +++++++ ...posite_range_sort_agg_metric_not_push.yaml | 14 +++++++ .../physical/SortAggregationMetricsRule.java | 10 ++--- .../scan/context/AggPushDownAction.java | 10 ++++- 6 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 5ea19434b78..60edb3b06e6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1062,6 +1062,47 @@ public void testExplainSortOnMetricsMultiTerms() throws IOException { + " gender, state | sort `count()`")); } + @Test + public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOException { + // For single bucket, only composite agg can apply pushdown sort agg metrics + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_range_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | eval value_range = case(value < 7000, 'small'" + + " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by" + + " value_range, category | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainCompositeAutoDateThenSortOnMetricsNotPushdown() throws IOException { + // For single bucket, only composite agg can apply pushdown sort agg metrics + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_autodate_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()" + + " as cnt by timestamp, category | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() throws IOException { + // For multiple buckets, only all term-buckets can apply multi-terms + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | eval value_range = case(value < 7000, 'small'" + + " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by" + + " timestamp, value_range | sort cnt", + TEST_INDEX_TIME_DATA))); + } + @Test public void testExplainEvalMax() throws IOException { String expected = loadExpectedPlan("explain_eval_max.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..314c0fbbf6b --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$2], cnt=[$3], timestamp=[$0], value_range=[$1]) + LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) + LogicalProject(timestamp=[$9], value_range=[$10], value=[$2]) + LogicalFilter(condition=[IS NOT NULL($9)]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1, 2},avg(value)=AVG($0),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, value_range]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..e3d4d9fba4d --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$2], cnt=[$3], timestamp=[$0], category=[$1]) + LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) + LogicalProject(timestamp=[$9], category=[$1], value=[$2]) + LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..19846e9910b --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$2], cnt=[$3], value_range=[$0], category=[$1]) + LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) + LogicalProject(value_range=[$10], category=[$1], value=[$2]) + LogicalFilter(condition=[IS NOT NULL($1)]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], timestamp=[$3], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java index 0fa0b4ebef0..63b04e8c099 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java @@ -8,6 +8,7 @@ import java.util.function.Predicate; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.logical.LogicalSort; import org.immutables.value.Value; import org.opensearch.sql.calcite.utils.PlanUtils; @@ -25,10 +26,6 @@ protected SortAggregationMetricsRule(Config config) { public void onMatch(RelOptRuleCall call) { final LogicalSort sort = call.rel(0); final CalciteLogicalIndexScan scan = call.rel(1); - // Only support single metric sort - if (sort.getCollation().getFieldCollations().size() != 1) { - return; - } CalciteLogicalIndexScan newScan = scan.pushDownSortAggregateMetrics(sort); if (newScan != null) { call.transformTo(newScan); @@ -38,6 +35,9 @@ public void onMatch(RelOptRuleCall call) { /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { + // TODO support multiple metrics, only support single metric sort + Predicate hasOneFieldCollation = + sort -> sort.getCollation().getFieldCollations().size() == 1; SortAggregationMetricsRule.Config DEFAULT = ImmutableSortAggregationMetricsRule.Config.builder() .build() @@ -45,7 +45,7 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(LogicalSort.class) - .predicate(PlanUtils::sortByFieldsOnly) + .predicate(hasOneFieldCollation.and(PlanUtils::sortByFieldsOnly)) .oneInput( b1 -> b1.operand(CalciteLogicalIndexScan.class) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java index acb13d67286..44c894f89be 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java @@ -153,11 +153,17 @@ private String getAggregationPath( List fieldNames, CompositeAggregationBuilder composite) { String path; - if (composite.getSubAggregations().isEmpty()) { + AggregationBuilder metric = composite.getSubAggregations().stream().findFirst().orElse(null); + if (metric == null) { // count agg optimized, get the path name from field names path = fieldNames.get(collations.get(0).getFieldIndex()); + } else if (metric instanceof ValuesSourceAggregationBuilder.LeafOnly) { + path = metric.getName(); } else { - path = composite.getSubAggregations().stream().toList().get(0).getName(); + // we do not support pushdown sort aggregate metrics for nested aggregation + throw new OpenSearchRequestBuilder.PushDownUnSupportedException( + "Cannot pushdown sort aggregate metrics, composite.getSubAggregations() is not a" + + " LeafOnly"); } return path; } From e77ab8e169f7aa9402522f91bab4f83a8da285ac Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 23 Oct 2025 22:06:29 +0800 Subject: [PATCH 7/9] delete incorrect comments Signed-off-by: Lantao Jin --- .../org/opensearch/sql/calcite/remote/CalciteExplainIT.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 60edb3b06e6..0880953d711 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1064,7 +1064,6 @@ public void testExplainSortOnMetricsMultiTerms() throws IOException { @Test public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOException { - // For single bucket, only composite agg can apply pushdown sort agg metrics enabledOnlyWhenPushdownIsEnabled(); assertYamlEqualsIgnoreId( loadExpectedPlan("agg_composite_range_sort_agg_metric_not_push.yaml"), @@ -1078,7 +1077,6 @@ public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOExc @Test public void testExplainCompositeAutoDateThenSortOnMetricsNotPushdown() throws IOException { - // For single bucket, only composite agg can apply pushdown sort agg metrics enabledOnlyWhenPushdownIsEnabled(); assertYamlEqualsIgnoreId( loadExpectedPlan("agg_composite_autodate_sort_agg_metric_not_push.yaml"), @@ -1091,7 +1089,6 @@ public void testExplainCompositeAutoDateThenSortOnMetricsNotPushdown() throws IO @Test public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() throws IOException { - // For multiple buckets, only all term-buckets can apply multi-terms enabledOnlyWhenPushdownIsEnabled(); assertYamlEqualsIgnoreId( loadExpectedPlan("agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml"), From a7a1501a1dcb396caa861129b313ca9ce0bd5b20 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 24 Oct 2025 12:59:46 +0800 Subject: [PATCH 8/9] convert composite agg to multi-terms agg for sort metrics on multiple buckets Signed-off-by: Lantao Jin --- .../sql/calcite/remote/CalciteExplainIT.java | 20 ++++++++--- ...range_metric_sort_agg_metric_not_push.yaml | 10 +++--- ...rms_autodate_sort_agg_metric_not_push.yaml | 15 +++++++++ ...plain_agg_sort_on_metrics_multi_terms.yaml | 11 +++++++ ...lain_agg_sort_on_metrics_multi_terms1.yaml | 13 -------- .../scan/context/AggPushDownAction.java | 33 ++++++++++++++++++- 6 files changed, 79 insertions(+), 23 deletions(-) create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml delete mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 0880953d711..89cb7398cf6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1050,11 +1050,10 @@ public void testExplainSortOnMetrics() throws IOException { TEST_INDEX_BANK))); } - @Ignore + @Test public void testExplainSortOnMetricsMultiTerms() throws IOException { - // TODO support multi-terms enabledOnlyWhenPushdownIsEnabled(); - String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms1.yaml"); + String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms.yaml"); assertYamlEqualsIgnoreId( expected, explainQueryYaml( @@ -1062,6 +1061,19 @@ public void testExplainSortOnMetricsMultiTerms() throws IOException { + " gender, state | sort `count()`")); } + @Test + public void testExplainCompositeMultiBucketsAutoDateThenSortOnMetricsNotPushdown() + throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()" + + " as cnt by category, value, timestamp | sort cnt", + TEST_INDEX_TIME_DATA))); + } + @Test public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOException { enabledOnlyWhenPushdownIsEnabled(); @@ -1096,7 +1108,7 @@ public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() thro String.format( "source=%s | bin timestamp bins=3 | eval value_range = case(value < 7000, 'small'" + " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by" - + " timestamp, value_range | sort cnt", + + " timestamp, value_range, category | sort cnt", TEST_INDEX_TIME_DATA))); } diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml index 314c0fbbf6b..90e83946c38 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml @@ -2,13 +2,13 @@ calcite: logical: | LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) - LogicalProject(avg(value)=[$2], cnt=[$3], timestamp=[$0], value_range=[$1]) - LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) - LogicalProject(timestamp=[$9], value_range=[$10], value=[$2]) - LogicalFilter(condition=[IS NOT NULL($9)]) + LogicalProject(avg(value)=[$3], cnt=[$4], timestamp=[$0], value_range=[$1], category=[$2]) + LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($3)], cnt=[COUNT()]) + LogicalProject(timestamp=[$9], value_range=[$10], category=[$1], value=[$2]) + LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))]) LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)]) CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) physical: | EnumerableLimit(fetch=[10000]) EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1, 2},avg(value)=AVG($0),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, value_range]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2, 3},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..6995097f878 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml @@ -0,0 +1,15 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$3], cnt=[$4], category=[$0], value=[$1], timestamp=[$2]) + LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($1)], cnt=[COUNT()]) + LogicalProject(category=[$1], value=[$2], timestamp=[$9]) + LogicalFilter(condition=[AND(IS NOT NULL($1), IS NOT NULL($2), IS NOT NULL($9))]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DOUBLE], avg(value)=[$t4], cnt=[$t3], category=[$t0], value=[$t1], timestamp=[$t2]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1, 2},cnt=COUNT())], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}},{"value":{"terms":{"field":"value","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml new file mode 100644 index 00000000000..a7a2bbad9db --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml @@ -0,0 +1,11 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalProject(count()=[$2], gender=[$0], state=[$1]) + LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) + LogicalProject(gender=[$4], state=[$7]) + LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), SORT_AGG_METRICS->[2 ASC FIRST], PROJECT->[count(), gender, state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"multi_terms_buckets":{"multi_terms":{"terms":[{"field":"gender.keyword"},{"field":"state.keyword"}],"size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"_count":"desc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml deleted file mode 100644 index 8a45ecc2f92..00000000000 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms1.yaml +++ /dev/null @@ -1,13 +0,0 @@ -calcite: - logical: | - LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) - LogicalProject(count()=[$2], gender=[$0], state=[$1]) - LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) - LogicalProject(gender=[$4], state=[$7]) - LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) - CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) - physical: | - EnumerableLimit(fetch=[10000]) - EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), PROJECT->[count(), gender, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":false,"order":"asc"}}},{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java index 44c894f89be..1a68e93bc5d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java @@ -14,6 +14,7 @@ import lombok.Getter; import org.apache.calcite.rel.RelFieldCollation; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AbstractAggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; @@ -28,6 +29,7 @@ import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; @@ -98,6 +100,9 @@ public void pushDownSortAggMetrics(List collations, List collations, List collations, List collations, List src instanceof TermsValuesSourceBuilder terms && !terms.missingBucket())) { // multi-term agg + MultiTermsAggregationBuilder multiTermsBuilder = + new MultiTermsAggregationBuilder("multi_terms_buckets"); + multiTermsBuilder.size(composite.size()); + multiTermsBuilder.terms( + composite.sources().stream() + .map(TermsValuesSourceBuilder.class::cast) + .map( + termValue -> { + MultiTermsValuesSourceConfig.Builder config = + new MultiTermsValuesSourceConfig.Builder(); + config.setFieldName(termValue.field()); + config.setUserValueTypeHint(termValue.userValuetypeHint()); + return config.build(); + }) + .toList()); + attachSubAggregations(composite.getSubAggregations(), path, multiTermsBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(multiTermsBuilder), + convertTo(aggregationBuilder.getRight())); return; } } @@ -168,7 +199,7 @@ private String getAggregationPath( return path; } - private > T attachSubAggregations( + private > T attachSubAggregations( Collection subAggregations, String path, T aggregationBuilder) { AggregatorFactories.Builder metricBuilder = new AggregatorFactories.Builder(); if (subAggregations.isEmpty()) { From 416ca6d8acdaf839abb1b17a0db0a5b1a0a28c02 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 24 Oct 2025 16:20:09 +0800 Subject: [PATCH 9/9] avoid the case of 'stat count, sum ... | sort count' Signed-off-by: Lantao Jin --- .../sql/calcite/remote/CalciteExplainIT.java | 18 ++++++++ ...agg_with_sort_on_one_metric_not_push1.yaml | 13 ++++++ ...agg_with_sort_on_one_metric_not_push2.yaml | 13 ++++++ .../scan/AbstractCalciteIndexScan.java | 45 ++++++++++++++++--- .../storage/scan/CalciteLogicalIndexScan.java | 2 +- 5 files changed, 83 insertions(+), 8 deletions(-) create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 89cb7398cf6..5f227c94472 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1112,6 +1112,24 @@ public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() thro TEST_INDEX_TIME_DATA))); } + @Test + public void testExplainMultipleAggregatorsWithSortOnOneMetricNotPushDown() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + String expected = + loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c," + + " sum(balance) as s by state | sort c")); + expected = loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c," + + " sum(balance) as s by state | sort c, s")); + } + @Test public void testExplainEvalMax() throws IOException { String expected = loadExpectedPlan("explain_eval_max.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml new file mode 100644 index 00000000000..6a5bc8ea0f5 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml @@ -0,0 +1,13 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalProject(c=[$1], s=[$2], state=[$0]) + LogicalAggregate(group=[{0}], c=[COUNT()], s=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},c=COUNT(),s=SUM($0)), PROJECT->[c, s, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"s":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml new file mode 100644 index 00000000000..d1651f464a6 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml @@ -0,0 +1,13 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first]) + LogicalProject(c=[$1], s=[$2], state=[$0]) + LogicalAggregate(group=[{0}], c=[COUNT()], s=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},c=COUNT(),s=SUM($0)), PROJECT->[c, s, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"s":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index cfb2201a8ef..c3f5d2fe4e4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -9,6 +9,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.stream.Stream; import lombok.Getter; @@ -230,24 +231,53 @@ protected List getCollationNames(List collations) { } /** - * Check if the sort by collations contains any aggregators that are pushed down. E.g. In `stats - * avg(age) as avg_age by state | sort avg_age`, the sort clause has `avg_age` which is an - * aggregator. The function will return true in this case. + * Check if all sort-by collations equal aggregators that are pushed down. E.g. In `stats avg(age) + * as avg_age, sum(age) as sum_age by state | sort avg_age, sum_age`, the sort keys `avg_age`, + * `sum_age` which equal the pushed down aggregators `avg(age)`, `sum(age)`. + * + * @param collations List of collation names to check against aggregators. + * @return True if all collation names match all aggregator output, false otherwise. + */ + protected boolean isAllCollationNamesEqualAggregators(List collations) { + Stream aggregates = + pushDownContext.stream() + .filter(action -> action.type() == PushDownType.AGGREGATION) + .map(action -> ((LogicalAggregate) action.digest())); + return aggregates + .map(aggregate -> isAllCollationNamesEqualAggregators(aggregate, collations)) + .reduce(false, Boolean::logicalOr); + } + + private boolean isAllCollationNamesEqualAggregators( + LogicalAggregate aggregate, List collations) { + List fieldNames = aggregate.getRowType().getFieldNames(); + // The output fields of the aggregate are in the format of + // [...grouping fields, ...aggregator fields], so we set an offset to skip + // the grouping fields. + int groupOffset = aggregate.getGroupSet().cardinality(); + List fieldsWithoutGrouping = fieldNames.subList(groupOffset, fieldNames.size()); + return new HashSet<>(collations).equals(new HashSet<>(fieldsWithoutGrouping)); + } + + /** + * Check if any sort-by collations is in aggregators that are pushed down. E.g. In `stats avg(age) + * as avg_age by state | sort avg_age`, the sort clause has `avg_age` which is an aggregator. The + * function will return true in this case. * * @param collations List of collation names to check against aggregators. * @return True if any collation name matches an aggregator output, false otherwise. */ - protected boolean hasAggregatorInSortBy(List collations) { + protected boolean isAnyCollationNameInAggregators(List collations) { Stream aggregates = pushDownContext.stream() .filter(action -> action.type() == PushDownType.AGGREGATION) .map(action -> ((LogicalAggregate) action.digest())); return aggregates - .map(aggregate -> isAnyCollationNameInAggregateOutput(aggregate, collations)) + .map(aggregate -> isAnyCollationNameInAggregators(aggregate, collations)) .reduce(false, Boolean::logicalOr); } - private static boolean isAnyCollationNameInAggregateOutput( + private boolean isAnyCollationNameInAggregators( LogicalAggregate aggregate, List collations) { List fieldNames = aggregate.getRowType().getFieldNames(); // The output fields of the aggregate are in the format of @@ -268,7 +298,8 @@ private static boolean isAnyCollationNameInAggregateOutput( public AbstractCalciteIndexScan pushDownSort(List collations) { try { List collationNames = getCollationNames(collations); - if (getPushDownContext().isAggregatePushed() && hasAggregatorInSortBy(collationNames)) { + if (getPushDownContext().isAggregatePushed() + && isAnyCollationNameInAggregators(collationNames)) { // If aggregation is pushed down, we cannot push down sorts where its by fields contain // aggregators. return null; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index 4bc6d688e1b..97755b7b53d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -296,7 +296,7 @@ public CalciteLogicalIndexScan pushDownSortAggregateMetrics(Sort sort) { return null; } List collationNames = getCollationNames(sort.getCollation().getFieldCollations()); - if (!hasAggregatorInSortBy(collationNames)) { + if (!isAllCollationNamesEqualAggregators(collationNames)) { return null; } AbstractAction newAction =