diff --git a/docs/deployment/settings.md b/docs/deployment/settings.md
index 7577aadd763..2f6032f8a75 100644
--- a/docs/deployment/settings.md
+++ b/docs/deployment/settings.md
@@ -180,6 +180,8 @@ kyuubi\.metrics
\.reporters|
PT3H
|Operation will be closed when it's not accessed for this duration of time
|1.0.0
+kyuubi\.operation
\.interrupt\.on\.cancel|true
|When true, all running tasks will be interrupted if one cancels a query. When false, all running tasks will remain until finished.
|1.2.0
+kyuubi\.operation
\.query\.timeout|PT0S
|Set a query duration timeout in seconds in Kyuubi. If the timeout is set to a positive value, a running query will be cancelled automatically if timeout. Otherwise the query continues to run till completion. If timeout values are set for each statement via `java.sql.Statement.setQueryTimeout` and they are smaller than this configuration value, they take precedence. If you set this timeout and prefer to cancel the queries right away without waiting task to finish, consider enabling kyuubi.operation.interrupt.on.cancel together.
|1.2.0
kyuubi\.operation
\.status\.polling
\.timeout|PT5S
|Timeout(ms) for long polling asynchronous running sql query's status
|1.0.0
### Session
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
index f0ecb9ca1fa..1adc44c7f8c 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
@@ -17,24 +17,32 @@
package org.apache.kyuubi.engine.spark.operation
-import java.util.concurrent.RejectedExecutionException
+import java.util.concurrent.{RejectedExecutionException, TimeUnit}
+
+import scala.util.control.NonFatal
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types._
import org.apache.kyuubi.{KyuubiSQLException, Logging}
+import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.{ArrayFetchIterator, KyuubiSparkUtil}
import org.apache.kyuubi.operation.{OperationState, OperationType}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
+import org.apache.kyuubi.util.ThreadUtils
class ExecuteStatement(
spark: SparkSession,
session: Session,
protected override val statement: String,
- override val shouldRunAsync: Boolean)
+ override val shouldRunAsync: Boolean,
+ queryTimeout: Long)
extends SparkOperation(spark, OperationType.EXECUTE_STATEMENT, session) with Logging {
+ private val forceCancel =
+ session.sessionManager.getConf.get(KyuubiConf.OPERATION_FORCE_CANCEL)
+
private val operationLog: OperationLog =
OperationLog.createOperationLog(session.handle, getHandle)
override def getOperationLog: Option[OperationLog] = Option(operationLog)
@@ -63,7 +71,7 @@ class ExecuteStatement(
setState(OperationState.RUNNING)
info(KyuubiSparkUtil.diagnostics(spark))
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
- spark.sparkContext.setJobGroup(statementId, statement)
+ spark.sparkContext.setJobGroup(statementId, statement, forceCancel)
result = spark.sql(statement)
debug(result.queryExecution)
iter = new ArrayFetchIterator(result.collect())
@@ -76,6 +84,7 @@ class ExecuteStatement(
}
override protected def runInternal(): Unit = {
+ addTimeoutMonitor()
if (shouldRunAsync) {
val asyncOperation = new Runnable {
override def run(): Unit = {
@@ -100,4 +109,27 @@ class ExecuteStatement(
executeStatement()
}
}
+
+ private def addTimeoutMonitor(): Unit = {
+ if (queryTimeout > 0) {
+ val timeoutExecutor =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("query-timeout-thread")
+ timeoutExecutor.schedule(new Runnable {
+ override def run(): Unit = {
+ try {
+ if (getStatus.state != OperationState.TIMEOUT) {
+ info(s"Query with $statementId timed out after $queryTimeout seconds")
+ cleanup(OperationState.TIMEOUT)
+ }
+ } catch {
+ case NonFatal(e) =>
+ setOperationException(KyuubiSQLException(e))
+ error(s"Error cancelling the query after timeout: $queryTimeout seconds")
+ } finally {
+ timeoutExecutor.shutdown()
+ }
+ }
+ }, queryTimeout, TimeUnit.SECONDS)
+ }
+ }
}
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
index 38fdc23fb7a..7dbfef4f783 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
@@ -58,7 +58,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
runAsync: Boolean,
queryTimeout: Long): Operation = {
val spark = getSparkSession(session.handle)
- val operation = new ExecuteStatement(spark, session, statement, runAsync)
+ val operation = new ExecuteStatement(spark, session, statement, runAsync, queryTimeout)
addOperation(operation)
}
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/IndividualSparkSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/IndividualSparkSuite.scala
new file mode 100644
index 00000000000..829eaf7e16a
--- /dev/null
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/IndividualSparkSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.spark
+
+import java.sql.{SQLTimeoutException, Statement}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
+
+import org.apache.spark.TaskKilled
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.SparkSession
+import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.kyuubi.KyuubiFunSuite
+import org.apache.kyuubi.config.KyuubiConf
+import org.apache.kyuubi.operation.JDBCTestUtils
+
+class SparkEngineSuites extends KyuubiFunSuite {
+
+ test("Add config to control if cancel invoke interrupt task on engine") {
+ Seq(true, false).foreach { force =>
+ withSparkJdbcStatement(Map(KyuubiConf.OPERATION_FORCE_CANCEL.key -> force.toString)) {
+ case (statement, spark) =>
+ val index = new AtomicInteger(0)
+ val forceCancel = new AtomicBoolean(false)
+ val listener = new SparkListener {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ assert(taskEnd.reason.isInstanceOf[TaskKilled])
+ if (forceCancel.get()) {
+ assert(System.currentTimeMillis() - taskEnd.taskInfo.launchTime < 3000)
+ index.incrementAndGet()
+ } else {
+ assert(System.currentTimeMillis() - taskEnd.taskInfo.launchTime >= 4000)
+ index.incrementAndGet()
+ }
+ }
+ }
+
+ spark.sparkContext.addSparkListener(listener)
+ try {
+ statement.setQueryTimeout(3)
+ forceCancel.set(force)
+ val e1 = intercept[SQLTimeoutException] {
+ statement.execute("select java_method('java.lang.Thread', 'sleep', 5000L)")
+ }.getMessage
+ assert(e1.contains("Query timed out"))
+ eventually(Timeout(30.seconds)) {
+ assert(index.get() == 1)
+ }
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
+ }
+ }
+ }
+ }
+
+ private def withSparkJdbcStatement(
+ conf: Map[String, String] = Map.empty)(
+ statement: (Statement, SparkSession) => Unit): Unit = {
+ val spark = new WithSparkSuite {
+ override def withKyuubiConf: Map[String, String] = conf
+ override protected def jdbcUrl: String = getJdbcUrl
+ }
+ spark.startSparkEngine()
+ val tmp: Statement => Unit = { tmpStatement =>
+ statement(tmpStatement, spark.getSpark)
+ }
+ try {
+ spark.withJdbcStatement()(tmp)
+ } finally {
+ spark.stopSparkEngine()
+ }
+ }
+}
+
+trait WithSparkSuite extends WithSparkSQLEngine with JDBCTestUtils
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala
index 87f759a44e3..605255b6364 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala
@@ -36,7 +36,7 @@ trait WithSparkSQLEngine extends KyuubiFunSuite {
super.beforeAll()
}
- protected def startSparkEngine(): Unit = {
+ def startSparkEngine(): Unit = {
val warehousePath = Utils.createTempDir()
val metastorePath = Utils.createTempDir()
warehousePath.toFile.delete()
@@ -63,7 +63,7 @@ trait WithSparkSQLEngine extends KyuubiFunSuite {
stopSparkEngine()
}
- protected def stopSparkEngine(): Unit = {
+ def stopSparkEngine(): Unit = {
// we need to clean up conf since it's the global config in same jvm.
withKyuubiConf.foreach { case (k, _) =>
System.clearProperty(k)
@@ -83,4 +83,5 @@ trait WithSparkSQLEngine extends KyuubiFunSuite {
}
protected def getJdbcUrl: String = s"jdbc:hive2://$connectionUrl/;"
+ def getSpark: SparkSession = spark
}
diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
index 93bca62c096..dfd8a21de27 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
@@ -502,6 +502,27 @@ object KyuubiConf {
.timeConf
.createWithDefault(Duration.ofSeconds(5).toMillis)
+ val OPERATION_FORCE_CANCEL: ConfigEntry[Boolean] =
+ buildConf("operation.interrupt.on.cancel")
+ .doc("When true, all running tasks will be interrupted if one cancels a query. " +
+ "When false, all running tasks will remain until finished.")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val OPERATION_QUERY_TIMEOUT: ConfigEntry[Long] =
+ buildConf("operation.query.timeout")
+ .doc("Set a query duration timeout in seconds in Kyuubi. If the timeout is set to " +
+ "a positive value, a running query will be cancelled automatically if timeout. " +
+ "Otherwise the query continues to run till completion. If timeout values are " +
+ "set for each statement via `java.sql.Statement.setQueryTimeout` and they are smaller " +
+ "than this configuration value, they take precedence. If you set this timeout and prefer " +
+ "to cancel the queries right away without waiting task to finish, consider enabling " +
+ s"${OPERATION_FORCE_CANCEL.key} together.")
+ .version("1.2.0")
+ .timeConf
+ .createWithDefault(Duration.ZERO.toMillis)
+
val ENGINE_SHARED_LEVEL: ConfigEntry[String] = buildConf("session.engine.share.level")
.doc("The SQL engine App will be shared in different levels, available configs are: " +
" - CONNECTION: the App will not be shared but only used by the current client" +
diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTestUtils.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTestUtils.scala
index 9e75902a479..e5f5846535e 100644
--- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTestUtils.scala
+++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTestUtils.scala
@@ -34,7 +34,7 @@ trait JDBCTestUtils extends KyuubiFunSuite {
protected val patterns = Seq("", "*", "%", null, ".*", "_*", "_%", ".%")
protected def jdbcUrl: String
- protected def withMultipleConnectionJdbcStatement(
+ def withMultipleConnectionJdbcStatement(
tableNames: String*)(fs: (Statement => Unit)*): Unit = {
val connections = fs.map { _ => DriverManager.getConnection(jdbcUrl, user, "") }
val statements = connections.map(_.createStatement())
@@ -57,7 +57,7 @@ trait JDBCTestUtils extends KyuubiFunSuite {
}
}
- protected def withDatabases(dbNames: String*)(fs: (Statement => Unit)*): Unit = {
+ def withDatabases(dbNames: String*)(fs: (Statement => Unit)*): Unit = {
val connections = fs.map { _ => DriverManager.getConnection(jdbcUrl, user, "") }
val statements = connections.map(_.createStatement())
@@ -75,11 +75,11 @@ trait JDBCTestUtils extends KyuubiFunSuite {
}
}
- protected def withJdbcStatement(tableNames: String*)(f: Statement => Unit): Unit = {
+ def withJdbcStatement(tableNames: String*)(f: Statement => Unit): Unit = {
withMultipleConnectionJdbcStatement(tableNames: _*)(f)
}
- protected def withThriftClient(f: TCLIService.Iface => Unit): Unit = {
+ def withThriftClient(f: TCLIService.Iface => Unit): Unit = {
val hostAndPort = jdbcUrl.stripPrefix("jdbc:hive2://").split("/;").head.split(":")
val host = hostAndPort.head
val port = hostAndPort(1).toInt
@@ -96,7 +96,7 @@ trait JDBCTestUtils extends KyuubiFunSuite {
}
}
- protected def withSessionHandle(f: (TCLIService.Iface, TSessionHandle) => Unit): Unit = {
+ def withSessionHandle(f: (TCLIService.Iface, TSessionHandle) => Unit): Unit = {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername(user)
@@ -117,7 +117,7 @@ trait JDBCTestUtils extends KyuubiFunSuite {
}
}
- protected def checkGetSchemas(
+ def checkGetSchemas(
rs: ResultSet, dbNames: Seq[String], catalogName: String = ""): Unit = {
var count = 0
while(rs.next()) {
diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTests.scala
index e2e842fc763..10b1fa03c14 100644
--- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTests.scala
+++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/JDBCTests.scala
@@ -17,7 +17,7 @@
package org.apache.kyuubi.operation
-import java.sql.{Date, SQLException, Timestamp}
+import java.sql.{Date, SQLException, SQLTimeoutException, Timestamp}
import scala.collection.JavaConverters._
@@ -327,4 +327,26 @@ trait JDBCTests extends BasicJDBCTests {
assert(metaData.getScale(1) === 0)
}
}
+
+ test("Support query auto timeout cancel on thriftserver - setQueryTimeout") {
+ withJdbcStatement() { statement =>
+ statement.setQueryTimeout(1)
+ val e = intercept[SQLTimeoutException] {
+ statement.execute("select java_method('java.lang.Thread', 'sleep', 10000L)")
+ }.getMessage
+ assert(e.contains("Query timed out after"))
+
+ statement.setQueryTimeout(0)
+ val rs1 = statement.executeQuery(
+ "select 'test', java_method('java.lang.Thread', 'sleep', 3000L)")
+ rs1.next()
+ assert(rs1.getString(1) == "test")
+
+ statement.setQueryTimeout(-1)
+ val rs2 = statement.executeQuery(
+ "select 'test', java_method('java.lang.Thread', 'sleep', 3000L)")
+ rs2.next()
+ assert(rs2.getString(1) == "test")
+ }
+ }
}
diff --git a/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/ExecuteStatement.scala b/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/ExecuteStatement.scala
index c2f3d2ae471..0a4700cbbb3 100644
--- a/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/ExecuteStatement.scala
+++ b/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/ExecuteStatement.scala
@@ -33,7 +33,8 @@ class ExecuteStatement(
client: TCLIService.Iface,
remoteSessionHandle: TSessionHandle,
override val statement: String,
- override val shouldRunAsync: Boolean)
+ override val shouldRunAsync: Boolean,
+ queryTimeout: Long)
extends KyuubiOperation(
OperationType.EXECUTE_STATEMENT, session, client, remoteSessionHandle) {
@@ -71,6 +72,7 @@ class ExecuteStatement(
val req = new TExecuteStatementReq(remoteSessionHandle, statement)
req.setRunAsync(shouldRunAsync)
+ req.setQueryTimeout(queryTimeout)
val resp = client.ExecuteStatement(req)
verifyTStatus(resp.getStatus)
_remoteOpHandle = resp.getOperationHandle
diff --git a/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/KyuubiOperationManager.scala b/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/KyuubiOperationManager.scala
index 2c301d2ef72..52c3ad4e9ca 100644
--- a/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/KyuubiOperationManager.scala
+++ b/kyuubi-main/src/main/scala/org/apache/kyuubi/operation/KyuubiOperationManager.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
import org.apache.hive.service.rpc.thrift.{TCLIService, TFetchResultsReq, TRowSet, TSessionHandle}
import org.apache.kyuubi.KyuubiSQLException
+import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation
import org.apache.kyuubi.session.{Session, SessionHandle}
import org.apache.kyuubi.util.ThriftUtils
@@ -49,6 +50,18 @@ class KyuubiOperationManager private (name: String) extends OperationManager(nam
tSessionHandle
}
+ private def getQueryTimeout(clientQueryTimeout: Long): Long = {
+ // If clientQueryTimeout is smaller than systemQueryTimeout value,
+ // we use the clientQueryTimeout value.
+ val systemQueryTimeout = getConf.get(KyuubiConf.OPERATION_QUERY_TIMEOUT)
+ if (clientQueryTimeout > 0 &&
+ (systemQueryTimeout <= 0 || clientQueryTimeout < systemQueryTimeout)) {
+ clientQueryTimeout
+ } else {
+ systemQueryTimeout
+ }
+ }
+
def setConnection(
sessionHandle: SessionHandle,
client: TCLIService.Iface,
@@ -69,9 +82,9 @@ class KyuubiOperationManager private (name: String) extends OperationManager(nam
queryTimeout: Long): Operation = {
val client = getThriftClient(session.handle)
val remoteSessionHandle = getRemoteTSessionHandle(session.handle)
- val operation = new ExecuteStatement(session, client, remoteSessionHandle, statement, runAsync)
+ val operation = new ExecuteStatement(session, client, remoteSessionHandle, statement, runAsync,
+ getQueryTimeout(queryTimeout))
addOperation(operation)
-
}
override def newGetTypeInfoOperation(session: Session): Operation = {
@@ -143,7 +156,6 @@ class KyuubiOperationManager private (name: String) extends OperationManager(nam
addOperation(operation)
}
-
override def getOperationLogRowSet(
opHandle: OperationHandle,
order: FetchOrientation, maxRows: Int): TRowSet = {