Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
Expand All @@ -208,11 +208,11 @@ object ResolveHints {
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand All @@ -224,7 +224,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
Expand All @@ -239,11 +239,11 @@ object ResolveHints {

hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.random.RandomSampler

Expand Down Expand Up @@ -948,16 +949,18 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
}

/**
* This method repartitions data using [[Expression]]s into `numPartitions`, and receives
* This method repartitions data using [[Expression]]s into `optNumPartitions`, and receives
* information about the number of partitions during execution. Used when a specific ordering or
* distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
* `coalesce` and `repartition`.
* `coalesce` and `repartition`. If no `optNumPartitions` is given, by default it partitions data
* into `numShufflePartitions` defined in `SQLConf`, and could be coalesced by AQE.
*/
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int) extends RepartitionOperation {
optNumPartitions: Option[Int]) extends RepartitionOperation {

val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions)
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")

val partitioning: Partitioning = {
Expand Down Expand Up @@ -985,6 +988,15 @@ case class RepartitionByExpression(
override def shuffle: Boolean = true
}

object RepartitionByExpression {
def apply(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int): RepartitionByExpression = {
RepartitionByExpression(partitionExpressions, child, Some(numPartitions))
}
}

/**
* A relation with one row. This is used in "SELECT ..." without a from clause.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
Seq(AttributeReference("a", IntegerType)()), testRelation, None))

val e = intercept[IllegalArgumentException] {
checkAnalysis(
Expand All @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
testRelation, conf.numShufflePartitions))
testRelation, None))

val errMsg2 = "REPARTITION Hint parameter should include columns, but"

Expand Down
54 changes: 33 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2991,17 +2991,9 @@ class Dataset[T] private[sql](
Repartition(numPartitions, shuffle = true, logicalPlan)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
private def repartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
Expand All @@ -3015,6 +3007,20 @@ class Dataset[T] private[sql](
}
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(Some(numPartitions), partitionExprs)
}

/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
* `spark.sql.shuffle.partitions` as number of partitions.
Expand All @@ -3027,7 +3033,20 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = {
repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByExpression(None, partitionExprs)
}

private def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
}

/**
Expand All @@ -3049,14 +3068,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
repartitionByRange(Some(numPartitions), partitionExprs)
}

/**
Expand All @@ -3078,7 +3090,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
repartitionByRange(None, partitionExprs)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
RepartitionByExpression(expressions, query, None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -754,7 +754,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
planLater(child), noUserSpecifiedNumPartition = false) :: Nil
planLater(child), REPARTITION_WITH_NUM) :: Nil
} else {
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
}
Expand Down Expand Up @@ -787,8 +787,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
exchange.ShuffleExchangeExec(
r.partitioning, planLater(r.child), noUserSpecifiedNumPartition = false) :: Nil
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
REPARTITION
} else {
REPARTITION_WITH_NUM
}
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -50,7 +52,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
val shuffleStages = collectShuffleStages(plan)
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
plan
} else {
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
Expand Down Expand Up @@ -85,6 +87,11 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
}
}
}

private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition &&
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
}
}

object CoalesceShufflePartitions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -142,9 +143,13 @@ object OptimizeLocalShuffleReader {

def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
s.mapStats.isDefined && supportLocalReader(s.shuffle)
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) =>
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
case _ => false
}

private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange {
def numPartitions: Int

/**
* Returns whether the shuffle partition number can be changed.
* The origin of this shuffle operator.
*/
def canChangeNumPartitions: Boolean
def shuffleOrigin: ShuffleOrigin

/**
* The asynchronous job that materializes the shuffle.
Expand All @@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange {
def runtimeStatistics: Statistics
}

// Describes where the shuffle operator comes from.
sealed trait ShuffleOrigin

// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
// means that the shuffle operator is used to ensure internal data partitioning requirements and
// Spark is free to optimize it as long as the requirements are still ensured.
case object ENSURE_REQUIREMENTS extends ShuffleOrigin

// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
case object REPARTITION extends ShuffleOrigin

// Indicates that the shuffle operator was added by the user-specified repartition operator with
// a certain partition number. Spark can't optimize it.
case object REPARTITION_WITH_NUM extends ShuffleOrigin

/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike {

// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
def canChangeNumPartitions: Boolean =
noUserSpecifiedNumPartition && outputPartitioning != SinglePartition
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
extends ShuffleExchangeLike {

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand Down
Loading