diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f6e8a5694dbd..2c639769e3b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2765,6 +2765,11 @@ object SparkContext extends Logging { private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" + /** + * Statement id is only used for thrift server + */ + private[spark] val SPARK_STATEMENT_ID = "spark.statement.id" + /** * Executor id for the driver. In earlier versions of Spark, this was ``, but this was * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 0c5fee20385e..afd4c93065b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -24,7 +24,7 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, SparkException} +import org.apache.spark.{broadcast, SparkContext, SparkException} import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -74,7 +74,9 @@ case class BroadcastExchangeExec( child: SparkPlan) extends BroadcastExchangeLike { import BroadcastExchangeExec._ - override val runId: UUID = UUID.randomUUID + // runId must be a UUID. We set it to statementId if defined. + override val runId: UUID = Option(sparkContext.getLocalProperty(SparkContext.SPARK_STATEMENT_ID)) + .map(UUID.fromString).getOrElse(UUID.randomUUID) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 8ca0ab91a73f..b68dd37be1f2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -31,6 +31,7 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLContext} import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, toHiveString, TimeFormatters} @@ -131,6 +132,7 @@ private[hive] class SparkExecuteStatementOperation( getNextRowSetInternal(order, maxRowsL) } finally { sqlContext.sparkContext.clearJobGroup() + clearStatementId() } } @@ -285,7 +287,7 @@ private[hive] class SparkExecuteStatementOperation( if (!runInBackground) { parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader) } - + setStatementId() sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement, forceCancel) result = sqlContext.sql(statement) logDebug(result.queryExecution.toString()) @@ -332,6 +334,7 @@ private[hive] class SparkExecuteStatementOperation( } } sqlContext.sparkContext.clearJobGroup() + clearStatementId() } } @@ -369,6 +372,14 @@ private[hive] class SparkExecuteStatementOperation( sqlContext.sparkContext.cancelJobGroup(statementId) } } + + private def setStatementId(): Unit = { + sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_STATEMENT_ID, statementId) + } + + private def clearStatementId(): Unit = { + sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_STATEMENT_ID, null) + } } object SparkExecuteStatementOperation {