Skip to content

Commit 7b2ae80

Browse files
committed
[FLINK-24860][python] Fix the wrong position mappings in the Python UDTF
This closes apache#17752.
1 parent ccf6f1a commit 7b2ae80

File tree

6 files changed

+161
-43
lines changed

6 files changed

+161
-43
lines changed

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java

+59-12
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
2424
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule;
2525
import org.apache.flink.table.planner.plan.utils.PythonUtil;
26+
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
2627

2728
import org.apache.calcite.plan.RelOptRule;
2829
import org.apache.calcite.plan.RelOptRuleCall;
2930
import org.apache.calcite.plan.hep.HepRelVertex;
3031
import org.apache.calcite.rel.RelNode;
3132
import org.apache.calcite.rel.type.RelDataType;
33+
import org.apache.calcite.rel.type.RelDataTypeField;
3234
import org.apache.calcite.rex.RexBuilder;
3335
import org.apache.calcite.rex.RexCall;
36+
import org.apache.calcite.rex.RexCorrelVariable;
3437
import org.apache.calcite.rex.RexFieldAccess;
3538
import org.apache.calcite.rex.RexInputRef;
3639
import org.apache.calcite.rex.RexNode;
@@ -112,10 +115,41 @@ private List<String> createNewFieldNames(
112115
for (int i = 0; i < primitiveFieldCount; i++) {
113116
calcProjects.add(RexInputRef.of(i, rowType));
114117
}
118+
// change RexCorrelVariable to RexInputRef.
119+
RexDefaultVisitor<RexNode> visitor =
120+
new RexDefaultVisitor<RexNode>() {
121+
@Override
122+
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
123+
RexNode expr = fieldAccess.getReferenceExpr();
124+
if (expr instanceof RexCorrelVariable) {
125+
RelDataTypeField field = fieldAccess.getField();
126+
return new RexInputRef(field.getIndex(), field.getType());
127+
} else {
128+
return rexBuilder.makeFieldAccess(
129+
expr.accept(this), fieldAccess.getField().getIndex());
130+
}
131+
}
132+
133+
@Override
134+
public RexNode visitNode(RexNode rexNode) {
135+
return rexNode;
136+
}
137+
};
115138
// add the fields of the extracted rex calls.
116139
Iterator<RexNode> iterator = extractedRexNodes.iterator();
117140
while (iterator.hasNext()) {
118-
calcProjects.add(iterator.next());
141+
RexNode rexNode = iterator.next();
142+
if (rexNode instanceof RexCall) {
143+
RexCall rexCall = (RexCall) rexNode;
144+
List<RexNode> newProjects =
145+
rexCall.getOperands().stream()
146+
.map(x -> x.accept(visitor))
147+
.collect(Collectors.toList());
148+
RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
149+
calcProjects.add(newRexCall);
150+
} else {
151+
calcProjects.add(rexNode);
152+
}
119153
}
120154

121155
List<String> nameList = new LinkedList<>();
@@ -252,18 +286,31 @@ public void onMatch(RelOptRuleCall call) {
252286
mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
253287
}
254288

255-
FlinkLogicalCalc leftCalc =
256-
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
289+
FlinkLogicalCorrelate newCorrelate;
290+
if (extractedRexNodes.size() > 0) {
291+
FlinkLogicalCalc leftCalc =
292+
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
257293

258-
FlinkLogicalCorrelate newCorrelate =
259-
new FlinkLogicalCorrelate(
260-
correlate.getCluster(),
261-
correlate.getTraitSet(),
262-
leftCalc,
263-
rightNewInput,
264-
correlate.getCorrelationId(),
265-
correlate.getRequiredColumns(),
266-
correlate.getJoinType());
294+
newCorrelate =
295+
new FlinkLogicalCorrelate(
296+
correlate.getCluster(),
297+
correlate.getTraitSet(),
298+
leftCalc,
299+
rightNewInput,
300+
correlate.getCorrelationId(),
301+
correlate.getRequiredColumns(),
302+
correlate.getJoinType());
303+
} else {
304+
newCorrelate =
305+
new FlinkLogicalCorrelate(
306+
correlate.getCluster(),
307+
correlate.getTraitSet(),
308+
left,
309+
rightNewInput,
310+
correlate.getCorrelationId(),
311+
correlate.getRequiredColumns(),
312+
correlate.getJoinType());
313+
}
267314

