Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum Key {
CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"),
CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR(
"plugins.calcite.pushdown.rowcount.estimation.factor"),
CALCITE_SUPPORT_ALL_JOIN_TYPES("plugins.calcite.all_join_types.allowed"),

/** Query Settings. */
FIELD_TYPE_TOLERANCE("plugins.query.field_type_tolerance"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,24 @@ public static class ArgumentMap {
private final Map<String, Literal> map;

public ArgumentMap(List<Argument> arguments) {
this.map =
arguments.stream()
.collect(java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
if (arguments == null || arguments.isEmpty()) {
this.map = Map.of();
} else {
this.map =
arguments.stream()
.collect(
java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
}
}

public static ArgumentMap of(List<Argument> arguments) {
return new ArgumentMap(arguments);
}

public static ArgumentMap empty() {
return new ArgumentMap(null);
}

/**
* Get argument value by name.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,8 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
public String toString() {
return String.valueOf(value);
}

public static Literal TRUE = new Literal(true, DataType.BOOLEAN);
public static Literal FALSE = new Literal(false, DataType.BOOLEAN);
public static Literal ZERO = new Literal(Integer.valueOf("0"), DataType.INTEGER);
}
15 changes: 14 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/tree/Join.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

@ToString
Expand All @@ -28,20 +30,26 @@ public class Join extends UnresolvedPlan {
private final JoinType joinType;
private final Optional<UnresolvedExpression> joinCondition;
private final JoinHint joinHint;
private final Optional<List<Field>> joinFields;
private final Argument.ArgumentMap argumentMap;

public Join(
UnresolvedPlan right,
Optional<String> leftAlias,
Optional<String> rightAlias,
JoinType joinType,
Optional<UnresolvedExpression> joinCondition,
JoinHint joinHint) {
JoinHint joinHint,
Optional<List<Field>> joinFields,
Argument.ArgumentMap argumentMap) {
this.right = right;
this.leftAlias = leftAlias;
this.rightAlias = rightAlias;
this.joinType = joinType;
this.joinCondition = joinCondition;
this.joinHint = joinHint;
this.joinFields = joinFields;
this.argumentMap = argumentMap;
}

@Override
Expand Down Expand Up @@ -89,6 +97,11 @@ public enum JoinType {
FULL
}

/** RIGHT, CROSS, FULL are performance sensitive join types */
public static List<JoinType> highCostJoinTypes() {
return List.of(JoinType.RIGHT, JoinType.CROSS, JoinType.FULL);
}

@Getter
@RequiredArgsConstructor
public static class JoinHint {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
Expand Down Expand Up @@ -903,6 +904,70 @@ private Optional<RexLiteral> extractAliasLiteral(RexNode node) {
public RelNode visitJoin(Join node, CalcitePlanContext context) {
List<UnresolvedPlan> children = node.getChildren();
children.forEach(c -> analyze(c, context));
if (node.getJoinCondition().isEmpty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np: not sure if we can simplify the current visitor by some more expressive DSL.

// join-with-field-list grammar
List<String> leftColumns = context.relBuilder.peek(1).getRowType().getFieldNames();
List<String> rightColumns = context.relBuilder.peek().getRowType().getFieldNames();
List<String> duplicatedFieldNames =
leftColumns.stream().filter(rightColumns::contains).toList();
RexNode joinCondition;
if (node.getJoinFields().isPresent()) {
joinCondition =
node.getJoinFields().get().stream()
.map(field -> buildJoinConditionByFieldName(context, field.getField().toString()))
.reduce(context.rexBuilder::and)
.orElse(context.relBuilder.literal(true));
} else {
joinCondition =
duplicatedFieldNames.stream()
.map(fieldName -> buildJoinConditionByFieldName(context, fieldName))
.reduce(context.rexBuilder::and)
.orElse(context.relBuilder.literal(true));
}
if (node.getJoinType() == SEMI || node.getJoinType() == ANTI) {
// semi and anti join only return left table outputs
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
return context.relBuilder.peek();
}
List<RexNode> toBeRemovedFields;
if (node.getArgumentMap().get("overwrite") == null // 'overwrite' default value is true
|| (node.getArgumentMap().get("overwrite").equals(Literal.TRUE))) {
toBeRemovedFields =
duplicatedFieldNames.stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, true, context))
.toList();
} else {
toBeRemovedFields =
duplicatedFieldNames.stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, false, context))
.toList();
}
Literal max = node.getArgumentMap().get("max");
if (max != null && !max.equals(Literal.ZERO)) {
// max != 0 means the right-side should be dedup
Integer allowedDuplication = (Integer) max.getValue();
if (allowedDuplication < 0) {
throw new SemanticCheckException("max option must be a positive integer");
}
List<RexNode> dedupeFields =
node.getJoinFields().isPresent()
? node.getJoinFields().get().stream()
.map(a -> (RexNode) context.relBuilder.field(a.getField().toString()))
.toList()
: duplicatedFieldNames.stream()
.map(a -> (RexNode) context.relBuilder.field(a))
.toList();
buildDedupNotNull(context, dedupeFields, allowedDuplication);
}
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
if (!toBeRemovedFields.isEmpty()) {
context.relBuilder.projectExcept(toBeRemovedFields);
}
return context.relBuilder.peek();
}
// The join-with-criteria grammar doesn't allow empty join condition
RexNode joinCondition =
node.getJoinCondition()
.map(c -> rexVisitor.analyzeJoinCondition(c, context))
Expand Down Expand Up @@ -938,6 +1003,19 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
.orElse(rightTableQualifiedName + "." + col)
: col)
.toList();

Literal max = node.getArgumentMap().get("max");
if (max != null && !max.equals(Literal.ZERO)) {
// max != 0 means the right-side should be dedup
Integer allowedDuplication = (Integer) max.getValue();
if (allowedDuplication < 0) {
throw new SemanticCheckException("max option must be a positive integer");
}
List<RexNode> dedupeFields =
getRightColumnsInJoinCriteria(context.relBuilder, joinCondition);

buildDedupNotNull(context, dedupeFields, allowedDuplication);
}
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
JoinAndLookupUtils.renameToExpectedFields(
Expand All @@ -946,6 +1024,37 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
return context.relBuilder.peek();
}

private List<RexNode> getRightColumnsInJoinCriteria(
RelBuilder relBuilder, RexNode joinCondition) {
int stackSize = relBuilder.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] Why not using 2 directly here? I think we should join the top 2 operators in the stack. And In which case will it be greater than 2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method invoked after join condition resolving, if the joinCondition contains subsearch, the top of stack is the subsearch table. So I pick the stackSize - 1 and stackSize - 2 as left and right.

int leftFieldCount = relBuilder.peek(stackSize - 1).getRowType().getFieldCount();
RelNode right = relBuilder.peek(stackSize - 2);
List<String> allColumnNamesOfRight = right.getRowType().getFieldNames();

List<Integer> rightColumnIndexes = new ArrayList<>();
joinCondition.accept(
new RexVisitorImpl<Void>(true) {
@Override
public Void visitInputRef(RexInputRef inputRef) {
if (inputRef.getIndex() >= leftFieldCount) {
rightColumnIndexes.add(inputRef.getIndex() - leftFieldCount);
}
return super.visitInputRef(inputRef);
}
});
return rightColumnIndexes.stream()
.map(allColumnNamesOfRight::get)
.map(n -> (RexNode) relBuilder.field(n))
.toList();
}

private static RexNode buildJoinConditionByFieldName(
CalcitePlanContext context, String fieldName) {
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, false, context);
RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, true, context);
return context.rexBuilder.equals(sourceKey, lookupKey);
}

