Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -392,11 +393,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
results.toArray
}

private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = {
val countsAndBytes = getByteArrayRdd().collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2))
(total, rows)
/**
* Runs this query in async way and return the future of collect result.
*/
private[spark] def executeCollectIteratorFuture()(
implicit executor: ExecutionContext): Future[(Long, Iterator[InternalRow])] = {
val future = getByteArrayRdd().collectAsync()
future.map(countsAndBytes => {
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator
.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2))
(total, rows)
})
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.UUID
import java.util.concurrent._

import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.concurrent.duration.{Duration, NANOSECONDS}
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, SparkException}
Expand Down Expand Up @@ -107,15 +107,22 @@ case class BroadcastExchangeExec(

@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
interruptOnCancel = true)
val beforeCollect = System.nanoTime()
// SPARK-35414: use executeCollectIteratorFuture() to submit a job before get the relationFuture
// This can ensure the broadcast job is submitted before shuffle map job in AQE
val collectFuture = child.executeCollectIteratorFuture()(
BroadcastExchangeExec.executionContext)

SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
interruptOnCancel = true)
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, input) = child.executeCollectIterator()
val (numRows, input) = ThreadUtils.awaitResult(collectFuture, Duration.Inf)
longMetric("numOutputRows") += numRows
if (numRows >= MAX_BROADCAST_TABLE_ROWS) {
throw new SparkException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.net.URI
import org.apache.log4j.Level
import org.scalatest.PrivateMethodTester

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
Expand Down Expand Up @@ -1817,4 +1817,37 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") {
val shuffleMapTaskParallelism = 10

val df = spark.range(0, 1000, 1, shuffleMapTaskParallelism)
.select($"id" % 26, $"id" % 10)
.toDF("index", "pv")
val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString))
.toDF("index", "name")
.coalesce(1)
val testDf = df.groupBy("index")
.agg(sum($"pv").alias("pv"))
.join(dim, Seq("index"))

val stageInfos = scala.collection.mutable.ArrayBuffer[StageInfo]()
val listener = new SparkListener {
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
stageInfos += stageSubmitted.stageInfo
}
}
spark.sparkContext.addSparkListener(listener)

withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val result = testDf.collect()
assert(result.length == 26)
val sortedStageInfos = stageInfos.sortBy(_.submissionTime)
assert(sortedStageInfos.size > 2)
// this is broadcast stage
assert(sortedStageInfos(0).numTasks == 1)
// this is shuffle map stage
assert(sortedStageInfos(1).numTasks == shuffleMapTaskParallelism)
}
}
}