Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import scala.collection.JavaConversions._
* for the standard logical algebra.
*/
class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
private val MaxGeneratedEnrichedKeys = 128

override def getDef: MetadataDef[UpsertKeys] = UpsertKeys.DEF

Expand Down Expand Up @@ -348,10 +349,12 @@ class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val leftKeys = fmq.getUpsertKeys(left)
val rightKeys = fmq.getUpsertKeys(right)
val leftFieldCount = left.getRowType.getFieldCount

FlinkRelMdUniqueKeys.INSTANCE.getJoinUniqueKeys(
// First get the base join unique keys
val baseKeys = FlinkRelMdUniqueKeys.INSTANCE.getJoinUniqueKeys(
joinRelType,
left.getRowType.getFieldCount,
leftFieldCount,
// Retain only keys whose columns are contained in the join's equi-join columns
// (the distribution keys), ensuring the result remains an upsert key.
// Note: An Exchange typically applies this filtering already via fmq.getUpsertKeys(...).
Expand All @@ -361,6 +364,97 @@ class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
isSideUnique(leftKeys, joinInfo.leftSet),
isSideUnique(rightKeys, joinInfo.rightSet)
)

// Enrich the keys by substituting equivalent columns from equi-join conditions
// The base keys are in joined output space, so enrichment works directly
enrichJoinedKeys(baseKeys, joinInfo, joinRelType, leftFieldCount)
}

/**
* Enriches join result keys by substituting columns with their equivalents from equi-join
* conditions.
*
* @param keys
* The upsert keys in joined output coordinate space
* @param joinInfo
* The join information containing equi-join column pairs
* @param joinRelType
* The join type (to check nullability constraints)
* @param leftFieldCount
* The number of fields from the left side
* @return
* The enriched set of upsert keys
*/
private def enrichJoinedKeys(
keys: JSet[ImmutableBitSet],
joinInfo: JoinInfo,
joinRelType: JoinRelType,
leftFieldCount: Int): JSet[ImmutableBitSet] = {
val pairs = joinInfo.leftKeys.zip(joinInfo.rightKeys).map {
case (l, r) => (l.intValue(), r.intValue() + leftFieldCount)
}
enrichKeysWithEquivalences(keys, pairs, joinRelType)
}

/**
* Core enrichment logic: for each key and each column equivalence pair, generates enriched
* versions by substituting one column with its equivalent.
*
* For example, if a key is {a2, b2} and there's an equivalence a1 = a2, this generates the
* additional key {a1, b2}.
*
* The enrichment respects join type nullability by controlling substitution directions:
* - Right→Left (replace right col with left): only if left side is never NULL
* - Left→Right (replace left col with right): only if right side is never NULL
*
* This prevents invalid keys where the substituted column might be NULL, causing the remaining
* columns (which may not be unique by themselves) to incorrectly appear as a valid key.
*
* @param keys
* The upsert keys to enrich
* @param equivalentPairs
* Column equivalence pairs (leftCol, rightCol) in joined output coordinate space
* @param joinRelType
* The join type (determines allowed substitution directions)
* @return
* The enriched set of upsert keys (includes original keys)
*/
private def enrichKeysWithEquivalences(
keys: JSet[ImmutableBitSet],
equivalentPairs: java.lang.Iterable[(Int, Int)],
joinRelType: JoinRelType): JSet[ImmutableBitSet] = {

if (keys == null) return null

val allowRightToLeft = !joinRelType.generatesNullsOnLeft()
val allowLeftToRight = !joinRelType.generatesNullsOnRight()

val seen = new util.HashSet[ImmutableBitSet](keys.size() * 2)
val queue = new util.ArrayDeque[ImmutableBitSet]()

@inline def enqueue(k: ImmutableBitSet): Unit =
if (seen.size() < MaxGeneratedEnrichedKeys && seen.add(k)) queue.add(k)

@inline def expand(key: ImmutableBitSet): Unit = {
val it = equivalentPairs.iterator()
while (it.hasNext) {
val (l, r) = it.next()
if (allowRightToLeft && key.get(r)) enqueue(key.clear(r).set(l))
if (allowLeftToRight && key.get(l)) enqueue(key.clear(l).set(r))
}
}

// seed
val seedIt = keys.iterator()
while (seedIt.hasNext) enqueue(seedIt.next())

// fixpoint
while (!queue.isEmpty) {
expand(queue.poll())
if (seen.size() >= MaxGeneratedEnrichedKeys) return seen
}

seen
}

