Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ class PlanGenerationTestSuite
simple.repartitionByRange(fn.col("a").asc, fn.col("id").desc_nulls_first)
}

test("repartitionById") {
simple.repartitionById(10, fn.col("partition_id"))
}

test("coalesce") {
simple.coalesce(5)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,44 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
Map("one" -> "1", "two" -> "2"))
assert(df.as(StringEncoder).collect().toSet == Set("one", "two"))
}

test("repartitionById e2e") {
val session = spark
import session.implicits._
import org.apache.spark.sql.functions._

val df = spark.range(100).withColumn("expected_p_id", col("id") % 10)
val repartitioned = df.repartitionById(10, $"expected_p_id".cast("int"))
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())

assert(result.filter(col("expected_p_id") =!= col("actual_p_id")).count() == 0)
assert(result.rdd.getNumPartitions == 10)

val negativeDf = spark.range(10).toDF("id")
val negativeRepartitioned = negativeDf.repartitionById(10, ($"id" - 5).cast("int"))
val negativeResult = negativeRepartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

assert(negativeResult.forall(row => {
val actualPartitionId = row.getAs[Int]("actual_p_id")
val id = row.getAs[Long]("id")
val expectedPartitionId = {
val mod = (id - 5) % 10
if (mod < 0) mod + 10 else mod
}.toInt
actualPartitionId == expectedPartitionId
}))

val nullDf = spark.range(10).toDF("id")
val nullExpr = when($"id" < 5, $"id").otherwise(lit(null)).cast("int")
val nullRepartitioned = nullDf.repartitionById(10, nullExpr)
val nullResult = nullRepartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))

val outOfBoundsDf = spark.range(20).toDF("id")
val outOfBoundsRepartitioned = outOfBoundsDf.repartitionById(10, $"id".cast("int"))
assert(outOfBoundsRepartitioned.collect().length == 20)
assert(outOfBoundsRepartitioned.rdd.getNumPartitions == 10)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ message Expression {
MergeAction merge_action = 19;
TypedAggregateExpression typed_aggregate_expression = 20;
SubqueryExpression subquery_expression = 21;
DirectShufflePartitionID direct_shuffle_partition_id = 22;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -559,3 +560,11 @@ message SubqueryExpression {
optional bool with_single_partition = 3;
}
}

// Expression that takes a partition ID value and passes it through directly for use in
// shuffle partitioning. This is used with RepartitionByExpression to allow users to
// directly specify target partition IDs.
message DirectShufflePartitionID {
// (Required) The expression that evaluates to the partition ID.
Expression child = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder, Obse
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
import org.apache.spark.sql.catalyst.expressions.{DirectShufflePartitionID, OrderUtils}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
Expand Down Expand Up @@ -1046,6 +1046,21 @@ class Dataset[T] private[sql] (
buildRepartitionByExpression(numPartitions, sortExprs)
}

/**
* Repartitions the Dataset into the given number of partitions using the specified
* partition ID expression.
*
* @param numPartitions the number of partitions to use.
* @param partitionIdExpr the expression to be used as the partition ID. Must be an integer type.
*
* @group typedrel
* @since 4.1.0
*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not implement it in sql/api ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr))
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
}

/** @inheritdoc */
def coalesce(numPartitions: Int): Dataset[T] = {
buildRepartition(numPartitions, shuffle = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_D
import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType}
import org.apache.spark.sql.{functions, Column, Encoder}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
Expand Down Expand Up @@ -233,6 +234,12 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
assert(relation.hasCommon && relation.getCommon.hasPlanId)
b.setPlanId(relation.getCommon.getPlanId)

case directShuffle @ DirectShufflePartitionID(child, _) =>
builder.setDirectShufflePartitionId(
builder.getDirectShufflePartitionIdBuilder
.setChild(apply(child, e, additionalTransformation))
)

case ProtoColumnNode(e, _) =>
return e

Expand Down