268315
FlinkLogicalCalc newTopCalc =
269316
createTopCalc(

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala

+8-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.function.Function
2222

2323
import org.apache.calcite.plan.RelOptRule.{any, operand}
2424
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
25-
import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
25+
import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
2626
import org.apache.calcite.sql.validate.SqlValidatorUtil
2727
import org.apache.flink.table.functions.ScalarFunction
2828
import org.apache.flink.table.functions.python.PythonFunctionKind
@@ -393,7 +393,13 @@ private class ScalarFunctionSplitter(
393393
expr match {
394394
case localRef: RexLocalRef if containsPythonCall(program.expandLocalRef(localRef))
395395
=> getExtractedRexFieldAccess(fieldAccess, localRef.getIndex)
396-
case _ => getExtractedRexNode(fieldAccess)
396+
case _: RexCorrelVariable =>
397+
val field = fieldAccess.getField
398+
new RexInputRef(field.getIndex, field.getType)
399+
case _ =>
400+
val newFieldAccess = rexBuilder.makeFieldAccess(
401+
expr.accept(this), fieldAccess.getField.getIndex)
402+
getExtractedRexNode(newFieldAccess)
397403
}
398404
} else {
399405
fieldAccess

flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
3131
</Resource>
3232
<Resource name="planAfter">
3333
<![CDATA[
34-
FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f10 AS f1], where=[AND(=(f0, 2), =(+(f10, 1), *(f10, f10)), =(f00, a))])
35-
+- FlinkLogicalCalc(select=[a, b, c, f00, f10, pyFunc(f00, f00) AS f0])
34+
FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1], where=[AND(=(f0, 2), =(+(f1, 1), *(f1, f1)), =(f00, a))])
35+
+- FlinkLogicalCalc(select=[a, b, c, f00, f1, pyFunc(f00, f00) AS f0])
3636
+- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
37-
:- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0, $cor0.b AS f1])
37+
:- FlinkLogicalCalc(select=[a, b, c, *(a, a) AS f0])
3838
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
39-
+- FlinkLogicalTableFunctionScan(invocation=[func($3, $4)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
39+
+- FlinkLogicalTableFunctionScan(invocation=[func($3, $1)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
4040
]]>
4141
</Resource>
4242
</TestCase>

flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRuleTest.xml

+8-9
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3])
3232
<![CDATA[
3333
FlinkLogicalCalc(select=[a, b, c, f00 AS f0])
3434
+- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
35-
:- FlinkLogicalCalc(select=[a, b, c, pyFunc(f0) AS f0])
36-
: +- FlinkLogicalCalc(select=[a, b, c, $cor0.c AS f0])
37-
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
35+
:- FlinkLogicalCalc(select=[a, b, c, pyFunc(c) AS f0])
36+
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
3837
+- FlinkLogicalTableFunctionScan(invocation=[javaFunc($3)], rowType=[RecordType(VARCHAR(2147483647) f0)], elementType=[class [Ljava.lang.Object;])
3938
]]>
4039
</Resource>
@@ -53,11 +52,11 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
5352
</Resource>
5453
<Resource name="planAfter">
5554
<![CDATA[
56-
FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f10 AS f1])
55+
FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1])
5756
+- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1, 2}])
58-
:- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0, $cor0.b AS f1, $cor0.c AS f2])
57+
:- FlinkLogicalCalc(select=[a, b, c, *(a, a) AS f0])
5958
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
60-
+- FlinkLogicalTableFunctionScan(invocation=[func($3, pyFunc($4, $5))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
59+
+- FlinkLogicalTableFunctionScan(invocation=[func($3, pyFunc($1, $2))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
6160
]]>
6261
</Resource>
6362
</TestCase>
@@ -78,7 +77,7 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$4])
7877
FlinkLogicalCalc(select=[a, b, c, f00 AS x])
7978
+- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}])
8079
:- FlinkLogicalCalc(select=[a, b, c, d, pyFunc(f0) AS f0])
81-
: +- FlinkLogicalCalc(select=[a, b, c, d, $cor0.d._1 AS f0])
80+
: +- FlinkLogicalCalc(select=[a, b, c, d, d._1 AS f0])
8281
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
8382
+- FlinkLogicalTableFunctionScan(invocation=[javaFunc($4)], rowType=[RecordType(VARCHAR(2147483647) f0)], elementType=[class [Ljava.lang.Object;])
8483
]]>
@@ -100,9 +99,9 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$4], y=[$5])
10099
<![CDATA[
101100
FlinkLogicalCalc(select=[a, b, c, f00 AS x, f10 AS y])
102101
+- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 2, 3}])
103-
:- FlinkLogicalCalc(select=[a, b, c, d, *($cor0.d._1, $cor0.a) AS f0, $cor0.d._2 AS f1, $cor0.c AS f2])
102+
:- FlinkLogicalCalc(select=[a, b, c, d, *(d._1, a) AS f0, d._2 AS f1])
104103
: +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
105-
+- FlinkLogicalTableFunctionScan(invocation=[func($4, pyFunc($5, $6))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
104+
+- FlinkLogicalTableFunctionScan(invocation=[func($4, pyFunc($5, $2))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
106105
]]>
107106
</Resource>
108107
</TestCase>

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/PythonCorrelateSplitRule.java

+69-14
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan;
2424
import org.apache.flink.table.plan.util.CorrelateUtil;
2525
import org.apache.flink.table.plan.util.PythonUtil;
26+
import org.apache.flink.table.plan.util.RexDefaultVisitor;
2627

2728
import org.apache.calcite.plan.RelOptRule;
2829
import org.apache.calcite.plan.RelOptRuleCall;
2930
import org.apache.calcite.plan.hep.HepRelVertex;
3031
import org.apache.calcite.rel.RelNode;
3132
import org.apache.calcite.rel.type.RelDataType;
33+
import org.apache.calcite.rel.type.RelDataTypeField;
3234
import org.apache.calcite.rex.RexBuilder;
3335
import org.apache.calcite.rex.RexCall;
36+
import org.apache.calcite.rex.RexCorrelVariable;
3437
import org.apache.calcite.rex.RexFieldAccess;
3538
import org.apache.calcite.rex.RexInputRef;
3639
import org.apache.calcite.rex.RexNode;
@@ -119,10 +122,41 @@ private List<String> createNewFieldNames(
119122
for (int i = 0; i < primitiveFieldCount; i++) {
120123
calcProjects.add(RexInputRef.of(i, rowType));
121124
}
125+
// change RexCorrelVariable to RexInputRef.
126+
RexDefaultVisitor<RexNode> visitor =
127+
new RexDefaultVisitor<RexNode>() {
128+
@Override
129+
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
130+
RexNode expr = fieldAccess.getReferenceExpr();
131+
if (expr instanceof RexCorrelVariable) {
132+
RelDataTypeField field = fieldAccess.getField();
133+
return new RexInputRef(field.getIndex(), field.getType());
134+
} else {
135+
return rexBuilder.makeFieldAccess(
136+
expr.accept(this), fieldAccess.getField().getIndex());
137+
}
138+
}
139+
140+
@Override
141+
public RexNode visitNode(RexNode rexNode) {
142+
return rexNode;
143+
}
144+
};
122145
// add the fields of the extracted rex calls.
123146
Iterator<RexNode> iterator = extractedRexNodes.iterator();
124147
while (iterator.hasNext()) {
125-
calcProjects.add(iterator.next());
148+
RexNode rexNode = iterator.next();
149+
if (rexNode instanceof RexCall) {
150+
RexCall rexCall = (RexCall) rexNode;
151+
List<RexNode> newProjects =
152+
rexCall.getOperands().stream()
153+
.map(x -> x.accept(visitor))
154+
.collect(Collectors.toList());
155+
RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
156+
calcProjects.add(newRexCall);
157+
} else {
158+
calcProjects.add(rexNode);
159+
}
126160
}
127161

128162
List<String> nameList = new LinkedList<>();
@@ -196,10 +230,12 @@ private FlinkLogicalCalc createTopCalc(
196230
}
197231

198232
private ScalarFunctionSplitter createScalarFunctionSplitter(
233+
RexBuilder rexBuilder,
199234
int primitiveLeftFieldCount,
200235
ArrayBuffer<RexNode> extractedRexNodes,
201236
RexNode tableFunctionNode) {
202237
return new ScalarFunctionSplitter(
238+
rexBuilder,
203239
primitiveLeftFieldCount,
204240
extractedRexNodes,
205241
node -> {
@@ -233,7 +269,10 @@ public void onMatch(RelOptRuleCall call) {
233269
createNewScan(
234270
scan,
235271
createScalarFunctionSplitter(
236-
primitiveLeftFieldCount, extractedRexNodes, scan.getCall()));
272+
rexBuilder,
273+
primitiveLeftFieldCount,
274+
extractedRexNodes,
275+
scan.getCall()));
237276
} else {
238277
FlinkLogicalCalc calc = (FlinkLogicalCalc) right;
239278
FlinkLogicalTableFunctionScan scan = CorrelateUtil.getTableFunctionScan(calc).get();
@@ -242,23 +281,39 @@ public void onMatch(RelOptRuleCall call) {
242281
createNewScan(
243282
scan,
244283
createScalarFunctionSplitter(
245-
primitiveLeftFieldCount, extractedRexNodes, scan.getCall()));
284+
rexBuilder,
285+
primitiveLeftFieldCount,
286+
extractedRexNodes,
287+
scan.getCall()));
246288
rightNewInput =
247289
mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
248290
}
249291

250-
FlinkLogicalCalc leftCalc =
251-
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
292+
FlinkLogicalCorrelate newCorrelate;
293+
if (extractedRexNodes.size() > 0) {
294+
FlinkLogicalCalc leftCalc =
295+
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
252296

253-
FlinkLogicalCorrelate newCorrelate =
254-
new FlinkLogicalCorrelate(
255-
correlate.getCluster(),
256-
correlate.getTraitSet(),
257-
leftCalc,
258-
rightNewInput,
259-
correlate.getCorrelationId(),
260-
correlate.getRequiredColumns(),
261-
correlate.getJoinType());
297+
newCorrelate =
298+
new FlinkLogicalCorrelate(
299+
correlate.getCluster(),
300+
correlate.getTraitSet(),
301+
leftCalc,
302+
rightNewInput,
303+
correlate.getCorrelationId(),
304+
correlate.getRequiredColumns(),
305+
correlate.getJoinType());
306+
} else {
307+
newCorrelate =
308+
new FlinkLogicalCorrelate(
309+
correlate.getCluster(),
310+
correlate.getTraitSet(),
311+
left,
312+
rightNewInput,
313+
correlate.getCorrelationId(),
314+
correlate.getRequiredColumns(),
315+
correlate.getJoinType());
316+
}
262317

263318
FlinkLogicalCalc newTopCalc =
264319
createTopCalc(

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonCalcSplitRule.scala

+13-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.function.Function
2222

2323
import org.apache.calcite.plan.RelOptRule.{any, operand}
2424
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
25-
import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexNode, RexProgram}
25+
import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexNode, RexProgram}
2626
import org.apache.calcite.sql.validate.SqlValidatorUtil
2727
import org.apache.flink.table.functions.ScalarFunction
2828
import org.apache.flink.table.functions.python.PythonFunctionKind
@@ -53,6 +53,7 @@ abstract class PythonCalcSplitRuleBase(description: String)
5353

5454
val extractedFunctionOffset = input.getRowType.getFieldCount
5555
val splitter = new ScalarFunctionSplitter(
56+
rexBuilder,
5657
extractedFunctionOffset,
5758
extractedRexNodes,
5859
new Function[RexNode, Boolean] {
@@ -304,6 +305,7 @@ object PythonCalcRewriteProjectionRule extends PythonCalcSplitRuleBase(
304305
}
305306

306307
private class ScalarFunctionSplitter(
308+
rexBuilder: RexBuilder,
307309
extractedFunctionOffset: Int,
308310
extractedRexNodes: mutable.ArrayBuffer[RexNode],
309311
needConvert: Function[RexNode, Boolean])
@@ -319,7 +321,16 @@ private class ScalarFunctionSplitter(
319321

320322
override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
321323
if (needConvert(fieldAccess)) {
322-
getExtractedRexNode(fieldAccess)
324+
val expr = fieldAccess.getReferenceExpr
325+
expr match {
326+
case _: RexCorrelVariable =>
327+
val field = fieldAccess.getField
328+
new RexInputRef(field.getIndex, field.getType)
329+
case _ =>
330+
val newFieldAccess = rexBuilder.makeFieldAccess(
331+
expr.accept(this), fieldAccess.getField.getIndex)
332+
getExtractedRexNode(newFieldAccess)
333+
}
323334
} else {
324335
fieldAccess
325336
}

0 commit comments

Comments
 (0)