def getUpsertKeys(rel: SetOp, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,190 @@ LogicalProject(a1=[$0], b1=[$1], a2=[$2], b2=[$3])
<Resource name="optimized exec plan">
<![CDATA[
Values(tuples=[[]])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentInnerJoinBasic">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- LogicalProject(a1=[$0], b2=[$3], x1=[$1])
+- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1], changelogMode=[NONE])
+- Calc(select=[a1, b2, x1], changelogMode=[I,UA,D])
+- Join(joinType=[InnerJoin], where=[=(a1, a2)], select=[a1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey], changelogMode=[I,UA,D])
:- Exchange(distribution=[hash[a1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- Calc(select=[a1, b2, x1])
+- Join(joinType=[InnerJoin], where=[(a1 = a2)], select=[a1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentInnerJoinReverse">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1])
+- LogicalProject(a2=[$3], b1=[$1], x1=[$2])
+- LogicalJoin(condition=[=($0, $3)], joinType=[inner])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1], changelogMode=[NONE])
+- Calc(select=[a2, b1, x1], changelogMode=[I,UA,D])
+- Join(joinType=[InnerJoin], where=[=(a1, a2)], select=[a1, b1, x1, a2], leftInputSpec=[HasUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey], changelogMode=[I,UA,D])
:- Exchange(distribution=[hash[a1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2], metadata=[]]], fields=[a2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1])
+- Calc(select=[a2, b1, x1])
+- Join(joinType=[InnerJoin], where=[(a1 = a2)], select=[a1, b1, x1, a2], leftInputSpec=[HasUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2], metadata=[]]], fields=[a2])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentLeftJoin">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- LogicalProject(a1=[$0], b2=[$3], x1=[$1])
+- LogicalJoin(condition=[=($0, $2)], joinType=[left])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1], changelogMode=[NONE])
+- Calc(select=[a1, b2, x1], changelogMode=[I,UA,D])
+- Join(joinType=[LeftOuterJoin], where=[=(a1, a2)], select=[a1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey], changelogMode=[I,UA,D])
:- Exchange(distribution=[hash[a1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- Calc(select=[a1, b2, x1])
+- Join(joinType=[LeftOuterJoin], where=[(a1 = a2)], select=[a1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentMultipleEquiConditions">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- LogicalProject(a1=[$0], b2=[$4], x1=[$2])
+- LogicalJoin(condition=[AND(=($0, $3), =($1, $4))], joinType=[inner])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1], changelogMode=[NONE])
+- Calc(select=[a1, b2, x1], changelogMode=[I,UA,D])
+- Join(joinType=[InnerJoin], where=[AND(=(a1, a2), =(b1, b2))], select=[a1, b1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey], changelogMode=[I,UA,D])
:- Exchange(distribution=[hash[a1, b1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2, b2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, b2, x1])
+- Calc(select=[a1, b2, x1])
+- Join(joinType=[InnerJoin], where=[((a1 = a2) AND (b1 = b2))], select=[a1, b1, x1, a2, b2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
:- Exchange(distribution=[hash[a1, b1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1])
+- Exchange(distribution=[hash[a2, b2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2, b2], metadata=[]]], fields=[a2, b2])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentNegativeCase">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a1, c2, x1])
+- LogicalProject(a1=[$0], c2=[$4], x1=[$1])
+- LogicalJoin(condition=[=($0, $2)], joinType=[inner])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, c2, x1], upsertMaterialize=[true], changelogMode=[NONE])
+- Calc(select=[a1, c2, x1], changelogMode=[I,UB,UA,D])
+- Join(joinType=[InnerJoin], where=[=(a1, a2)], select=[a1, x1, a2, c2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[NoUniqueKey], changelogMode=[I,UB,UA,D])
:- Exchange(distribution=[hash[a1]], changelogMode=[I,UB,UA,D])
: +- ChangelogNormalize(key=[a1], changelogMode=[I,UB,UA,D])
: +- Exchange(distribution=[hash[a1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2]], changelogMode=[I,UB,UA,D])
+- Calc(select=[a2, c2], changelogMode=[I,UB,UA,D])
+- ChangelogNormalize(key=[a2, b2], changelogMode=[I,UB,UA,D])
+- Exchange(distribution=[hash[a2, b2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[a2, b2, c2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a1, c2, x1], upsertMaterialize=[true])
+- Calc(select=[a1, c2, x1])
+- Join(joinType=[InnerJoin], where=[(a1 = a2)], select=[a1, x1, a2, c2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[NoUniqueKey])
:- Exchange(distribution=[hash[a1]])
: +- ChangelogNormalize(key=[a1])
: +- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, x1])
+- Exchange(distribution=[hash[a2]])
+- Calc(select=[a2, c2])
+- ChangelogNormalize(key=[a2, b2])
+- Exchange(distribution=[hash[a2, b2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[a2, b2, c2])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinUpsertKeyEnrichmentRightJoin">
<Resource name="explain">
<![CDATA[== Abstract Syntax Tree ==
LogicalSink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1])
+- LogicalProject(a2=[$3], b1=[$1], x1=[$2])
+- LogicalJoin(condition=[=($0, $3)], joinType=[right])
:- LogicalTableScan(table=[[default_catalog, default_database, src1]])
+- LogicalTableScan(table=[[default_catalog, default_database, src2]])

== Optimized Physical Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1], changelogMode=[NONE])
+- Calc(select=[a2, b1, x1], changelogMode=[I,UA,D])
+- Join(joinType=[RightOuterJoin], where=[=(a1, a2)], select=[a1, b1, x1, a2], leftInputSpec=[HasUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey], changelogMode=[I,UA,D])
:- Exchange(distribution=[hash[a1]], changelogMode=[I,UA,D])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1], changelogMode=[I,UA,D])
+- Exchange(distribution=[hash[a2]], changelogMode=[I,UA,D])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2], metadata=[]]], fields=[a2], changelogMode=[I,UA,D])

