From 72de2c5cf4d8aa6f1424aa7283595dead3568976 Mon Sep 17 00:00:00 2001 From: liuyongvs Date: Thu, 2 Jan 2025 14:30:27 +0800 Subject: [PATCH 1/2] [FLINK-36988][table] Migrate LogicalCorrelateToJoinFromTemporalTableFunctionRule to java --- ...teToJoinFromTemporalTableFunctionRule.java | 361 ++++++++++++++++++ ...eToJoinFromTemporalTableFunctionRule.scala | 238 ------------ 2 files changed, 361 insertions(+), 238 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java new file mode 100644 index 0000000000000..a0b0c11fb8bd6 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.TemporalTableFunction; +import org.apache.flink.table.functions.TemporalTableFunctionImpl; +import org.apache.flink.table.operations.QueryOperation; +import org.apache.flink.table.planner.calcite.FlinkRelBuilder; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; +import org.apache.flink.table.planner.functions.utils.TableSqlFunction; +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecTemporalJoin; +import org.apache.flink.table.planner.plan.optimize.program.FlinkOptimizeContext; +import org.apache.flink.table.planner.plan.utils.ExpandTableScanShuttle; +import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor; +import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil; +import org.apache.flink.table.planner.utils.ShortcutUtils; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.TableFunctionScan; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlOperator; +import org.immutables.value.Value; + +import java.util.Optional; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isProctimeAttribute; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * The initial temporal TableFunction join (LATERAL TemporalTableFunction(o.proctime)) is a + * correlate. Rewrite it into a Join with a special temporal join condition wraps time attribute and + * primary key information. The join will be translated into {@link StreamExecTemporalJoin} in + * physical. + */ +@Value.Enclosing +public class LogicalCorrelateToJoinFromTemporalTableFunctionRule + extends RelRule< + LogicalCorrelateToJoinFromTemporalTableFunctionRule + .LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig> { + + public static final LogicalCorrelateToJoinFromTemporalTableFunctionRule INSTANCE = + LogicalCorrelateToJoinFromTemporalTableFunctionRule + .LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig.DEFAULT + .toRule(); + + private LogicalCorrelateToJoinFromTemporalTableFunctionRule( + LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig config) { + super(config); + } + + private String extractNameFromTimeAttribute(Expression timeAttribute) { + if (timeAttribute instanceof FieldReferenceExpression) { + FieldReferenceExpression f = (FieldReferenceExpression) timeAttribute; + if (f.getOutputDataType() + .getLogicalType() + .isAnyOf( + LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE, + LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) { + return f.getName(); + } + } + throw new ValidationException( + "Invalid timeAttribute [" + timeAttribute + "] in TemporalTableFunction"); + } + + private boolean isProctimeReference(TemporalTableFunctionImpl temporalTableFunction) { + FieldReferenceExpression fieldRef = + (FieldReferenceExpression) temporalTableFunction.getTimeAttribute(); + return isProctimeAttribute(fieldRef.getOutputDataType().getLogicalType()); + } + + private String extractNameFromPrimaryKeyAttribute(Expression expression) { + if (expression instanceof FieldReferenceExpression) { + FieldReferenceExpression f = (FieldReferenceExpression) expression; + return f.getName(); + } + throw new ValidationException( + "Unsupported expression [" + + expression + + "] as primary key. " + + "Only top-level (not nested) field references are supported."); + } + + @Override + public void onMatch(RelOptRuleCall call) { + LogicalCorrelate logicalCorrelate = call.rel(0); + RelNode leftNode = call.rel(1); + TableFunctionScan rightTableFunctionScan = call.rel(2); + + RelOptCluster cluster = logicalCorrelate.getCluster(); + + Optional temporalTableFunctionCall = + new GetTemporalTableFunctionCall(cluster.getRexBuilder(), leftNode) + .visit(rightTableFunctionScan.getCall()); + + if (temporalTableFunctionCall.isPresent() + && temporalTableFunctionCall.get().getTemporalTableFunction() + instanceof TemporalTableFunctionImpl) { + TemporalTableFunctionImpl rightTemporalTableFunction = + (TemporalTableFunctionImpl) + temporalTableFunctionCall.get().getTemporalTableFunction(); + RexNode leftTimeAttribute = temporalTableFunctionCall.get().getTimeAttribute(); + + // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin + QueryOperation underlyingHistoryTable = + rightTemporalTableFunction.getUnderlyingHistoryTable(); + RexBuilder rexBuilder = cluster.getRexBuilder(); + + FlinkOptimizeContext flinkContext = + (FlinkOptimizeContext) ShortcutUtils.unwrapContext(call.getPlanner()); + FlinkRelBuilder relBuilder = flinkContext.getFlinkRelBuilder(); + + RelNode temporalTable = relBuilder.queryOperation(underlyingHistoryTable).build(); + // expand QueryOperationCatalogViewTable in Table Scan + ExpandTableScanShuttle shuttle = new ExpandTableScanShuttle(); + RelNode rightNode = temporalTable.accept(shuttle); + + RexNode rightTimeIndicatorExpression = + createRightExpression( + rexBuilder, + leftNode, + rightNode, + extractNameFromTimeAttribute( + rightTemporalTableFunction.getTimeAttribute())); + + RexNode rightPrimaryKeyExpression = + createRightExpression( + rexBuilder, + leftNode, + rightNode, + extractNameFromPrimaryKeyAttribute( + rightTemporalTableFunction.getPrimaryKey())); + + relBuilder.push(leftNode); + relBuilder.push(rightNode); + + RexNode condition; + if (isProctimeReference(rightTemporalTableFunction)) { + condition = + TemporalJoinUtil.makeProcTimeTemporalFunctionJoinConCall( + rexBuilder, leftTimeAttribute, rightPrimaryKeyExpression); + } else { + condition = + TemporalJoinUtil.makeRowTimeTemporalFunctionJoinConCall( + rexBuilder, + leftTimeAttribute, + rightTimeIndicatorExpression, + rightPrimaryKeyExpression); + } + + relBuilder.join(JoinRelType.INNER, condition); + call.transformTo(relBuilder.build()); + } else { + // Do nothing and handle standard TableFunction + } + } + + private RexNode createRightExpression( + RexBuilder rexBuilder, RelNode leftNode, RelNode rightNode, String field) { + int rightReferencesOffset = leftNode.getRowType().getFieldCount(); + RelDataTypeField rightDataTypeField = rightNode.getRowType().getField(field, false, false); + return rexBuilder.makeInputRef( + rightDataTypeField.getType(), + rightReferencesOffset + rightDataTypeField.getIndex()); + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig + extends RelRule.Config { + LogicalCorrelateToJoinFromTemporalTableFunctionRule + .LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig + DEFAULT = + ImmutableLogicalCorrelateToJoinFromTemporalTableFunctionRule + .LogicalCorrelateToJoinFromTemporalTableFunctionRuleConfig.builder() + .build() + .withOperandSupplier( + b0 -> + b0.operand(LogicalCorrelate.class) + .inputs( + b1 -> + b1.operand(RelNode.class) + .anyInputs(), + b2 -> + b2.operand( + TableFunctionScan + .class) + .noInputs())) + .withDescription( + "LogicalCorrelateToJoinFromTemporalTableFunctionRule"); + + @Override + default LogicalCorrelateToJoinFromTemporalTableFunctionRule toRule() { + return new LogicalCorrelateToJoinFromTemporalTableFunctionRule(this); + } + } +} + +/** + * Simple pojo class for extracted {@link TemporalTableFunction} with time attribute extracted from + * RexNode with {@link TemporalTableFunction} call. + */ +class TemporalTableFunctionCall { + private TemporalTableFunction temporalTableFunction; + private RexNode timeAttribute; + + public TemporalTableFunctionCall( + TemporalTableFunction temporalTableFunction, RexNode timeAttribute) { + this.temporalTableFunction = temporalTableFunction; + this.timeAttribute = timeAttribute; + } + + public TemporalTableFunction getTemporalTableFunction() { + return temporalTableFunction; + } + + public void setTemporalTableFunction(TemporalTableFunction temporalTableFunction) { + this.temporalTableFunction = temporalTableFunction; + } + + public RexNode getTimeAttribute() { + return timeAttribute; + } + + public void setTimeAttribute(RexNode timeAttribute) { + this.timeAttribute = timeAttribute; + } +} + +/** + * Find {@link TemporalTableFunction} call and run {@link CorrelatedFieldAccessRemoval} on it's + * operand. + */ +class GetTemporalTableFunctionCall extends RexVisitorImpl { + private final RexBuilder rexBuilder; + private final RelNode leftSide; + + GetTemporalTableFunctionCall(RexBuilder rexBuilder, RelNode leftSide) { + super(false); + this.rexBuilder = rexBuilder; + this.leftSide = leftSide; + } + + Optional visit(RexNode node) { + TemporalTableFunctionCall result = node.accept(this); + return result != null ? Optional.of(result) : Optional.empty(); + } + + @Override + public TemporalTableFunctionCall visitCall(RexCall rexCall) { + FunctionDefinition functionDefinition; + SqlOperator sqlOperator = rexCall.getOperator(); + if (sqlOperator instanceof TableSqlFunction) { + functionDefinition = ((TableSqlFunction) sqlOperator).udtf(); + } else if (sqlOperator instanceof BridgingSqlFunction) { + functionDefinition = ((BridgingSqlFunction) sqlOperator).getDefinition(); + } else { + return null; + } + + if (!(functionDefinition instanceof TemporalTableFunctionImpl)) { + return null; + } + TemporalTableFunctionImpl temporalTableFunction = + (TemporalTableFunctionImpl) functionDefinition; + + checkState( + rexCall.getOperands().size() == 1, + "TemporalTableFunction call [%s] must have exactly one argument", + rexCall); + CorrelatedFieldAccessRemoval correlatedFieldAccessRemoval = + new CorrelatedFieldAccessRemoval(temporalTableFunction, rexBuilder, leftSide); + return new TemporalTableFunctionCall( + temporalTableFunction, + rexCall.getOperands().get(0).accept(correlatedFieldAccessRemoval)); + } +} + +/** + * This converts field accesses like `$cor0.o_rowtime` to valid input references for join condition + * context without `$cor` reference. + */ +class CorrelatedFieldAccessRemoval extends RexDefaultVisitor { + private TemporalTableFunctionImpl temporalTableFunction; + private RexBuilder rexBuilder; + private RelNode leftSide; + + public CorrelatedFieldAccessRemoval( + TemporalTableFunctionImpl temporalTableFunction, + RexBuilder rexBuilder, + RelNode leftSide) { + this.temporalTableFunction = temporalTableFunction; + this.rexBuilder = rexBuilder; + this.leftSide = leftSide; + } + + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + int leftIndex = leftSide.getRowType().getFieldList().indexOf(fieldAccess.getField()); + if (leftIndex < 0) { + throw new IllegalStateException( + "Failed to find reference to field [" + + fieldAccess.getField() + + "] in node [" + + leftSide + + "]"); + } + return rexBuilder.makeInputRef(leftSide, leftIndex); + } + + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + return inputRef; + } + + @Override + public RexNode visitNode(RexNode rexNode) { + throw new ValidationException( + "Unsupported argument [" + + rexNode + + "] " + + "in " + + TemporalTableFunction.class.getSimpleName() + + " call of " + + "[" + + temporalTableFunction.getUnderlyingHistoryTable() + + "] table"); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala deleted file mode 100644 index 79b09d5e427aa..0000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import org.apache.flink.table.api.ValidationException -import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.{TemporalTableFunction, TemporalTableFunctionImpl} -import org.apache.flink.table.operations.QueryOperation -import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction -import org.apache.flink.table.planner.functions.utils.TableSqlFunction -import org.apache.flink.table.planner.plan.optimize.program.FlinkOptimizeContext -import org.apache.flink.table.planner.plan.utils.{ExpandTableScanShuttle, RexDefaultVisitor} -import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil.{makeProcTimeTemporalFunctionJoinConCall, makeRowTimeTemporalFunctionJoinConCall} -import org.apache.flink.table.planner.utils.ShortcutUtils -import org.apache.flink.table.types.logical.LogicalTypeRoot.{TIMESTAMP_WITH_LOCAL_TIME_ZONE, TIMESTAMP_WITHOUT_TIME_ZONE} -import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isProctimeAttribute -import org.apache.flink.util.Preconditions.checkState - -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} -import org.apache.calcite.plan.RelOptRule.{any, none, operand, some} -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.core.{JoinRelType, TableFunctionScan} -import org.apache.calcite.rel.logical.LogicalCorrelate -import org.apache.calcite.rex._ - -/** - * The initial temporal TableFunction join (LATERAL TemporalTableFunction(o.proctime)) is a - * correlate. Rewrite it into a Join with a special temporal join condition wraps time attribute and - * primary key information. The join will be translated into [[StreamExecTemporalJoin]] in physical. - */ -class LogicalCorrelateToJoinFromTemporalTableFunctionRule - extends RelOptRule( - operand( - classOf[LogicalCorrelate], - some(operand(classOf[RelNode], any()), operand(classOf[TableFunctionScan], none()))), - "LogicalCorrelateToJoinFromTemporalTableFunctionRule") { - - private def extractNameFromTimeAttribute(timeAttribute: Expression): String = { - timeAttribute match { - case f: FieldReferenceExpression - if f.getOutputDataType.getLogicalType.isAnyOf( - TIMESTAMP_WITHOUT_TIME_ZONE, - TIMESTAMP_WITH_LOCAL_TIME_ZONE) => - f.getName - case _ => - throw new ValidationException( - s"Invalid timeAttribute [$timeAttribute] in TemporalTableFunction") - } - } - - private def isProctimeReference(temporalTableFunction: TemporalTableFunctionImpl): Boolean = { - val fieldRef = temporalTableFunction.getTimeAttribute.asInstanceOf[FieldReferenceExpression] - isProctimeAttribute(fieldRef.getOutputDataType.getLogicalType) - } - - private def extractNameFromPrimaryKeyAttribute(expression: Expression): String = { - expression match { - case f: FieldReferenceExpression => - f.getName - case _ => - throw new ValidationException( - s"Unsupported expression [$expression] as primary key. " + - s"Only top-level (not nested) field references are supported.") - } - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val logicalCorrelate: LogicalCorrelate = call.rel(0) - val leftNode: RelNode = call.rel(1) - val rightTableFunctionScan: TableFunctionScan = call.rel(2) - - val cluster = logicalCorrelate.getCluster - - new GetTemporalTableFunctionCall(cluster.getRexBuilder, leftNode) - .visit(rightTableFunctionScan.getCall) match { - case None => - // Do nothing and handle standard TableFunction - case Some( - TemporalTableFunctionCall( - rightTemporalTableFunction: TemporalTableFunctionImpl, - leftTimeAttribute)) => - // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin - val underlyingHistoryTable: QueryOperation = - rightTemporalTableFunction.getUnderlyingHistoryTable - val rexBuilder = cluster.getRexBuilder - - val flinkContext = ShortcutUtils - .unwrapContext(call.getPlanner) - .asInstanceOf[FlinkOptimizeContext] - val relBuilder = flinkContext.getFlinkRelBuilder - - val temporalTable: RelNode = relBuilder.queryOperation(underlyingHistoryTable).build() - // expand QueryOperationCatalogViewTable in Table Scan - val shuttle = new ExpandTableScanShuttle - val rightNode = temporalTable.accept(shuttle) - - val rightTimeIndicatorExpression = createRightExpression( - rexBuilder, - leftNode, - rightNode, - extractNameFromTimeAttribute(rightTemporalTableFunction.getTimeAttribute)) - - val rightPrimaryKeyExpression = createRightExpression( - rexBuilder, - leftNode, - rightNode, - extractNameFromPrimaryKeyAttribute(rightTemporalTableFunction.getPrimaryKey)) - - relBuilder.push(leftNode) - relBuilder.push(rightNode) - - val condition = - if (isProctimeReference(rightTemporalTableFunction)) { - makeProcTimeTemporalFunctionJoinConCall( - rexBuilder, - leftTimeAttribute, - rightPrimaryKeyExpression) - } else { - makeRowTimeTemporalFunctionJoinConCall( - rexBuilder, - leftTimeAttribute, - rightTimeIndicatorExpression, - rightPrimaryKeyExpression) - } - relBuilder.join(JoinRelType.INNER, condition) - - call.transformTo(relBuilder.build()) - } - } - - private def createRightExpression( - rexBuilder: RexBuilder, - leftNode: RelNode, - rightNode: RelNode, - field: String): RexNode = { - val rightReferencesOffset = leftNode.getRowType.getFieldCount - val rightDataTypeField = rightNode.getRowType.getField(field, false, false) - rexBuilder.makeInputRef( - rightDataTypeField.getType, - rightReferencesOffset + rightDataTypeField.getIndex) - } - -} - -object LogicalCorrelateToJoinFromTemporalTableFunctionRule { - val INSTANCE: RelOptRule = new LogicalCorrelateToJoinFromTemporalTableFunctionRule -} - -/** - * Simple pojo class for extracted [[TemporalTableFunction]] with time attribute extracted from - * RexNode with [[TemporalTableFunction]] call. - */ -case class TemporalTableFunctionCall( - var temporalTableFunction: TemporalTableFunction, - var timeAttribute: RexNode) {} - -/** Find [[TemporalTableFunction]] call and run [[CorrelatedFieldAccessRemoval]] on it's operand. */ -class GetTemporalTableFunctionCall(var rexBuilder: RexBuilder, var leftSide: RelNode) - extends RexVisitorImpl[TemporalTableFunctionCall](false) { - - def visit(node: RexNode): Option[TemporalTableFunctionCall] = { - val result = node.accept(this) - if (result == null) { - return None - } - Some(result) - } - - override def visitCall(rexCall: RexCall): TemporalTableFunctionCall = { - val functionDefinition = rexCall.getOperator match { - case tsf: TableSqlFunction => tsf.udtf - case bsf: BridgingSqlFunction => bsf.getDefinition - case _ => return null - } - - if (!functionDefinition.isInstanceOf[TemporalTableFunction]) { - return null - } - val temporalTableFunction = - functionDefinition.asInstanceOf[TemporalTableFunctionImpl] - - checkState( - rexCall.getOperands.size().equals(1), - "TemporalTableFunction call [%s] must have exactly one argument", - rexCall) - val correlatedFieldAccessRemoval = - new CorrelatedFieldAccessRemoval(temporalTableFunction, rexBuilder, leftSide) - TemporalTableFunctionCall( - temporalTableFunction, - rexCall.getOperands.get(0).accept(correlatedFieldAccessRemoval)) - } -} - -/** - * This converts field accesses like `$cor0.o_rowtime` to valid input references for join condition - * context without `$cor` reference. - */ -class CorrelatedFieldAccessRemoval( - var temporalTableFunction: TemporalTableFunctionImpl, - var rexBuilder: RexBuilder, - var leftSide: RelNode) - extends RexDefaultVisitor[RexNode] { - - override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = { - val leftIndex = leftSide.getRowType.getFieldList.indexOf(fieldAccess.getField) - if (leftIndex < 0) { - throw new IllegalStateException( - s"Failed to find reference to field [${fieldAccess.getField}] in node [$leftSide]") - } - rexBuilder.makeInputRef(leftSide, leftIndex) - } - - override def visitInputRef(inputRef: RexInputRef): RexNode = { - inputRef - } - - override def visitNode(rexNode: RexNode): RexNode = { - throw new ValidationException( - s"Unsupported argument [$rexNode] " + - s"in ${classOf[TemporalTableFunction].getSimpleName} call of " + - s"[${temporalTableFunction.getUnderlyingHistoryTable}] table") - } -} From 9a86bbefc31a9347a6e12303e6e6acd04f2d1de9 Mon Sep 17 00:00:00 2001 From: yongliu Date: Mon, 12 Jan 2026 10:16:14 +0800 Subject: [PATCH 2/2] [FLINK-36988][table] fix review --- ...teToJoinFromTemporalTableFunctionRule.java | 122 +++++++++--------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java index a0b0c11fb8bd6..d80062919e505 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java @@ -125,66 +125,66 @@ public void onMatch(RelOptRuleCall call) { new GetTemporalTableFunctionCall(cluster.getRexBuilder(), leftNode) .visit(rightTableFunctionScan.getCall()); - if (temporalTableFunctionCall.isPresent() + if (!(temporalTableFunctionCall.isPresent() && temporalTableFunctionCall.get().getTemporalTableFunction() - instanceof TemporalTableFunctionImpl) { - TemporalTableFunctionImpl rightTemporalTableFunction = - (TemporalTableFunctionImpl) - temporalTableFunctionCall.get().getTemporalTableFunction(); - RexNode leftTimeAttribute = temporalTableFunctionCall.get().getTimeAttribute(); - - // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin - QueryOperation underlyingHistoryTable = - rightTemporalTableFunction.getUnderlyingHistoryTable(); - RexBuilder rexBuilder = cluster.getRexBuilder(); - - FlinkOptimizeContext flinkContext = - (FlinkOptimizeContext) ShortcutUtils.unwrapContext(call.getPlanner()); - FlinkRelBuilder relBuilder = flinkContext.getFlinkRelBuilder(); - - RelNode temporalTable = relBuilder.queryOperation(underlyingHistoryTable).build(); - // expand QueryOperationCatalogViewTable in Table Scan - ExpandTableScanShuttle shuttle = new ExpandTableScanShuttle(); - RelNode rightNode = temporalTable.accept(shuttle); - - RexNode rightTimeIndicatorExpression = - createRightExpression( - rexBuilder, - leftNode, - rightNode, - extractNameFromTimeAttribute( - rightTemporalTableFunction.getTimeAttribute())); - - RexNode rightPrimaryKeyExpression = - createRightExpression( - rexBuilder, - leftNode, - rightNode, - extractNameFromPrimaryKeyAttribute( - rightTemporalTableFunction.getPrimaryKey())); - - relBuilder.push(leftNode); - relBuilder.push(rightNode); - - RexNode condition; - if (isProctimeReference(rightTemporalTableFunction)) { - condition = - TemporalJoinUtil.makeProcTimeTemporalFunctionJoinConCall( - rexBuilder, leftTimeAttribute, rightPrimaryKeyExpression); - } else { - condition = - TemporalJoinUtil.makeRowTimeTemporalFunctionJoinConCall( - rexBuilder, - leftTimeAttribute, - rightTimeIndicatorExpression, - rightPrimaryKeyExpression); - } + instanceof TemporalTableFunctionImpl)) { + return; + } - relBuilder.join(JoinRelType.INNER, condition); - call.transformTo(relBuilder.build()); + TemporalTableFunctionImpl rightTemporalTableFunction = + (TemporalTableFunctionImpl) + temporalTableFunctionCall.get().getTemporalTableFunction(); + RexNode leftTimeAttribute = temporalTableFunctionCall.get().getTimeAttribute(); + + // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin + QueryOperation underlyingHistoryTable = + rightTemporalTableFunction.getUnderlyingHistoryTable(); + RexBuilder rexBuilder = cluster.getRexBuilder(); + + FlinkOptimizeContext flinkContext = + (FlinkOptimizeContext) ShortcutUtils.unwrapContext(call.getPlanner()); + FlinkRelBuilder relBuilder = flinkContext.getFlinkRelBuilder(); + + RelNode temporalTable = relBuilder.queryOperation(underlyingHistoryTable).build(); + // expand QueryOperationCatalogViewTable in Table Scan + ExpandTableScanShuttle shuttle = new ExpandTableScanShuttle(); + RelNode rightNode = temporalTable.accept(shuttle); + + RexNode rightTimeIndicatorExpression = + createRightExpression( + rexBuilder, + leftNode, + rightNode, + extractNameFromTimeAttribute( + rightTemporalTableFunction.getTimeAttribute())); + + RexNode rightPrimaryKeyExpression = + createRightExpression( + rexBuilder, + leftNode, + rightNode, + extractNameFromPrimaryKeyAttribute( + rightTemporalTableFunction.getPrimaryKey())); + + relBuilder.push(leftNode); + relBuilder.push(rightNode); + + RexNode condition; + if (isProctimeReference(rightTemporalTableFunction)) { + condition = + TemporalJoinUtil.makeProcTimeTemporalFunctionJoinConCall( + rexBuilder, leftTimeAttribute, rightPrimaryKeyExpression); } else { - // Do nothing and handle standard TableFunction + condition = + TemporalJoinUtil.makeRowTimeTemporalFunctionJoinConCall( + rexBuilder, + leftTimeAttribute, + rightTimeIndicatorExpression, + rightPrimaryKeyExpression); } + + relBuilder.join(JoinRelType.INNER, condition); + call.transformTo(relBuilder.build()); } private RexNode createRightExpression( @@ -246,10 +246,6 @@ public TemporalTableFunction getTemporalTableFunction() { return temporalTableFunction; } - public void setTemporalTableFunction(TemporalTableFunction temporalTableFunction) { - this.temporalTableFunction = temporalTableFunction; - } - public RexNode getTimeAttribute() { return timeAttribute; } @@ -313,9 +309,9 @@ public TemporalTableFunctionCall visitCall(RexCall rexCall) { * context without `$cor` reference. */ class CorrelatedFieldAccessRemoval extends RexDefaultVisitor { - private TemporalTableFunctionImpl temporalTableFunction; - private RexBuilder rexBuilder; - private RelNode leftSide; + private final TemporalTableFunctionImpl temporalTableFunction; + private final RexBuilder rexBuilder; + private final RelNode leftSide; public CorrelatedFieldAccessRemoval( TemporalTableFunctionImpl temporalTableFunction,