Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.scripting.SqlScriptingExecution.withLocalVariableManager"),

// [SPARK-53391][CORE] Remove unused PrimitiveKeyOpenHashMap
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.PrimitiveKeyOpenHashMap*")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.PrimitiveKeyOpenHashMap*"),

// [SPARK-54041][SQL] Enable Direct Passthrough Partitioning in the DataFrame API
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById")
)

// Default exclude rules
Expand Down
170 changes: 86 additions & 84 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,28 @@ class Expression(google.protobuf.message.Message):
],
) -> None: ...

class DirectShufflePartitionID(google.protobuf.message.Message):
"""Expression that takes a partition ID value and passes it through directly for use in
shuffle partitioning. This is used with RepartitionByExpression to allow users to
directly specify target partition IDs.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

CHILD_FIELD_NUMBER: builtins.int
@property
def child(self) -> global___Expression:
"""(Required) The expression that evaluates to the partition ID."""
def __init__(
self,
*,
child: global___Expression | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["child", b"child"]
) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["child", b"child"]) -> None: ...

class Cast(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -1401,6 +1423,7 @@ class Expression(google.protobuf.message.Message):
MERGE_ACTION_FIELD_NUMBER: builtins.int
TYPED_AGGREGATE_EXPRESSION_FIELD_NUMBER: builtins.int
SUBQUERY_EXPRESSION_FIELD_NUMBER: builtins.int
DIRECT_SHUFFLE_PARTITION_ID_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def common(self) -> global___ExpressionCommon: ...
Expand Down Expand Up @@ -1447,6 +1470,8 @@ class Expression(google.protobuf.message.Message):
@property
def subquery_expression(self) -> global___SubqueryExpression: ...
@property
def direct_shuffle_partition_id(self) -> global___Expression.DirectShufflePartitionID: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -1476,6 +1501,7 @@ class Expression(google.protobuf.message.Message):
merge_action: global___MergeAction | None = ...,
typed_aggregate_expression: global___TypedAggregateExpression | None = ...,
subquery_expression: global___SubqueryExpression | None = ...,
direct_shuffle_partition_id: global___Expression.DirectShufflePartitionID | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -1491,6 +1517,8 @@ class Expression(google.protobuf.message.Message):
b"common",
"common_inline_user_defined_function",
b"common_inline_user_defined_function",
"direct_shuffle_partition_id",
b"direct_shuffle_partition_id",
"expr_type",
b"expr_type",
"expression_string",
Expand Down Expand Up @@ -1542,6 +1570,8 @@ class Expression(google.protobuf.message.Message):
b"common",
"common_inline_user_defined_function",
b"common_inline_user_defined_function",
"direct_shuffle_partition_id",
b"direct_shuffle_partition_id",
"expr_type",
b"expr_type",
"expression_string",
Expand Down Expand Up @@ -1604,6 +1634,7 @@ class Expression(google.protobuf.message.Message):
"merge_action",
"typed_aggregate_expression",
"subquery_expression",
"direct_shuffle_partition_id",
"extension",
]
| None
Expand Down
14 changes: 14 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2967,6 +2967,20 @@ abstract class Dataset[T] extends Serializable {
repartitionByRange(None, partitionExprs)
}

/**
* Repartition the Dataset into the given number of partitions using the specified partition ID
* expression.
*
* @param numPartitions
* the number of partitions to use.
* @param partitionIdExpr
* the expression to be used as the partition ID. Must be an integer type.
*
* @group typedrel
* @since 4.1.0
*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T]

/**
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ class PlanGenerationTestSuite
simple.repartitionByRange(fn.col("a").asc, fn.col("id").desc_nulls_first)
}

test("repartitionById") {
simple.repartitionById(10, fn.col("id").cast("int"))
}

test("coalesce") {
simple.coalesce(5)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,39 @@ class ClientE2ETestSuite
schema.fields.head.dataType.asInstanceOf[MapType].valueContainsNull === valueContainsNull)
}
}

test("SPARK-54043: DirectShufflePartitionID should be supported") {
val df = spark.range(100).withColumn("expected_p_id", col("id") % 10)
val repartitioned = df.repartitionById(10, col("expected_p_id").cast("int"))
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())

assert(result.filter(col("expected_p_id") =!= col("actual_p_id")).count() == 0)

val negativeDf = spark.range(10).toDF("id")
val negativeRepartitioned = negativeDf.repartitionById(10, (col("id") - 5).cast("int"))
val negativeResult =
negativeRepartitioned
.withColumn("actual_p_id", spark_partition_id())
.collect()

assert(negativeResult.forall(row => {
val actualPartitionId = row.getAs[Int]("actual_p_id")
val id = row.getAs[Long]("id")
val expectedPartitionId = {
val mod = (id - 5) % 10
if (mod < 0) mod + 10 else mod
}.toInt
actualPartitionId == expectedPartitionId
}))

val nullDf = spark.range(10).toDF("id")
val nullExpr = when(col("id") < 5, col("id")).otherwise(lit(null)).cast("int")
val nullRepartitioned = nullDf.repartitionById(10, nullExpr)
val nullResult = nullRepartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ message Expression {
MergeAction merge_action = 19;
TypedAggregateExpression typed_aggregate_expression = 20;
SubqueryExpression subquery_expression = 21;
DirectShufflePartitionID direct_shuffle_partition_id = 22;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -142,6 +143,14 @@ message Expression {
}
}

// Expression that takes a partition ID value and passes it through directly for use in
// shuffle partitioning. This is used with RepartitionByExpression to allow users to
// directly specify target partition IDs.
message DirectShufflePartitionID {
// (Required) The expression that evaluates to the partition ID.
Expression child = 1;
}

message Cast {
// (Required) the expression to be casted.
Expression expr = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,21 @@ class Dataset[T] private[sql] (
}
}

private def buildRepartitionById(numPartitions: Int, partitionExpr: Column): Dataset[T] = {
val exprBuilder = proto.Expression.newBuilder()
val directShufflePartitionIdExpr = exprBuilder
.setDirectShufflePartitionId(
exprBuilder.getDirectShufflePartitionIdBuilder
.setChild(toExpr(partitionExpr)))
.build()
sparkSession.newDataset(agnosticEncoder, partitionExpr :: Nil) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(Seq(directShufflePartitionIdExpr).asJava)
repartitionBuilder.setNumPartitions(numPartitions)
}
}

/** @inheritdoc */
def repartition(numPartitions: Int): Dataset[T] = {
buildRepartition(numPartitions, shuffle = true)
Expand Down Expand Up @@ -1046,6 +1061,11 @@ class Dataset[T] private[sql] (
buildRepartitionByExpression(numPartitions, sortExprs)
}

/** @inheritdoc */
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not implement it in sql/api ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buildRepartitionById(numPartitions, partitionIdExpr)
}

/** @inheritdoc */
def coalesce(numPartitions: Int): Dataset[T] = {
buildRepartition(numPartitions, shuffle = false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
RepartitionByExpression [direct_shuffle_partition_id(cast(id#0L as int))], 10
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"common": {
"planId": "1"
},
"repartitionByExpression": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"partitionExprs": [{
"directShufflePartitionId": {
"child": {
"cast": {
"expr": {
"unresolvedAttribute": {
"unparsedIdentifier": "id"
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.functions$",
"methodName": "col",
"fileName": "functions.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
},
"type": {
"integer": {
}
}
},
"common": {
"origin": {
"jvmOrigin": {
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.Column",
"methodName": "cast",
"fileName": "Column.scala"
}, {
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
"methodName": "~~trimmed~anonfun~~",
"fileName": "PlanGenerationTestSuite.scala"
}]
}
}
}
}
}
}],
"numPartitions": 10
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,8 @@ class SparkConnectPlanner(
transformTypedAggregateExpression(exp.getTypedAggregateExpression, baseRelationOpt)
case proto.Expression.ExprTypeCase.SUBQUERY_EXPRESSION =>
transformSubqueryExpression(exp.getSubqueryExpression)
case proto.Expression.ExprTypeCase.DIRECT_SHUFFLE_PARTITION_ID =>
transformDirectShufflePartitionID(exp.getDirectShufflePartitionId)
case other =>
throw InvalidInputErrors.invalidOneOfField(other, exp.getDescriptorForType)
}
Expand Down Expand Up @@ -4109,6 +4111,11 @@ class SparkConnectPlanner(
}
}

private def transformDirectShufflePartitionID(
directShufflePartitionID: proto.Expression.DirectShufflePartitionID): Expression = {
DirectShufflePartitionID(transformExpression(directShufflePartitionID.getChild))
}

private def transformWithRelations(getWithRelations: proto.WithRelations): LogicalPlan = {
if (isValidSQLWithRefs(getWithRelations)) {
transformSqlWithRefs(getWithRelations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,16 +1544,7 @@ class Dataset[T] private[sql](
}
}

/**
* Repartitions the Dataset into the given number of partitions using the specified
* partition ID expression.
*
* @param numPartitions the number of partitions to use.
* @param partitionIdExpr the expression to be used as the partition ID. Must be an integer type.
*
* @group typedrel
* @since 4.1.0
*/
/** @inheritdoc */
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr))
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
Expand Down