Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_DEDUP;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_MAIN;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_RARE_TOP;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_FOR_STREAMSTATS;
Expand Down Expand Up @@ -48,9 +49,6 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFamily;
Expand Down Expand Up @@ -1054,7 +1052,7 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
context.relBuilder.aggregate(
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
if (hintBucketNonNull) addIgnoreNullBucketHintToAggregate(context);
if (hintBucketNonNull) PlanUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
// During aggregation, Calcite projects both input dependencies and output group-by fields.
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
// Apply explicit renaming to restore the intended aliases.
Expand Down Expand Up @@ -1317,7 +1315,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
: duplicatedFieldNames.stream()
.map(a -> (RexNode) context.relBuilder.field(a))
.toList();
buildDedupNotNull(context, dedupeFields, allowedDuplication);
buildDedupNotNull(context, dedupeFields, allowedDuplication, true);
}
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
Expand Down Expand Up @@ -1373,7 +1371,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
List<RexNode> dedupeFields =
getRightColumnsInJoinCriteria(context.relBuilder, joinCondition);

buildDedupNotNull(context, dedupeFields, allowedDuplication);
buildDedupNotNull(context, dedupeFields, allowedDuplication, true);
}
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
Expand Down Expand Up @@ -1538,24 +1536,20 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) {
if (keepEmpty) {
buildDedupOrNull(context, dedupeFields, allowedDuplication);
} else {
buildDedupNotNull(context, dedupeFields, allowedDuplication);
buildDedupNotNull(context, dedupeFields, allowedDuplication, false);
}
return context.relBuilder.peek();
}

private static void buildDedupOrNull(
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* | dedup 2 a, b keepempty=true
* LogicalProject(...)
* +- LogicalFilter(condition=[OR(IS NULL(a), IS NULL(b), <=(_row_number_dedup_, 1))])
* +- LogicalProject(..., _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY a, b)])
* +- ...
*/
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
// ASC
// NULLS FIRST, 'b ASC NULLS FIRST]
RexNode rowNumber =
context
.relBuilder
Expand All @@ -1578,16 +1572,21 @@ private static void buildDedupOrNull(
}

