Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ case class InsertAdaptiveSparkPlan(
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(p)
verifyAdaptivePlan(executedPlan, p)
val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan)
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", executedPlan)
subqueryMap.put(exprId.id, subquery)
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
if !subqueryMap.contains(exprId.id) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ abstract class BaseSubqueryExec extends SparkPlan {
/**
* Physical plan for a subquery.
*/
case class SubqueryExec(name: String, child: SparkPlan)
case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int] = None)
extends BaseSubqueryExec with UnaryExecNode {

override lazy val metrics = Map(
Expand All @@ -784,7 +784,11 @@ case class SubqueryExec(name: String, child: SparkPlan)
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
val rows: Array[InternalRow] = if (maxNumRows.isDefined) {
child.executeTake(maxNumRows.get)
} else {
child.executeCollect()
}
val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
Expand All @@ -797,28 +801,45 @@ case class SubqueryExec(name: String, child: SparkPlan)
}

protected override def doCanonicalize(): SparkPlan = {
SubqueryExec("Subquery", child.canonicalized)
SubqueryExec("Subquery", child.canonicalized, maxNumRows)
}

protected override def doPrepare(): Unit = {
relationFuture
}

// `SubqueryExec` should only be used by calling `executeCollect`. It launches a new thread to
// collect the result of `child`. We should not trigger codegen of `child` again in other threads,
// as generating code is not thread-safe.
override def executeCollect(): Array[InternalRow] = {
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
}

protected override def doExecute(): RDD[InternalRow] = {
child.execute()
throw new IllegalStateException("SubqueryExec.doExecute should never be called")
}

override def executeCollect(): Array[InternalRow] = {
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
override def executeTake(n: Int): Array[InternalRow] = {
throw new IllegalStateException("SubqueryExec.executeTake should never be called")
}

override def executeTail(n: Int): Array[InternalRow] = {
throw new IllegalStateException("SubqueryExec.executeTail should never be called")
}

override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]")
override def stringArgs: Iterator[Any] = Iterator(name, child) ++ Iterator(s"[id=#$id]")
}

object SubqueryExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery",
SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD)))

def createForScalarSubquery(name: String, child: SparkPlan): SubqueryExec = {
// Scalar subquery needs only one row. We require 2 rows here to validate if the scalar query is
// invalid(return more than one row). We don't need all the rows as it may OOM.
SubqueryExec(name, child, maxNumRows = Some(2))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ case class ScalarSubquery(
@volatile private var updated: Boolean = false

def updateResult(): Unit = {
// Only return the first two rows as an array to avoid Driver OOM.
val rows = plan.executeTake(2)
val rows = plan.executeCollect()
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n$plan")
}
Expand Down Expand Up @@ -178,7 +177,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
case subquery: expressions.ScalarSubquery =>
val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
ScalarSubquery(
SubqueryExec(s"scalar-subquery#${subquery.exprId.id}", executedPlan),
SubqueryExec.createForScalarSubquery(
s"scalar-subquery#${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
val expr = if (values.length == 1) {
Expand Down