@Override
public RelNode visitSubqueryAlias(SubqueryAlias node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down Expand Up @@ -1068,74 +1177,82 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) {
List<RexNode> dedupeFields =
node.getFields().stream().map(f -> rexVisitor.analyze(f, context)).toList();
if (keepEmpty) {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* +- ...
*/
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
// ASC
// NULLS FIRST, 'b ASC NULLS FIRST]
RexNode rowNumber =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.partitionBy(dedupeFields)
.orderBy(dedupeFields)
.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
context.relBuilder.projectPlus(rowNumber);
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
// Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n)
context.relBuilder.filter(
context.relBuilder.or(
context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()),
context.relBuilder.lessThanOrEqual(
_row_number_dedup_, context.relBuilder.literal(allowedDuplication))));
// DropColumns('_row_number_)
context.relBuilder.projectExcept(_row_number_dedup_);
buildDedupOrNull(context, dedupeFields, allowedDuplication);
} else {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n)
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* +- Filter (isnotnull('a) AND isnotnull('b))
* +- ...
*/
// Filter (isnotnull('a) AND isnotnull('b))
context.relBuilder.filter(
context.relBuilder.and(
dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
// ASC
// NULLS FIRST, 'b ASC NULLS FIRST]
RexNode rowNumber =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.partitionBy(dedupeFields)
.orderBy(dedupeFields)
.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
context.relBuilder.projectPlus(rowNumber);
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
// Filter ('_row_number_dedup_ <= n)
context.relBuilder.filter(
context.relBuilder.lessThanOrEqual(
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
// DropColumns('_row_number_dedup_)
context.relBuilder.projectExcept(_row_number_dedup_);
buildDedupNotNull(context, dedupeFields, allowedDuplication);
}
return context.relBuilder.peek();
}

private static void buildDedupOrNull(
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* +- ...
*/
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
// ASC
// NULLS FIRST, 'b ASC NULLS FIRST]
RexNode rowNumber =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.partitionBy(dedupeFields)
.orderBy(dedupeFields)
.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
context.relBuilder.projectPlus(rowNumber);
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
// Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n)
context.relBuilder.filter(
context.relBuilder.or(
context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()),
context.relBuilder.lessThanOrEqual(
_row_number_dedup_, context.relBuilder.literal(allowedDuplication))));
// DropColumns('_row_number_dedup_)
context.relBuilder.projectExcept(_row_number_dedup_);
}

private static void buildDedupNotNull(
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
/*
* | dedup 2 a, b keepempty=false
* DropColumns('_row_number_dedup_)
* +- Filter ('_row_number_dedup_ <= n)
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
* +- Filter (isnotnull('a) AND isnotnull('b))
* +- ...
*/
// Filter (isnotnull('a) AND isnotnull('b))
context.relBuilder.filter(
context.relBuilder.and(dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC
// NULLS FIRST, 'b ASC NULLS FIRST]
RexNode rowNumber =
context
.relBuilder
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
.over()
.partitionBy(dedupeFields)
.orderBy(dedupeFields)
.rowsTo(RexWindowBounds.CURRENT_ROW)
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
context.relBuilder.projectPlus(rowNumber);
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
// Filter ('_row_number_dedup_ <= n)
context.relBuilder.filter(
context.relBuilder.lessThanOrEqual(
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
// DropColumns('_row_number_dedup_)
context.relBuilder.projectExcept(_row_number_dedup_);
}

@Override
public RelNode visitWindow(Window node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down
Loading
Loading