private static void buildDedupNotNull(
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
CalcitePlanContext context,
List<RexNode> dedupeFields,
Integer allowedDuplication,
boolean fromJoinMaxOption) {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n)
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* +- Filter (isnotnull('a) AND isnotnull('b))
* +- ...
* LogicalProject(...)
* +- LogicalFilter(condition=[<=(_row_number_dedup_, n)]))
* +- LogicalProject(..., _row_number_dedup_=[ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY a, b)])
* +- LogicalFilter(condition=[AND(IS NOT NULL(a), IS NOT NULL(b))])
* +- ...
*/
// Filter (isnotnull('a) AND isnotnull('b))
String rowNumberAlias =
fromJoinMaxOption ? ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP : ROW_NUMBER_COLUMN_FOR_DEDUP;
context.relBuilder.filter(
context.relBuilder.and(dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
Expand All @@ -1601,15 +1600,15 @@ private static void buildDedupNotNull(
.partitionBy(dedupeFields)
.orderBy(dedupeFields)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will dedupe work without ordering by deduped fields?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

em, it should work in non-pushdown case. maybe we could remove this orderBy in window.

Copy link
Collaborator

@yuancu yuancu Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks a little strange to me because the sort keys in a window should be the same (the partition key)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. it was not introduced by this pr. let's fix them in followup PR since it will change the plan a lot.

.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
.as(rowNumberAlias);
context.relBuilder.projectPlus(rowNumber);
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
RexNode rowNumberField = context.relBuilder.field(rowNumberAlias);
// Filter ('_row_number_dedup_ <= n)
context.relBuilder.filter(
context.relBuilder.lessThanOrEqual(
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
rowNumberField, context.relBuilder.literal(allowedDuplication)));
// DropColumns('_row_number_dedup_)
context.relBuilder.projectExcept(_row_number_dedup_);
context.relBuilder.projectExcept(rowNumberField);
}

@Override
Expand Down Expand Up @@ -2378,25 +2377,6 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
return context.relBuilder.peek();
}

private static void addIgnoreNullBucketHintToAggregate(CalcitePlanContext context) {
final RelHint statHits =
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
assert context.relBuilder.peek() instanceof LogicalAggregate
: "Stats hits should be added to LogicalAggregate";
context.relBuilder.hints(statHits);
context
.relBuilder
.getCluster()
.setHintStrategies(
HintStrategyTable.builder()
.hintStrategy(
"stats_args",
(hint, rel) -> {
return rel instanceof LogicalAggregate;
})
.build());
}

@Override
public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context) {
throw new CalciteUnsupportedException("Table function is unsupported in Calcite");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexVisitorImpl;
Expand All @@ -45,8 +51,11 @@
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.WindowBound;
Expand All @@ -62,6 +71,7 @@ public interface PlanUtils {
/** this is only for dedup command, do not reuse it in other command */
String ROW_NUMBER_COLUMN_FOR_DEDUP = "_row_number_dedup_";

String ROW_NUMBER_COLUMN_FOR_JOIN_MAX_DEDUP = "_row_number_join_max_dedup_";
String ROW_NUMBER_COLUMN_FOR_RARE_TOP = "_row_number_rare_top_";
String ROW_NUMBER_COLUMN_FOR_MAIN = "_row_number_main_";
String ROW_NUMBER_COLUMN_FOR_SUBSEARCH = "_row_number_subsearch_";
Expand Down Expand Up @@ -449,18 +459,15 @@ static RexNode derefMapCall(RexNode rexNode) {
return rexNode;
}

/** Check if contains RexOver introduced by dedup */
static boolean containsRowNumberDedup(LogicalProject project) {
return project.getProjects().stream()
.anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER)
&& project.getRowType().getFieldNames().contains(ROW_NUMBER_COLUMN_FOR_DEDUP);
/** Check if contains dedup */
static boolean containsRowNumberDedup(RelNode node) {
return node.getRowType().getFieldNames().stream().anyMatch(ROW_NUMBER_COLUMN_FOR_DEDUP::equals);
}

/** Check if contains RexOver introduced by dedup top/rare */
static boolean containsRowNumberRareTop(LogicalProject project) {
return project.getProjects().stream()
.anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER)
&& project.getRowType().getFieldNames().contains(ROW_NUMBER_COLUMN_FOR_RARE_TOP);
/** Check if contains dedup for top/rare */
static boolean containsRowNumberRareTop(RelNode node) {
return node.getRowType().getFieldNames().stream()
.anyMatch(ROW_NUMBER_COLUMN_FOR_RARE_TOP::equals);
}

/** Get all RexWindow list from LogicalProject */
Expand Down Expand Up @@ -508,10 +515,6 @@ static boolean distinctProjectList(LogicalProject project) {
return project.getNamedProjects().stream().allMatch(rexSet::add);
}

static boolean containsRexOver(LogicalProject project) {
return project.getProjects().stream().anyMatch(RexOver::containsOver);
}

/**
* The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its
* collation is empty. For example: <code>sort name | head 5</code> should not be pushed down
Expand All @@ -524,7 +527,7 @@ static boolean isLogicalSortLimit(LogicalSort sort) {
return sort.fetch != null;
}

static boolean projectContainsExpr(Project project) {
static boolean containsRexCall(Project project) {
return project.getProjects().stream().anyMatch(p -> p instanceof RexCall);
}

Expand Down Expand Up @@ -595,4 +598,58 @@ static void replaceTop(RelBuilder relBuilder, RelNode relNode) {
throw new IllegalStateException("Unable to invoke RelBuilder.replaceTop", e);
}
}

static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
final RelHint statHits =
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
assert relBuilder.peek() instanceof LogicalAggregate
: "Stats hits should be added to LogicalAggregate";
relBuilder.hints(statHits);
relBuilder
.getCluster()
.setHintStrategies(
HintStrategyTable.builder()
.hintStrategy(
"stats_args",
(hint, rel) -> {
return rel instanceof LogicalAggregate;
})
.build());
}

/** Extract the RexLiteral from the aggregate call if the aggregate call is a LITERAL_AGG. */
static @Nullable RexLiteral getObjectFromLiteralAgg(AggregateCall aggCall) {
if (aggCall.getAggregation().kind == SqlKind.LITERAL_AGG) {
return (RexLiteral)
aggCall.rexList.stream().filter(rex -> rex instanceof RexLiteral).findAny().orElse(null);
} else {
return null;
}
}

/**
* This is a helper method to create a target mapping easily for replacing calling {@link
* Mappings#target(List, int)}
*
* @param rexNodes the rex list in schema
* @param schema the schema which contains the rex list
* @return the target mapping
*/
static Mapping mapping(List<RexNode> rexNodes, RelDataType schema) {
return Mappings.target(getSelectColumns(rexNodes), schema.getFieldCount());
}

static boolean mayBeFilterFromBucketNonNull(LogicalFilter filter) {
RexNode condition = filter.getCondition();
return isNotNullOnRef(condition)
|| (condition instanceof RexCall rexCall
&& rexCall.getOperator().equals(SqlStdOperatorTable.AND)
&& rexCall.getOperands().stream().allMatch(PlanUtils::isNotNullOnRef));
}

private static boolean isNotNullOnRef(RexNode rex) {
return rex instanceof RexCall rexCall
&& rexCall.isA(SqlKind.IS_NOT_NULL)
&& rexCall.getOperands().get(0) instanceof RexInputRef;
}
}
12 changes: 10 additions & 2 deletions core/src/main/java/org/opensearch/sql/data/type/ExprType.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,18 @@ default Optional<String> getOriginalPath() {
}

/**
* Get the original path. Types like alias type should be derived from the type of the original
* field.
* Get the original expr path. Types like alias type should be derived from the type of the
* original field.
*/
default ExprType getOriginalExprType() {
return this;
}

/**
* Get the original data type. Types like alias type should be derived from the type of the
* original field.
*/
default ExprType getOriginalType() {
return this;
}
}
Loading
Loading