diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index f8478f860b2d5..cd0503fb8a147 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -120,7 +120,8 @@ case class InsertAdaptiveSparkPlan( if !subqueryMap.contains(exprId.id) => val executedPlan = compileSubquery(p) verifyAdaptivePlan(executedPlan, p) - val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan) + val subquery = SubqueryExec.createForScalarSubquery( + s"subquery#${exprId.id}", executedPlan) subqueryMap.put(exprId.id, subquery) case expressions.InSubquery(_, ListQuery(query, _, exprId, _)) if !subqueryMap.contains(exprId.id) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 80a4090ce03f3..fcf77e588fc60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -765,7 +765,7 @@ abstract class BaseSubqueryExec extends SparkPlan { /** * Physical plan for a subquery. */ -case class SubqueryExec(name: String, child: SparkPlan) +case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int] = None) extends BaseSubqueryExec with UnaryExecNode { override lazy val metrics = Map( @@ -784,7 +784,11 @@ case class SubqueryExec(name: String, child: SparkPlan) SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types - val rows: Array[InternalRow] = child.executeCollect() + val rows: Array[InternalRow] = if (maxNumRows.isDefined) { + child.executeTake(maxNumRows.get) + } else { + child.executeCollect() + } val beforeBuild = System.nanoTime() longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum @@ -797,28 +801,45 @@ case class SubqueryExec(name: String, child: SparkPlan) } protected override def doCanonicalize(): SparkPlan = { - SubqueryExec("Subquery", child.canonicalized) + SubqueryExec("Subquery", child.canonicalized, maxNumRows) } protected override def doPrepare(): Unit = { relationFuture } + // `SubqueryExec` should only be used by calling `executeCollect`. It launches a new thread to + // collect the result of `child`. We should not trigger codegen of `child` again in other threads, + // as generating code is not thread-safe. + override def executeCollect(): Array[InternalRow] = { + ThreadUtils.awaitResult(relationFuture, Duration.Inf) + } + protected override def doExecute(): RDD[InternalRow] = { - child.execute() + throw new IllegalStateException("SubqueryExec.doExecute should never be called") } - override def executeCollect(): Array[InternalRow] = { - ThreadUtils.awaitResult(relationFuture, Duration.Inf) + override def executeTake(n: Int): Array[InternalRow] = { + throw new IllegalStateException("SubqueryExec.executeTake should never be called") + } + + override def executeTail(n: Int): Array[InternalRow] = { + throw new IllegalStateException("SubqueryExec.executeTail should never be called") } - override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]") + override def stringArgs: Iterator[Any] = Iterator(name, child) ++ Iterator(s"[id=#$id]") } object SubqueryExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("subquery", SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD))) + + def createForScalarSubquery(name: String, child: SparkPlan): SubqueryExec = { + // Scalar subquery needs only one row. We require 2 rows here to validate if the scalar query is + // invalid(return more than one row). We don't need all the rows as it may OOM. + SubqueryExec(name, child, maxNumRows = Some(2)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 5e222d2e48769..0080b73575de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -80,8 +80,7 @@ case class ScalarSubquery( @volatile private var updated: Boolean = false def updateResult(): Unit = { - // Only return the first two rows as an array to avoid Driver OOM. - val rows = plan.executeTake(2) + val rows = plan.executeCollect() if (rows.length > 1) { sys.error(s"more than one row returned by a subquery used as an expression:\n$plan") } @@ -178,7 +177,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { case subquery: expressions.ScalarSubquery => val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan) ScalarSubquery( - SubqueryExec(s"scalar-subquery#${subquery.exprId.id}", executedPlan), + SubqueryExec.createForScalarSubquery( + s"scalar-subquery#${subquery.exprId.id}", executedPlan), subquery.exprId) case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) => val expr = if (values.length == 1) {