diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java index f16758a78836c..dae2f477488fc 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java @@ -130,6 +130,17 @@ public class OptimizerConfigOptions { + " if the source extends from SupportsStatisticReport and the statistics from catalog is UNKNOWN." + "Default value is true."); + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH) + public static final ConfigOption TABLE_OPTIMIZER_STORAGE_PARTITION_JOIN_ENABLED = + key("table.optimizer.storage-partition-join-enabled") + .booleanType() + .defaultValue(false) + .withDescription( + "When it is true, the optimizer will try to use storage partition join for the join operation " + + "if the source table is partitioned by the join key. Default value is false. " + + "Note that this option only works in batch mode and requires the source table to be partitioned by the join key." + + " If the source table is not partitioned by the join key, it will fall back to other join strategies."); + @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH_STREAMING) public static final ConfigOption TABLE_OPTIMIZER_JOIN_REORDER_ENABLED = key("table.optimizer.join-reorder-enabled") diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsPartitioning.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsPartitioning.java new file mode 100644 index 0000000000000..6a6942ef2d907 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsPartitioning.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.connector.source.abilities; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.connector.source.ScanTableSource; +import org.apache.flink.table.connector.source.partitioning.Partitioning; + +/** + * Enables {@link ScanTableSource} to discover source partitions and inform the optimizer + * accordingly. + * + *

Partitions split the data stored in an external system into smaller portions that are + * identified by partition keys. + * + *

For example, data can be partitioned by dt and within a dt partitioned by user_id. + * the table definition could look like partition by (dt, bucket(user_id, 10)) and the partition + * values can be ("2023-10-01", 0), ("2023-10-01", 1), ("2023-10-02", 0), ... + * + *

In the example above, the partition keys = [dt, bucket(user_id, 10)]. the optimizer might utilize + * this pre-partitioned data source to eliminate possible shuffle operation. + */ +@PublicEvolving +public interface SupportsPartitioning { + + /** Returns the output data partitioning that this reader guarantees. */ + Partitioning outputPartitioning(); + + /** Applies partitioned reading to the source operator. */ + void applyPartitionedRead(); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/KeyGroupedPartitioning.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/KeyGroupedPartitioning.java new file mode 100644 index 0000000000000..eb65578e30d26 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/KeyGroupedPartitioning.java @@ -0,0 +1,103 @@ +package org.apache.flink.table.connector.source.partitioning; + +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.expressions.TransformExpression; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +/** + * Key-grouped partitioning implementation for table sources. + * + *

TODO Consider relaxing this constraint in a future version + * Preconditions: + * 1. keys are ordered by the partition columns defined in the table schema. + * 2. the partition values are ordered by the values in Row, comparing the values from 1st to last. + * for example: + * if a table is partitioned by (dt, bucket(128, user_id)) + * then the partition keys = [dt, bucket(128, user_id)]. It cannot be [bucket(128, user_id), dt]. + * the partition values can be ("2023-10-01", 0), ("2023-10-01", 1), ("2023-10-02", 0), ... + * it cannot be ("2023-10-01", 1), ("2023-10-01", 0), ("2023-10-02", 0), ... + */ +public class KeyGroupedPartitioning implements Partitioning { + private final TransformExpression[] keys; + private final int numPartitions; + private final Row[] partitionValues; + + public KeyGroupedPartitioning(TransformExpression[] keys, Row[] partitionValues, int numPartitions) { + this.keys = keys; + this.numPartitions = numPartitions; + this.partitionValues = partitionValues; + } + + /** + * Returns the partition transform expressions for this partitioning. + */ + public TransformExpression[] keys() { + return keys; + } + + public Row[] getPartitionValues() { + return partitionValues; + } + + @Override + public int numPartitions() { + return numPartitions; + } + + /** + * Checks if this partitioning is compatible with another KeyGroupedPartitioning. + * conditions: + * 1. numPartitions is the same + * 2. keys length is the same and for each key,keys are compatible + * 3. RowData length is the same. values are the same. + * + * @param other the other KeyGroupedPartitioning to check compatibility with + * @return true if compatible, false otherwise + */ + public boolean isCompatible(KeyGroupedPartitioning other) { + if (other == null) { + return false; + } + + // 1. Check numPartitions is the same + if (this.numPartitions != other.numPartitions) { + return false; + } + + // 2. Check keys length is the same and each key is compatible + if (this.keys.length != other.keys.length) { + return false; + } + + for (int i = 0; i < this.keys.length; i++) { + if (!this.keys[i].isCompatible(other.keys[i])) { + return false; + } + } + + // 3. Check RowData length and values are the same + if (this.partitionValues.length != other.partitionValues.length) { + return false; + } + + for (int i = 0; i < this.partitionValues.length; i++) { + Row thisRow = this.partitionValues[i]; + Row otherRow = other.partitionValues[i]; + + if (thisRow.getArity() != otherRow.getArity()) { + return false; + } + + for (int j = 0; j < thisRow.getArity(); j++) { + // filed in row cannot be null + Preconditions.checkArgument(thisRow.getField(j) != null); + if (!thisRow.getField(j).equals(otherRow.getField(j))) { + return false; + } + } + } + + return true; + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/Partitioning.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/Partitioning.java new file mode 100644 index 0000000000000..0f9faf3bd43c2 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/partitioning/Partitioning.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.connector.source.partitioning; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.connector.source.abilities.SupportsPartitioning; + +import java.util.Optional; + +/** + * Base interface for defining how data is partitioned across multiple partitions. + * Used by table sources that support partitioning. + */ +@PublicEvolving +public interface Partitioning { + /** + * Returns the number of partitions that the data is split across. + */ + int numPartitions(); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TransformExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TransformExpression.java new file mode 100644 index 0000000000000..f1e8dda0ce587 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TransformExpression.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions; + +import org.apache.flink.annotation.PublicEvolving; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents a transform expression that can be used for partitioning or other transformations. + * It consists of a key, an optional function name, and an optional number of buckets. + */ +@PublicEvolving +public class TransformExpression { + private final String key; + private final Optional functionName; + private final Optional numBucketsOpt; + + /** + * Creates a new TransformExpression with the given key, function name, and number of buckets. + * + * @param key the key to be transformed + * @param functionName the name of the transform function, can be null + * @param numBuckets the number of buckets for bucket transforms, can be null + */ + public TransformExpression( + @Nonnull String key, + @Nullable String functionName, + @Nullable Integer numBuckets) { + this.key = Objects.requireNonNull(key, "Key must not be null"); + this.functionName = Optional.ofNullable(functionName); + this.numBucketsOpt = Optional.ofNullable(numBuckets); + } + + /** + * Returns the key to be transformed. + * + * @return the key + */ + public String getKey() { + return key; + } + + /** + * Returns the name of the transform function, if present. + * + * @return the function name, or empty if not set + */ + public Optional getFunctionName() { + return functionName; + } + + /** + * Returns the number of buckets if this is a bucket transform, or empty otherwise. + * + * @return the number of buckets, or empty if not a bucket transform + */ + public Optional getNumBucketsOpt() { + return numBucketsOpt; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TransformExpression that = (TransformExpression) o; + return key.equals(that.key) && + functionName.equals(that.functionName) && + numBucketsOpt.equals(that.numBucketsOpt); + } + + @Override + public int hashCode() { + return Objects.hash(key, functionName, numBucketsOpt); + } + + @Override + public String toString() { + if (functionName.isPresent()) { + StringBuilder builder = new StringBuilder() + .append(functionName.get()) + .append("(") + .append(key); + if (numBucketsOpt.isPresent()) { + builder.append(", ").append(numBucketsOpt.get()); + } + return builder.append(")").toString(); + } + return key; + } + + /** * Checks if this TransformExpression is compatible with another TransformExpression. + * Compatibility is defined by having the same function name and number of buckets. + * examples: + * - bucket(128, user_id) is compatible with bucket(128, user_id_2) + * - year(dt) is compatible with year(dt) but not compatible with month(dt) + * + * TODO Support partial compatibility, e.g., bucket(128, user_id) is compatible with bucket(64, user_id_2) + */ + public boolean isCompatible(TransformExpression other) { + return + this.functionName.equals(other.functionName) && + this.numBucketsOpt.equals(other.numBucketsOpt); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/PartitioningSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/PartitioningSpec.java new file mode 100644 index 0000000000000..650b163d9aa54 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/PartitioningSpec.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.abilities.source; + +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.abilities.SupportsPartitioning; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName; + +import java.util.Objects; + +/** + * A sub-class of {@link SourceAbilitySpec} that can not only serialize/deserialize the limit value + * to/from JSON, but also can push the limit value into a {@link LimitPushDownSpec}. + */ +@JsonTypeName("Partitioning") +public final class PartitioningSpec extends SourceAbilitySpecBase { + + // it connects from PartitioningSpec in table planner + @Override + public void apply(DynamicTableSource tableSource, SourceAbilityContext context) { + if (tableSource instanceof SupportsPartitioning) { + ((SupportsPartitioning) tableSource).applyPartitionedRead(); + } else { + throw new TableException( + String.format( + "%s does not support SupportsPartitioning.", + tableSource.getClass().getName())); + } + } + + @Override + public boolean needAdjustFieldReferenceAfterProjection() { + return false; + } + + @Override + public String getDigests(SourceAbilityContext context) { + return "partitionedReading"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return super.equals(o); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode()); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java index e51328d5e9f25..594996ccff9ad 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java @@ -43,7 +43,8 @@ @JsonSubTypes.Type(value = ReadingMetadataSpec.class), @JsonSubTypes.Type(value = WatermarkPushDownSpec.class), @JsonSubTypes.Type(value = SourceWatermarkSpec.class), - @JsonSubTypes.Type(value = AggregatePushDownSpec.class) + @JsonSubTypes.Type(value = AggregatePushDownSpec.class), + @JsonSubTypes.Type(value = PartitioningSpec.class) }) @Internal public interface SourceAbilitySpec { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortMergeJoinRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortMergeJoinRule.scala index 676a2aac6a020..b6a817d77c81e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortMergeJoinRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortMergeJoinRule.scala @@ -20,21 +20,26 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.annotation.Experimental import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key +import org.apache.flink.table.api.config.OptimizerConfigOptions +import org.apache.flink.table.connector.source.partitioning.KeyGroupedPartitioning import org.apache.flink.table.planner.hint.JoinStrategy import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortMergeJoin -import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil +import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalJoin, FlinkLogicalTableSourceScan} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalSortMergeJoin, BatchPhysicalTableSourceScan} +import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, ScanUtil} import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.volcano.RelSubset import org.apache.calcite.rel.{RelCollations, RelNode} import org.apache.calcite.rel.core.Join import org.apache.calcite.util.ImmutableIntList +import org.apache.flink.table.connector.source.abilities.SupportsPartitioning import java.lang.{Boolean => JBoolean} +import java.util import scala.collection.JavaConversions._ @@ -54,20 +59,148 @@ class BatchPhysicalSortMergeJoinRule canUseJoinStrategy(join, tableConfig, JoinStrategy.SHUFFLE_MERGE) } + // TODO confirm this is the best practice for getting the table scan from a RelNode + private def getTableScan(relNode: RelNode): Option[FlinkLogicalTableSourceScan] = { + relNode match { + // Handle RelSubset by getting the best input + case subset: RelSubset => + // Get the best input or first input if no best is set + val best = Option(subset.getBest).orElse( + if (!subset.getRelList.isEmpty) Some(subset.getRelList.get(0)) + else None + ) + best.flatMap(getTableScan) + + // Handle different types of table scan nodes + case scan: FlinkLogicalTableSourceScan => Some(scan) + + // For other nodes with a single input + case node if node.getInputs.size() == 1 && !node.getInput(0).equals(node) => + getTableScan(node.getInput(0)) + + case _ => None + } + } + + private def isPartitionBy( + partition: KeyGroupedPartitioning, + fieldNames: util.List[String]): Boolean = { + val partitionKeys = partition.keys() + // Example: query with both tables partitioned by [dt, user_id]: + // SELECT count(*) FROM t1 JOIN t2 ON t1.dt = t2.dt AND t1.user_id = t2.user_id + // WHERE t1.dt = '2025-05-01' AND t2.dt = '2025-05-01' + // + // After filter pushdown optimization, the constant filter WHERE dt = '2025-05-01' + // may cause the 'dt' field to be pruned from fieldNames + // leaving only fieldNames = [user_id]. However, the original partition spec still + // contains [dt, user_id]. So we must check that the joinKey's remaining are still part + // of the partitionSpec + fieldNames.forall(fieldName => + partitionKeys.exists(partitionKey => partitionKey.getKey == fieldName) + ) + } + + private def canApplyStoragePartitionJoin(join: Join): Boolean = { + // TODO ensure the join condition is equal join, not like col1 + 1 = col2, or func(col1) = func(col2) + val joinInfo = join.analyzeCondition() + + // Return false if it's not an equi-join + if (joinInfo.nonEquiConditions != null && !joinInfo.nonEquiConditions.isEmpty) { + return false + } + + // Return false if there are no equi-join conditions + if (joinInfo.leftKeys.isEmpty || joinInfo.rightKeys.isEmpty) { + return false + } + + // Find all table scans in both branches + val leftTableScan = getTableScan(join.getLeft) + val rightTableScan = getTableScan(join.getRight) + + if (leftTableScan.isEmpty || rightTableScan.isEmpty) { + return false + } + + // Get the field names from the table scans + val leftFieldNames = leftTableScan.get.getRowType.getFieldNames + val rightFieldNames = rightTableScan.get.getRowType.getFieldNames + + // TODO: this won't work in case there is a projection in between join and table and adds extra columns + // for example: in testCannotPushDownProbeSideWithCalc, + // "select * from dim inner join (select fact_date_sk, RAND(10) as random from fact) " + // + "as factSide on dim.amount = factSide.random and dim.price < 500"; + // the join condition is dim.amount = factSide.random, but the right side table doesn't contains random + + // Map join keys to field names + val leftJoinFields = (0 until joinInfo.leftKeys.size()).map { + i => leftFieldNames.get(joinInfo.leftKeys.get(i)) + } + + val rightJoinFields = (0 until joinInfo.rightKeys.size()).map { + i => rightFieldNames.get(joinInfo.rightKeys.get(i)) + } + + // conditions: + // 1. leftPartition is partitioned by leftFieldNames + // 2. rightPartition is partitioned by rightFieldNames + // 3. leftPartition is compatible with rightPartition + + val leftPartition = ScanUtil.getPartition(leftTableScan.get.relOptTable) + val rightPartition = ScanUtil.getPartition(rightTableScan.get.relOptTable) + + // ensure both leftPartition and rightPartition are KeyGroupedPartitioning class + if (leftPartition.isEmpty || rightPartition.isEmpty) { + return false + } + if ( + !leftPartition.get.isInstanceOf[KeyGroupedPartitioning] || + !rightPartition.get.isInstanceOf[KeyGroupedPartitioning] + ) { + return false + } + val leftKeyGroupedPartitioning = leftPartition.get.asInstanceOf[KeyGroupedPartitioning] + val rightKeyGroupedPartitioning = rightPartition.get.asInstanceOf[KeyGroupedPartitioning] + isPartitionBy(leftKeyGroupedPartitioning, leftJoinFields) && + isPartitionBy(rightKeyGroupedPartitioning, rightJoinFields) && + leftKeyGroupedPartitioning.isCompatible(rightKeyGroupedPartitioning) + } + override def onMatch(call: RelOptRuleCall): Unit = { val join: Join = call.rel(0) val joinInfo = join.analyzeCondition val left = join.getLeft val right = join.getRight + val tableConfig = unwrapTableConfig(join) + val canApplyPartitionJoin = + tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_STORAGE_PARTITION_JOIN_ENABLED) && + canApplyStoragePartitionJoin(join) + if (canApplyPartitionJoin) { + ScanUtil.applyPartitionedRead(getTableScan(join.getLeft).get.relOptTable) + ScanUtil.applyPartitionedRead(getTableScan(join.getRight).get.relOptTable) + } + def getTraitSetByShuffleKeys( shuffleKeys: ImmutableIntList, requireStrict: Boolean, requireCollation: Boolean): RelTraitSet = { - var traitSet = call.getPlanner - .emptyTraitSet() - .replace(FlinkConventions.BATCH_PHYSICAL) - .replace(FlinkRelDistribution.hash(shuffleKeys, requireStrict)) + var traitSet = if (canApplyPartitionJoin) { + // precondition requireCollation is always false when using ANY distribution + // this is related to TABLE_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED + // this can only be true if that is set + assert(!requireCollation, "requireCollation should be false when using ANY distribution") + call.getPlanner + .emptyTraitSet() + .replace(FlinkConventions.BATCH_PHYSICAL) + .replace(FlinkRelDistribution.ANY) + } else { + call.getPlanner + .emptyTraitSet() + .replace(FlinkConventions.BATCH_PHYSICAL) + .replace(FlinkRelDistribution.hash(shuffleKeys, requireStrict)) + } + if (requireCollation) { val fieldCollations = shuffleKeys.map(FlinkRelOptUtil.ofRelFieldCollation(_)) val relCollation = RelCollations.of(fieldCollations) @@ -105,7 +238,6 @@ class BatchPhysicalSortMergeJoinRule call.transformTo(newJoin) } - val tableConfig = unwrapTableConfig(join) val candidates = if (tableConfig.get(BatchPhysicalSortMergeJoinRule.TABLE_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED)) { // add more possibility to remove redundant sort, and longer optimization time diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 0f1005e51f067..234c4bada289f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFu import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction} import org.apache.flink.table.planner.functions.utils.AggSqlFunction -import org.apache.flink.table.runtime.functions.aggregate.{BuiltInAggregateFunction, CollectAggFunction, FirstValueAggFunction, FirstValueWithRetractAggFunction, JsonArrayAggFunction, JsonObjectAggFunction, LagAggFunction, LastValueAggFunction, LastValueWithRetractAggFunction, ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction, MaxWithRetractAggFunction, MinWithRetractAggFunction} +import org.apache.flink.table.runtime.functions.aggregate.{ArrayAggFunction, BuiltInAggregateFunction, CollectAggFunction, FirstValueAggFunction, FirstValueWithRetractAggFunction, JsonArrayAggFunction, JsonObjectAggFunction, LagAggFunction, LastValueAggFunction, LastValueWithRetractAggFunction, ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction, MaxWithRetractAggFunction, MinWithRetractAggFunction} import org.apache.flink.table.runtime.functions.aggregate.BatchApproxCountDistinctAggFunctions._ import org.apache.flink.table.types.logical._ import org.apache.flink.table.types.logical.LogicalTypeRoot._ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/ScanUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/ScanUtil.scala index 0d72a479aaec9..82689fdf5aa67 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/ScanUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/ScanUtil.scala @@ -19,11 +19,16 @@ package org.apache.flink.table.planner.plan.utils import org.apache.flink.api.dag.Transformation import org.apache.flink.table.api.TableException +import org.apache.flink.table.catalog.ResolvedCatalogTable +import org.apache.flink.table.connector.source.abilities.SupportsPartitioning +import org.apache.flink.table.connector.source.partitioning.Partitioning import org.apache.flink.table.data.{GenericRowData, RowData} import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ExprCodeGenerator, OperatorCodeGenerator} import org.apache.flink.table.planner.codegen.CodeGenUtils.{DEFAULT_INPUT1_TERM, GENERIC_ROW} import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.generateCollect import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil +import org.apache.flink.table.planner.plan.schema.TableSourceTable +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.runtime.typeutils.InternalTypeInfo @@ -160,4 +165,21 @@ object ScanUtil { index }.toArray } + + def getPartition(tableSourceTable: TableSourceTable): Option[Partitioning] = { + val tableSource = tableSourceTable.tableSource + if (!tableSource.isInstanceOf[SupportsPartitioning]) { + return None + } + Some(tableSource.asInstanceOf[SupportsPartitioning].outputPartitioning) + } + + def applyPartitionedRead(tableSourceTable: TableSourceTable): Unit = { + val tableSource = tableSourceTable.tableSource + tableSource match { + case partitioning: SupportsPartitioning => + partitioning.applyPartitionedRead() + case _ => + } + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializer.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializer.java new file mode 100644 index 0000000000000..c8ca7b5ea85f2 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializer.java @@ -0,0 +1,135 @@ +package org.apache.flink.connector.source; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.table.connector.source.partitioning.KeyGroupedPartitioning; +import org.apache.flink.table.connector.source.partitioning.Partitioning; +import org.apache.flink.table.expressions.TransformExpression; +import org.apache.flink.types.Row; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Utility class for serializing and deserializing Partitioning instances to/from JSON. + */ +public class PartitionSerializer { + + private static final ObjectMapper objectMapper = JacksonMapperFactory.createObjectMapper(); + + /** + * Serializes a Partitioning instance to JSON string. + * Currently only supports KeyGroupedPartitioning. + * + * @param partitioning the Partitioning instance to serialize + * @return JSON string representation + * @throws IOException if serialization fails + * @throws IllegalArgumentException if partitioning is not a KeyGroupedPartitioning + */ + public static String serialize(Partitioning partitioning) throws IOException { + if (!(partitioning instanceof KeyGroupedPartitioning)) { + throw new IllegalArgumentException( + "Only KeyGroupedPartitioning is supported. Got: " + partitioning.getClass().getSimpleName()); + } + + KeyGroupedPartitioning keyGroupedPartitioning = (KeyGroupedPartitioning) partitioning; + ObjectNode rootNode = objectMapper.createObjectNode(); + // Serialize numPartitions + rootNode.put("numPartitions", keyGroupedPartitioning.numPartitions()); + // Serialize keys (TransformExpression array) + ArrayNode keysNode = objectMapper.createArrayNode(); + for (TransformExpression key : keyGroupedPartitioning.keys()) { + ObjectNode keyNode = objectMapper.createObjectNode(); + keyNode.put("key", key.getKey()); + if (key.getFunctionName().isPresent()) { + keyNode.put("functionName", key.getFunctionName().get()); + } + if (key.getNumBucketsOpt().isPresent()) { + keyNode.put("numBuckets", key.getNumBucketsOpt().get()); + } + keysNode.add(keyNode); + } + rootNode.set("keys", keysNode); + // Serialize partition values (Row array) + ArrayNode partitionValuesNode = objectMapper.createArrayNode(); + for (Row row : keyGroupedPartitioning.getPartitionValues()) { + ArrayNode rowNode = objectMapper.createArrayNode(); + for (int i = 0; i < row.getArity(); i++) { + Object field = row.getField(i); + if (field == null) { + rowNode.addNull(); + } else { + // Convert field to JSON node based on its type + JsonNode fieldNode = objectMapper.valueToTree(field); + rowNode.add(fieldNode); + } + } + partitionValuesNode.add(rowNode); + } + rootNode.set("partitionValues", partitionValuesNode); + + return objectMapper.writeValueAsString(rootNode); + } + + /** + * Deserializes a JSON string to Partitioning instance. + * Currently only supports KeyGroupedPartitioning. + * + * @param json the JSON string to deserialize + * @return Partitioning instance + * @throws IOException if deserialization fails + */ + public static Partitioning deserialize(String json) throws IOException { + JsonNode rootNode = objectMapper.readTree(json); + + // Deserialize numPartitions + int numPartitions = rootNode.get("numPartitions").asInt(); + + // Deserialize keys + JsonNode keysNode = rootNode.get("keys"); + List keysList = new ArrayList<>(); + for (JsonNode keyNode : keysNode) { + String key = keyNode.get("key").asText(); + String functionName = keyNode.has("functionName") ? keyNode.get("functionName").asText() : null; + Integer numBuckets = keyNode.has("numBuckets") ? keyNode.get("numBuckets").asInt() : null; + + TransformExpression transformExpression = new TransformExpression(key, functionName, numBuckets); + keysList.add(transformExpression); + } + TransformExpression[] keys = keysList.toArray(new TransformExpression[0]); + + // Deserialize partition values + JsonNode partitionValuesNode = rootNode.get("partitionValues"); + List partitionValuesList = new ArrayList<>(); + for (JsonNode rowNode : partitionValuesNode) { + List fields = new ArrayList<>(); + for (JsonNode fieldNode : rowNode) { + if (fieldNode.isNull()) { + fields.add(null); + } else if (fieldNode.isTextual()) { + fields.add(fieldNode.asText()); + } else if (fieldNode.isInt()) { + fields.add(fieldNode.asInt()); + } else if (fieldNode.isLong()) { + fields.add(fieldNode.asLong()); + } else if (fieldNode.isDouble()) { + fields.add(fieldNode.asDouble()); + } else if (fieldNode.isBoolean()) { + fields.add(fieldNode.asBoolean()); + } else { + // For other types, try to convert to string + fields.add(fieldNode.asText()); + } + } + Row row = Row.of(fields.toArray()); + partitionValuesList.add(row); + } + Row[] partitionValues = partitionValuesList.toArray(new Row[0]); + + return new KeyGroupedPartitioning(keys, partitionValues, numPartitions); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializerTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializerTest.java new file mode 100644 index 0000000000000..380a10bd5bf6f --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/connector/source/PartitionSerializerTest.java @@ -0,0 +1,321 @@ +package org.apache.flink.connector.source; + +import org.apache.flink.table.connector.source.partitioning.KeyGroupedPartitioning; +import org.apache.flink.table.connector.source.partitioning.Partitioning; +import org.apache.flink.table.expressions.TransformExpression; +import org.apache.flink.types.Row; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for {@link PartitionSerializer}. + */ +public class PartitionSerializerTest { + + @Test + public void testSerializeAndDeserializeBasic() throws IOException { + // Create test data + TransformExpression[] keys = { + new TransformExpression("dt", null, null), + new TransformExpression("user_id", "bucket", 128) + }; + + Row[] partitionValues = { + Row.of("2023-10-01", 0), + Row.of("2023-10-01", 1), + Row.of("2023-10-02", 0) + }; + + int numPartitions = 3; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning( + keys, + partitionValues, + numPartitions); + + // Serialize + String json = PartitionSerializer.serialize(original); + assertNotNull(json); + assertFalse(json.isEmpty()); + + // Deserialize + Partitioning deserialized = PartitionSerializer.deserialize(json); + + // Verify round-trip + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + assertEquals(original.numPartitions(), deserialized.numPartitions()); + assertEquals(original.keys().length, ((KeyGroupedPartitioning) deserialized).keys().length); + assertEquals( + original.getPartitionValues().length, + ((KeyGroupedPartitioning) deserialized).getPartitionValues().length); + } + + @Test + public void testSerializeWithOnlyKeyNames() throws IOException { + // Test with transform expressions that only have key names (no functions) + TransformExpression[] keys = { + new TransformExpression("year", null, null), + new TransformExpression("month", null, null) + }; + + Row[] partitionValues = { + Row.of(2023, 10), + Row.of(2023, 11), + Row.of(2024, 1) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 3); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + } + + @Test + public void testSerializeWithMixedDataTypes() throws IOException { + // Test with various data types in partition values + TransformExpression[] keys = { + new TransformExpression("category", null, null), + new TransformExpression("count", null, null), + new TransformExpression("rate", null, null), + new TransformExpression("active", null, null)}; + + Row[] partitionValues = { + Row.of("electronics", 100, 3.14, true), + Row.of("books", 50, 2.75, false), + Row.of("clothing", 200, 4.99, true) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 3); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + + // Verify specific values + Row[] deserializedValues = ((KeyGroupedPartitioning) deserialized).getPartitionValues(); + assertEquals("electronics", deserializedValues[0].getField(0)); + assertEquals(100, deserializedValues[0].getField(1)); + assertEquals(3.14, deserializedValues[0].getField(2)); + assertEquals(true, deserializedValues[0].getField(3)); + } + + @Test + public void testSerializeWithFunctionAndBuckets() throws IOException { + // Test with transform expressions having function names and bucket counts + TransformExpression[] keys = { + new TransformExpression("user_id", "bucket", 256), + new TransformExpression("timestamp", "hour", null), + new TransformExpression("country", "hash", 64) + }; + + Row[] partitionValues = { + Row.of(123, 14, 1), + Row.of(45, 15, 2), + Row.of(67, 16, 3) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 3); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + + // Verify transform expressions + TransformExpression[] deserializedKeys = ((KeyGroupedPartitioning) deserialized).keys(); + assertEquals("user_id", deserializedKeys[0].getKey()); + assertEquals("bucket", deserializedKeys[0].getFunctionName().get()); + assertEquals(256, deserializedKeys[0].getNumBucketsOpt().get().intValue()); + + assertEquals("timestamp", deserializedKeys[1].getKey()); + assertEquals("hour", deserializedKeys[1].getFunctionName().get()); + assertFalse(deserializedKeys[1].getNumBucketsOpt().isPresent()); + + assertEquals("country", deserializedKeys[2].getKey()); + assertEquals("hash", deserializedKeys[2].getFunctionName().get()); + assertEquals(64, deserializedKeys[2].getNumBucketsOpt().get().intValue()); + } + + @Test + public void testSerializeEmptyPartitionValues() throws IOException { + // Test with empty partition values array + TransformExpression[] keys = { + new TransformExpression("dt", null, null) + }; + + Row[] partitionValues = {}; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 0); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + assertEquals(0, ((KeyGroupedPartitioning) deserialized).getPartitionValues().length); + assertEquals(0, deserialized.numPartitions()); + } + + @Test + public void testSerializeWithLongValues() throws IOException { + // Test with long values + TransformExpression[] keys = { + new TransformExpression("timestamp", null, null), + new TransformExpression("id", null, null) + }; + + Row[] partitionValues = { + Row.of(1698768000000L, 123456789012345L), + Row.of(1698854400000L, 987654321098765L) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 2); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + + // Verify long values are preserved + Row[] deserializedValues = ((KeyGroupedPartitioning) deserialized).getPartitionValues(); + assertEquals(1698768000000L, deserializedValues[0].getField(0)); + assertEquals(123456789012345L, deserializedValues[0].getField(1)); + } + + @Test + public void testRoundTripConsistency() throws IOException { + // Test multiple round trips to ensure consistency + TransformExpression[] keys = { + new TransformExpression("partition_key", "bucket", 100) + }; + + Row[] partitionValues = { + Row.of("value1", 42), + Row.of("value2", 84) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 2); + + // First round trip + String json1 = PartitionSerializer.serialize(original); + Partitioning deserialized1 = PartitionSerializer.deserialize(json1); + + // Second round trip + String json2 = PartitionSerializer.serialize(deserialized1); + Partitioning deserialized2 = PartitionSerializer.deserialize(json2); + + // Both should be compatible with original + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized1)); + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized2)); + assertTrue(((KeyGroupedPartitioning) deserialized1).isCompatible((KeyGroupedPartitioning) deserialized2)); + + // JSON strings should be identical + assertEquals(json1, json2); + } + + @Test + public void testSerializeWithSingleKey() throws IOException { + // Test with single key and single partition value + TransformExpression[] keys = { + new TransformExpression("single_key", null, null) + }; + + Row[] partitionValues = { + Row.of("single_value") + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 1); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + assertEquals(1, ((KeyGroupedPartitioning) deserialized).keys().length); + assertEquals(1, ((KeyGroupedPartitioning) deserialized).getPartitionValues().length); + assertEquals(1, deserialized.numPartitions()); + assertEquals("single_key", ((KeyGroupedPartitioning) deserialized).keys()[0].getKey()); + assertEquals( + "single_value", + ((KeyGroupedPartitioning) deserialized).getPartitionValues()[0].getField(0)); + } + + @Test + public void testSerializeWithNumericStrings() throws IOException { + // Test with string values that look like numbers + TransformExpression[] keys = { + new TransformExpression("code", null, null) + }; + + Row[] partitionValues = { + Row.of("001"), + Row.of("002"), + Row.of("999") + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 3); + + String json = PartitionSerializer.serialize(original); + Partitioning deserialized = PartitionSerializer.deserialize(json); + + assertTrue(original.isCompatible((KeyGroupedPartitioning) deserialized)); + + // Verify string values are preserved as strings + Row[] deserializedValues = ((KeyGroupedPartitioning) deserialized).getPartitionValues(); + assertEquals("001", deserializedValues[0].getField(0)); + assertEquals("002", deserializedValues[1].getField(0)); + assertEquals("999", deserializedValues[2].getField(0)); + } + + @Test + public void testJsonStructure() throws IOException { + // Test that JSON contains expected structure + TransformExpression[] keys = { + new TransformExpression("test_key", "test_func", 42) + }; + + Row[] partitionValues = { + Row.of("test_value", 123) + }; + + KeyGroupedPartitioning original = new KeyGroupedPartitioning(keys, partitionValues, 1); + + String json = PartitionSerializer.serialize(original); + + // Basic JSON structure validation + assertTrue(json.contains("\"numPartitions\"")); + assertTrue(json.contains("\"keys\"")); + assertTrue(json.contains("\"partitionValues\"")); + assertTrue(json.contains("\"key\":\"test_key\"")); + assertTrue(json.contains("\"functionName\":\"test_func\"")); + assertTrue(json.contains("\"numBuckets\":42")); + assertTrue(json.contains("\"test_value\"")); + assertTrue(json.contains("123")); + } + + @Test + public void testSerializeUnsupportedPartitioningType() { + // Test that unsupported partitioning types throw IllegalArgumentException + Partitioning unsupportedPartitioning = new Partitioning() { + @Override + public int numPartitions() { + return 1; + } + }; + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + PartitionSerializer.serialize(unsupportedPartitioning); + }); + + assertTrue(exception.getMessage().contains("Only KeyGroupedPartitioning is supported")); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index a116ff9044ed5..d8670e48c4b4e 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -29,6 +29,7 @@ import org.apache.flink.configuration.ConfigOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.source.DynamicFilteringValuesSource; +import org.apache.flink.connector.source.PartitionSerializer; import org.apache.flink.connector.source.ValuesSource; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSink; @@ -49,7 +50,6 @@ import org.apache.flink.table.connector.sink.DynamicTableSink; import org.apache.flink.table.connector.sink.OutputFormatProvider; import org.apache.flink.table.connector.sink.SinkFunctionProvider; -import org.apache.flink.table.connector.sink.abilities.SupportsPartitioning; import org.apache.flink.table.connector.sink.abilities.SupportsWritingMetadata; import org.apache.flink.table.connector.source.AsyncTableFunctionProvider; import org.apache.flink.table.connector.source.DataStreamScanProvider; @@ -65,6 +65,7 @@ import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; import org.apache.flink.table.connector.source.abilities.SupportsLimitPushDown; import org.apache.flink.table.connector.source.abilities.SupportsPartitionPushDown; +import org.apache.flink.table.connector.source.abilities.SupportsPartitioning; import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown; import org.apache.flink.table.connector.source.abilities.SupportsReadingMetadata; import org.apache.flink.table.connector.source.abilities.SupportsSourceWatermark; @@ -79,6 +80,7 @@ import org.apache.flink.table.connector.source.lookup.cache.LookupCache; import org.apache.flink.table.connector.source.lookup.cache.trigger.CacheReloadTrigger; import org.apache.flink.table.connector.source.lookup.cache.trigger.PeriodicCacheReloadTrigger; +import org.apache.flink.table.connector.source.partitioning.Partitioning; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.util.DataFormatConverters; @@ -413,6 +415,12 @@ private static RowKind parseRowKind(String rowKindShortString) { "Option to specify the amount of time to sleep after processing every N elements. " + "The default value is 0, which means that no sleep is performed"); + private static final ConfigOption SOURCE_PARTITIONING = + ConfigOptions.key("source.partitioning") + .stringType() + .noDefaultValue() + .withDescription("specify the partitioning"); + /** * Parse partition list from Options with the format as * "key1:val1,key2:val2;key1:val3,key2:val4". @@ -453,6 +461,15 @@ public DynamicTableSource createDynamicTableSource(Context context) { int lookupThreshold = helper.getOptions().get(LOOKUP_THRESHOLD); int sleepAfterElements = helper.getOptions().get(SOURCE_SLEEP_AFTER_ELEMENTS); long sleepTimeMillis = helper.getOptions().get(SOURCE_SLEEP_TIME).toMillis(); + String partitioning = helper.getOptions().get(SOURCE_PARTITIONING); + Partitioning sourcePartitioning = null; + if (partitioning != null) { + try { + sourcePartitioning = PartitionSerializer.deserialize(partitioning); + } catch (IOException e) { + throw new RuntimeException(e); + } + } DefaultLookupCache cache = null; if (helper.getOptions().get(CACHE_TYPE).equals(LookupOptions.LookupCacheType.PARTIAL)) { cache = DefaultLookupCache.fromConfig(helper.getOptions()); @@ -516,7 +533,8 @@ public DynamicTableSource createDynamicTableSource(Context context) { Long.MAX_VALUE, partitions, readableMetadata, - null); + null, + sourcePartitioning); } if (disableLookup) { @@ -537,7 +555,8 @@ public DynamicTableSource createDynamicTableSource(Context context) { Long.MAX_VALUE, partitions, readableMetadata, - null); + null, + sourcePartitioning); } else { return new TestValuesScanTableSource( producedDataType, @@ -555,7 +574,8 @@ public DynamicTableSource createDynamicTableSource(Context context) { Long.MAX_VALUE, partitions, readableMetadata, - null); + null, + sourcePartitioning); } } else { return new TestValuesScanLookupTableSource( @@ -580,7 +600,8 @@ public DynamicTableSource createDynamicTableSource(Context context) { null, cache, reloadTrigger, - lookupThreshold); + lookupThreshold, + sourcePartitioning); } } else { try { @@ -697,7 +718,8 @@ public Set> optionalOptions() { FULL_CACHE_PERIODIC_RELOAD_INTERVAL, FULL_CACHE_PERIODIC_RELOAD_SCHEDULE_MODE, FULL_CACHE_TIMED_RELOAD_ISO_TIME, - FULL_CACHE_TIMED_RELOAD_INTERVAL_IN_DAYS)); + FULL_CACHE_TIMED_RELOAD_INTERVAL_IN_DAYS, + SOURCE_PARTITIONING)); } private static int validateAndExtractRowtimeIndex( @@ -833,7 +855,8 @@ private static class TestValuesScanTableSourceWithoutProjectionPushDown SupportsPartitionPushDown, SupportsReadingMetadata, SupportsAggregatePushDown, - SupportsDynamicFiltering { + SupportsDynamicFiltering, + SupportsPartitioning { protected DataType producedDataType; protected final ChangelogMode changelogMode; @@ -852,6 +875,7 @@ private static class TestValuesScanTableSourceWithoutProjectionPushDown protected List> allPartitions; protected final Map readableMetadata; protected @Nullable int[] projectedMetadataFields; + protected final @Nullable Partitioning partitioning; private @Nullable int[] groupingSet; private List aggregateExpressions; @@ -873,7 +897,8 @@ private TestValuesScanTableSourceWithoutProjectionPushDown( long limit, List> allPartitions, Map readableMetadata, - @Nullable int[] projectedMetadataFields) { + @Nullable int[] projectedMetadataFields, + @Nullable Partitioning partitioning) { this.producedDataType = producedDataType; this.changelogMode = changelogMode; this.bounded = bounded; @@ -890,6 +915,7 @@ private TestValuesScanTableSourceWithoutProjectionPushDown( this.allPartitions = allPartitions; this.readableMetadata = readableMetadata; this.projectedMetadataFields = projectedMetadataFields; + this.partitioning = partitioning; this.groupingSet = null; this.aggregateExpressions = Collections.emptyList(); } @@ -1025,7 +1051,8 @@ public DynamicTableSource copy() { limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); } @Override @@ -1342,6 +1369,16 @@ public List listAcceptedFilterFields() { public void applyDynamicFiltering(List candidateFilterFields) { acceptedPartitionFilterFields = candidateFilterFields; } + + @Override + public Partitioning outputPartitioning() { + return partitioning; + } + + @Override + public void applyPartitionedRead() { + // Do nothing as per requirement + } } /** Values {@link ScanTableSource} for testing that supports projection push down. */ @@ -1365,7 +1402,8 @@ private TestValuesScanTableSource( long limit, List> allPartitions, Map readableMetadata, - @Nullable int[] projectedMetadataFields) { + @Nullable int[] projectedMetadataFields, + @Nullable Partitioning partitioning) { super( producedDataType, changelogMode, @@ -1382,7 +1420,8 @@ private TestValuesScanTableSource( limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); } @Override @@ -1403,7 +1442,8 @@ public DynamicTableSource copy() { limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); } @Override @@ -1444,7 +1484,8 @@ private TestValuesScanTableSourceWithWatermarkPushDown( long limit, List> allPartitions, Map readableMetadata, - @Nullable int[] projectedMetadataFields) { + @Nullable int[] projectedMetadataFields, + @Nullable Partitioning partitioning) { super( producedDataType, changelogMode, @@ -1461,7 +1502,8 @@ private TestValuesScanTableSourceWithWatermarkPushDown( limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); this.tableName = tableName; } @@ -1514,7 +1556,8 @@ public DynamicTableSource copy() { limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); newSource.watermarkStrategy = watermarkStrategy; return newSource; } @@ -1559,7 +1602,8 @@ private TestValuesScanLookupTableSource( @Nullable int[] projectedMetadataFields, @Nullable LookupCache cache, @Nullable CacheReloadTrigger reloadTrigger, - int lookupThreshold) { + int lookupThreshold, + @Nullable Partitioning partitioning) { super( producedDataType, changelogMode, @@ -1576,7 +1620,8 @@ private TestValuesScanLookupTableSource( limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + partitioning); this.originType = originType; this.lookupFunctionClass = lookupFunctionClass; this.isAsync = isAsync; @@ -1765,7 +1810,8 @@ public DynamicTableSource copy() { projectedMetadataFields, cache, reloadTrigger, - lookupThreshold); + lookupThreshold, + partitioning); } } @@ -1836,7 +1882,7 @@ public String asSummaryString() { /** Values {@link DynamicTableSink} for testing. */ private static class TestValuesTableSink - implements DynamicTableSink, SupportsWritingMetadata, SupportsPartitioning { + implements DynamicTableSink, SupportsWritingMetadata, org.apache.flink.table.connector.sink.abilities.SupportsPartitioning { private DataType consumedDataType; private int[] primaryKeyIndices; diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/join/TestStoragePartitionJoin.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/join/TestStoragePartitionJoin.java new file mode 100644 index 0000000000000..5108d7c1d37f4 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/join/TestStoragePartitionJoin.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.batch.sql.join; + +import org.apache.flink.connector.source.PartitionSerializer; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.connector.source.partitioning.KeyGroupedPartitioning; +import org.apache.flink.table.expressions.TransformExpression; +import org.apache.flink.types.Row; +import org.apache.flink.table.planner.factories.TestValuesTableFactory; +import org.apache.flink.table.planner.runtime.utils.BatchTestBase; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.types.Row; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for storage partition join using sort merge join. */ + +/** TODO right now we test query plan only, need to add result verification. */ +public class TestStoragePartitionJoin extends BatchTestBase { + private TableEnvironment tEnv; + + // Common test data constants + private static final String[] TABLE1_COLUMNS = {"id", "name", "salary"}; + private static final String[] TABLE2_COLUMNS = {"id", "department", "location"}; + private static final String JOIN_SQL = + "SELECT t1.id, t1.name, t1.salary, t2.department, t2.location " + + "FROM %s t1 INNER JOIN %s t2 ON t1.id = t2.id"; + + @BeforeEach + @Override + public void before() throws Exception { + super.before(); + tEnv = tEnv(); + + // Disable other join operators to force sort merge join usage + tEnv + .getConfig() + .set( + ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, + "HashJoin, NestedLoopJoin"); + + // reset storage partition join config to false before each test + tEnv + .getConfig() + .set(OptimizerConfigOptions.TABLE_OPTIMIZER_STORAGE_PARTITION_JOIN_ENABLED, false); + + // Create test tables with sample data + setupTestTables(); + } + + /** + * Test that a basic sort merge join will add hash exchanges for both sides of the join. + */ + @Test + public void testBasicSortMergeJoinWithEqualCondition() { + String sql = String.format(JOIN_SQL, "table1", "table2"); + verifySortMergeJoinPlan( + sql, + "Exchange(distribution=[hash[id]])", + 2, + "Basic Sort Merge Join"); + } + + /** + * Test that storage partition join is disabled by default + * and does not add any exchanges for partitioned tables. + */ + @Test + public void testStoragePartitionJoinDisabledByDefault() { + createStandardPartitionedTables("partitioned_table1", "partitioned_table2"); + + String sql = String.format(JOIN_SQL, "partitioned_table1", "partitioned_table2"); + verifySortMergeJoinPlan( + sql, + "Exchange(distribution=[hash[id]])", + 2, + "Storage Partition Join Disabled"); + } + + @Test + public void testStoragePartitionJoin() { + tEnv + .getConfig() + .set(OptimizerConfigOptions.TABLE_OPTIMIZER_STORAGE_PARTITION_JOIN_ENABLED, true); + + createStandardPartitionedTables("partitioned_table1", "partitioned_table2"); + + String sql = String.format(JOIN_SQL, "partitioned_table1", "partitioned_table2"); + verifySortMergeJoinPlan( + sql, + "Exchange(distribution=[keep_input_as_is[hash[id]]])", + 2, + "Storage Partition Join Enabled"); + } + + /** + * Common utility to create a partitioned table with the given configuration. + */ + private void createPartitionedTable( + String tableName, + String[] columns, + KeyGroupedPartitioning partitioning) { + // Build column definitions - all columns are VARCHAR + StringBuilder columnDefs = new StringBuilder(); + for (int i = 0; i < columns.length; i++) { + if (i > 0) { + columnDefs.append(",\n"); + } + columnDefs.append(" ").append(columns[i]).append(" VARCHAR"); + } + + // Serialize the partitioning using PartitionSerializer + String partitionString; + try { + partitionString = PartitionSerializer.serialize(partitioning); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize partitioning", e); + } + + // Create the table SQL with partitioning property + String createTableSql = "CREATE TABLE " + tableName + " (\n" + + columnDefs.toString() + "\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'bounded' = 'true',\n" + + " 'source.partitioning' = '" + partitionString + "'\n" + + ")"; + // Execute the SQL statement + tEnv.executeSql(createTableSql); + } + + /** + * Common utility to create standard partitioned tables for testing. + */ + private void createStandardPartitionedTables(String table1Name, String table2Name) { + // Create partitioned table1 with columns: id, name, salary + TransformExpression[] table1Keys = {new TransformExpression("id", null, null)}; + Row[] table1PartitionValues = {Row.of("1"), Row.of("2")}; + KeyGroupedPartitioning table1Partitioning = new KeyGroupedPartitioning( + table1Keys, + table1PartitionValues, + 2); + createPartitionedTable(table1Name, TABLE1_COLUMNS, table1Partitioning); + + // Create partitioned table2 with columns: id, department, location + TransformExpression[] table2Keys = {new TransformExpression("id", null, null)}; + Row[] table2PartitionValues = {Row.of("1"), Row.of("2")}; + KeyGroupedPartitioning table2Partitioning = new KeyGroupedPartitioning( + table2Keys, + table2PartitionValues, + 2); + createPartitionedTable(table2Name, TABLE2_COLUMNS, table2Partitioning); + } + + /** + * Common utility to execute SQL and extract optimized execution plan. + */ + private String getOptimizedExecutionPlan(String sql) { + String explainResult = tEnv.explainSql(sql); + + // Extract only the Optimized Execution Plan section + String[] sections = explainResult.split("== Optimized Execution Plan =="); + return sections.length > 1 ? sections[1].trim() : explainResult; + } + + /** + * Common utility to verify exchange patterns in execution plan. + */ + private void verifyExchangePattern( + String optimizedExecutionPlan, + String exchangePattern, + int expectedCount) { + long exchangeCount = Arrays.stream(optimizedExecutionPlan.split("\n")) + .filter(line -> line.contains(exchangePattern)) + .count(); + assertThat(exchangeCount).isEqualTo(expectedCount); + } + + /** + * Common utility to verify sort merge join execution plan. + */ + private void verifySortMergeJoinPlan( + String sql, + String exchangePattern, + int expectedExchangeCount, + String testDescription) { + String optimizedExecutionPlan = getOptimizedExecutionPlan(sql); + System.out.println( + testDescription + " - Optimized Execution Plan: " + optimizedExecutionPlan); + + assertThat(optimizedExecutionPlan).contains("SortMergeJoin"); + verifyExchangePattern(optimizedExecutionPlan, exchangePattern, expectedExchangeCount); + } + + private void setupTestTables() { + // Create table1 + tEnv.executeSql( + "CREATE TABLE table1 (\n" + + " id INT,\n" + + " name VARCHAR,\n" + + " salary INT\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'bounded' = 'true'\n" + + ")"); + + // Create table2 + tEnv.executeSql( + "CREATE TABLE table2 (\n" + + " id INT,\n" + + " department VARCHAR,\n" + + " location VARCHAR\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'bounded' = 'true'\n" + + ")"); + } +}