== Optimized Execution Plan ==
Sink(table=[default_catalog.default_database.sink], fields=[a2, b1, x1])
+- Calc(select=[a2, b1, x1])
+- Join(joinType=[RightOuterJoin], where=[(a1 = a2)], select=[a1, b1, x1, a2], leftInputSpec=[HasUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
:- Exchange(distribution=[hash[a1]])
: +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a1, b1, x1])
+- Exchange(distribution=[hash[a2]])
+- TableSourceScan(table=[[default_catalog, default_database, src2, project=[a2], metadata=[]]], fields=[a2])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,23 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase {
@Test
def testGetUpsertKeysOnJoin(): Unit = {
assertEquals(
toBitSet(Array(1), Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)),
toBitSet(Array(1), Array(5), Array(1, 5), Array(1, 6), Array(5, 6), Array(1, 5, 6)),
mq.getUpsertKeys(logicalInnerJoinOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinNotOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinOnRHSUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithoutEquiCond).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithEquiAndNonEquiCond).toSet)

assertEquals(
toBitSet(Array(1), Array(1, 5), Array(1, 5, 6)),
toBitSet(Array(1), Array(1, 5), Array(1, 6), Array(1, 5, 6)),
mq.getUpsertKeys(logicalLeftJoinOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinNotOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinOnRHSUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithoutEquiCond).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithEquiAndNonEquiCond).toSet)

assertEquals(
toBitSet(Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)),
toBitSet(Array(5), Array(5, 6), Array(1, 5), Array(1, 5, 6)),
mq.getUpsertKeys(logicalRightJoinOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinNotOnUniqueKeys).toSet)
assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinOnLHSUniqueKeys).toSet)
Expand Down
Loading