Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2558,8 +2558,8 @@ test_that("coalesce, repartition, numPartitions", {

df2 <- repartition(df1, 10)
expect_equal(getNumPartitions(df2), 10)
expect_equal(getNumPartitions(coalesce(df2, 13)), 5)
expect_equal(getNumPartitions(coalesce(df2, 7)), 5)
expect_equal(getNumPartitions(coalesce(df2, 13)), 10)
expect_equal(getNumPartitions(coalesce(df2, 7)), 7)
expect_equal(getNumPartitions(coalesce(df2, 3)), 3)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ package object dsl {

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None)

def coalesce(num: Integer): LogicalPlan =
Repartition(num, shuffle = false, logicalPlan)

def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,27 +562,23 @@ object CollapseProject extends Rule[LogicalPlan] {
}

/**
* Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
* by keeping only the one.
* 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]].
* 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
* 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single
* [[RepartitionByExpression]] with the expression and last number of partition.
* Combines adjacent [[RepartitionOperation]] operators
*/
object CollapseRepartition extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// Case 1
case Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
Repartition(numPartitions, shuffle, child)
// Case 2
case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) =>
RepartitionByExpression(exprs, child, numPartitions)
// Case 3
case Repartition(numPartitions, _, r: RepartitionByExpression) =>
r.copy(numPartitions = numPartitions)
// Case 3
case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
RepartitionByExpression(exprs, child, numPartitions)
// Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
// 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
// enables the shuffle. Returns the child node if the last numPartitions is bigger;
// otherwise, keep unchanged.
// 2) In the other cases, returns the top node with the child's child
case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match {
case (false, true) => if (r.numPartitions >= child.numPartitions) child else r
case _ => r.copy(child = child.child)
}
// Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
// we can remove the child.
case r @ RepartitionByExpression(_, child: RepartitionOperation, _) =>
r.copy(child = child.child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -835,16 +835,24 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

/**
* A base interface for [[RepartitionByExpression]] and [[Repartition]]
*/
abstract class RepartitionOperation extends UnaryNode {
def shuffle: Boolean
def numPartitions: Int
override def output: Seq[Attribute] = child.output
}

/**
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
* of the output requires some specific ordering or distribution of the data.
*/
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
extends UnaryNode {
extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
override def output: Seq[Attribute] = child.output
}

/**
Expand All @@ -856,12 +864,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
numPartitions: Int) extends UnaryNode {
numPartitions: Int) extends RepartitionOperation {

require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")

override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
override def shuffle: Boolean = true
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int)


test("collapse two adjacent coalesces into one") {
// Always respects the top coalesces amd removes useless coalesce below coalesce
val query1 = testRelation
.coalesce(10)
.coalesce(20)
val query2 = testRelation
.coalesce(30)
.coalesce(20)

val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.coalesce(20).analyze

comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}

test("collapse two adjacent repartitions into one") {
val query = testRelation
// Always respects the top repartition amd removes useless repartition below repartition
val query1 = testRelation
.repartition(10)
.repartition(20)
val query2 = testRelation
.repartition(30)
.repartition(20)

val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze

comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}

test("coalesce above repartition") {
// Remove useless coalesce above repartition
val query1 = testRelation
.repartition(10)
.coalesce(20)

val optimized1 = Optimize.execute(query1.analyze)
val correctAnswer1 = testRelation.repartition(10).analyze

comparePlans(optimized1, correctAnswer1)

// No change in this case
val query2 = testRelation
.repartition(30)
.coalesce(20)

val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer2 = query2.analyze

comparePlans(optimized2, correctAnswer2)
}

test("repartition above coalesce") {
// Always respects the top repartition amd removes useless coalesce below repartition
val query1 = testRelation
.coalesce(10)
.repartition(20)
val query2 = testRelation
.coalesce(30)
.repartition(20)

val optimized = Optimize.execute(query.analyze)
val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze

comparePlans(optimized, correctAnswer)
comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}

test("collapse repartition and repartitionBy into one") {
val query = testRelation
test("repartitionBy above repartition") {
// Always respects the top repartitionBy amd removes useless repartition
val query1 = testRelation
.repartition(10)
.distribute('a)(20)
val query2 = testRelation
.repartition(30)
.distribute('a)(20)

val optimized = Optimize.execute(query.analyze)
val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze

comparePlans(optimized, correctAnswer)
comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}

test("collapse repartitionBy and repartition into one") {
val query = testRelation
test("repartitionBy above coalesce") {
// Always respects the top repartitionBy amd removes useless coalesce below repartition
val query1 = testRelation
.coalesce(10)
.distribute('a)(20)
val query2 = testRelation
.coalesce(30)
.distribute('a)(20)
.repartition(10)

val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation.distribute('a)(10).analyze
val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze

comparePlans(optimized, correctAnswer)
comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}

test("repartition above repartitionBy") {
// Always respects the top repartition amd removes useless distribute below repartition
val query1 = testRelation
.distribute('a)(10)
.repartition(20)
val query2 = testRelation
.distribute('a)(30)
.repartition(20)

val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze

comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)

}

test("coalesce above repartitionBy") {
// Remove useless coalesce above repartition
val query1 = testRelation
.distribute('a)(10)
.coalesce(20)

val optimized1 = Optimize.execute(query1.analyze)
val correctAnswer1 = testRelation.distribute('a)(10).analyze

comparePlans(optimized1, correctAnswer1)

// No change in this case
val query2 = testRelation
.distribute('a)(30)
.coalesce(20)

val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer2 = query2.analyze

comparePlans(optimized2, correctAnswer2)
}

test("collapse two adjacent repartitionBys into one") {
val query = testRelation
// Always respects the top repartitionBy
val query1 = testRelation
.distribute('b)(10)
.distribute('a)(20)
val query2 = testRelation
.distribute('b)(30)
.distribute('a)(20)

val optimized = Optimize.execute(query.analyze)
val optimized1 = Optimize.execute(query1.analyze)
val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze

comparePlans(optimized, correctAnswer)
comparePlans(optimized1, correctAnswer)
comparePlans(optimized2, correctAnswer)
}
}
10 changes: 5 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2441,11 +2441,11 @@ class Dataset[T] private[sql](
}

/**
* Returns a new Dataset that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g.
* if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
* the 100 new partitions will claim 10 of the current partitions. If a larger number of
* partitions is requested, it will stay at the current number of partitions.
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
* number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
* a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
* be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
*
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
* this may result in your computation taking place on fewer nodes than
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext {
val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1)
assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2)
doubleRepartitioned.queryExecution.optimizedPlan match {
case r: Repartition =>
assert(r.numPartitions === 5)
assert(r.shuffle === false)
case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) =>
assert(numPartitions === 5)
assert(shuffle === false)
assert(shuffleChild === true)
}
}

Expand Down