diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ef974dc176e51..75a510d389cb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -926,13 +926,23 @@ object SQLConf { .booleanConf .createWithDefault(false) + val THRIFTSERVER_FORCE_CANCEL = + buildConf("spark.sql.thriftServer.interruptOnCancel") + .doc("When true, all running tasks will be interrupted if one cancels a query. " + + "When false, all running tasks will remain until finished.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val THRIFTSERVER_QUERY_TIMEOUT = buildConf("spark.sql.thriftServer.queryTimeout") .doc("Set a query duration timeout in seconds in Thrift Server. If the timeout is set to " + "a positive value, a running query will be cancelled automatically when the timeout is " + "exceeded, otherwise the query continues to run till completion. If timeout values are " + "set for each statement via `java.sql.Statement.setQueryTimeout` and they are smaller " + - "than this configuration value, they take precedence.") + "than this configuration value, they take precedence. If you set this timeout and prefer" + + "to cancel the queries right away without waiting task to finish, consider enabling" + + s"${THRIFTSERVER_FORCE_CANCEL.key} together.") .version("3.1.0") .timeConf(TimeUnit.SECONDS) .createWithDefault(0L) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f7a4be9591818..0ced58b1b6d16 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -63,6 +63,8 @@ private[hive] class SparkExecuteStatementOperation( } } + private val forceCancel = sqlContext.conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL) + private val substitutorStatement = SQLConf.withExistingConf(sqlContext.conf) { new VariableSubstitution().substitute(statement) } @@ -131,7 +133,7 @@ private[hive] class SparkExecuteStatementOperation( def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = withLocalProperties { try { - sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) + sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement, forceCancel) getNextRowSetInternal(order, maxRowsL) } finally { sqlContext.sparkContext.clearJobGroup() @@ -321,7 +323,7 @@ private[hive] class SparkExecuteStatementOperation( parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader) } - sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement) + sqlContext.sparkContext.setJobGroup(statementId, substitutorStatement, forceCancel) result = sqlContext.sql(statement) logDebug(result.queryExecution.toString()) HiveThriftServer2.eventManager.onStatementParsed(statementId, diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index fd3a638c4fa44..036eb5850695e 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -18,9 +18,14 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.SQLException +import java.util.concurrent.atomic.AtomicBoolean import org.apache.hive.service.cli.HiveSQLException +import org.apache.spark.TaskKilled +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.internal.SQLConf + trait ThriftServerWithSparkContextSuite extends SharedThriftServer { test("the scratch dir will be deleted during server start but recreated with new operation") { @@ -79,6 +84,38 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { "java.lang.NumberFormatException: invalid input syntax for type numeric: 1.2")) } } + + test("SPARK-33526: Add config to control if cancel invoke interrupt task on thriftserver") { + withJdbcStatement { statement => + val forceCancel = new AtomicBoolean(false) + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + assert(taskEnd.reason.isInstanceOf[TaskKilled]) + if (forceCancel.get()) { + assert(System.currentTimeMillis() - taskEnd.taskInfo.launchTime < 1000) + } else { + // avoid accuracy, we check 2s instead of 3s. + assert(System.currentTimeMillis() - taskEnd.taskInfo.launchTime >= 2000) + } + } + } + + spark.sparkContext.addSparkListener(listener) + try { + statement.execute(s"SET ${SQLConf.THRIFTSERVER_QUERY_TIMEOUT.key}=1") + Seq(true, false).foreach { force => + statement.execute(s"SET ${SQLConf.THRIFTSERVER_FORCE_CANCEL.key}=$force") + forceCancel.set(force) + val e1 = intercept[SQLException] { + statement.execute("select java_method('java.lang.Thread', 'sleep', 3000L)") + }.getMessage + assert(e1.contains("Query timed out")) + } + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } }