Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,24 @@
],
"sqlState" : "22023"
},
"INVALID_HANDLE" : {
"message" : [
"The handle <handle> is invalid."
],
"subClass" : {
"ALREADY_EXISTS" : {
"message" : [
"Handle already exists."
]
},
"FORMAT" : {
"message" : [
"Handle has invalid format. Handle must an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'"
]
}
},
"sqlState" : "HY000"
},
"INVALID_HIVE_COLUMN_NAME" : {
"message" : [
"Cannot create the table <tableName> having the nested column <columnName> whose name contains invalid characters <invalidChars> in Hive metastore."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,40 @@ class SparkSession private[sql] (
/**
* Interrupt all operations of this session currently running on the connected server.
*
* TODO/WIP: Currently it will interrupt the Spark Jobs running on the server, triggered from
* ExecutePlan requests. If an operation is not running a Spark Job, it becomes an noop and the
* operation will continue afterwards, possibly with more Spark Jobs.
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptAll(): Unit = {
client.interruptAll()
def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}

/**
* Interrupt all operations of this session with the given operation tag.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}

/**
* Interrupt an operation of this session with the given operationId.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possiblility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}

/**
Expand All @@ -641,6 +667,50 @@ class SparkSession private[sql] (
allocator.close()
SparkSession.onSessionClose(this)
}

/**
* Add a tag to be assigned to all the operations started by this thread in this session.
*
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def addTag(tag: String): Unit = {
client.addTag(tag)
}

/**
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
*
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def removeTag(tag: String): Unit = {
client.removeTag(tag)
}

/**
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
*
* @since 3.5.0
*/
def getTags(): Set[String] = {
client.getTags()
}

/**
* Clear the current thread's operation tags.
*
* @since 3.5.0
*/
def clearTags(): Unit = {
client.clearTags()
}
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ import java.net.URI
import java.util.UUID
import java.util.concurrent.Executor

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.google.protobuf.ByteString
import io.grpc._

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
Expand Down Expand Up @@ -76,6 +80,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.build()
bstub.executePlan(request)
}
Expand Down Expand Up @@ -195,6 +200,59 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}

private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
val builder = proto.InterruptRequest.newBuilder()
val request = builder
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
.setOperationTag(tag)
.build()
bstub.interrupt(request)
}

private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
val builder = proto.InterruptRequest.newBuilder()
val request = builder
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
.setOperationId(id)
.build()
bstub.interrupt(request)
}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
override def childValue(parent: mutable.Set[String]): mutable.Set[String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
// those of the children threads.
parent.clone()
}
override protected def initialValue(): mutable.Set[String] = new mutable.HashSet[String]()
}

private[sql] def addTag(tag: String): Unit = {
// validation is also done server side, but this will give error earlier.
ProtoUtils.throwIfInvalidTag(tag)
tags.get += tag
}

private[sql] def removeTag(tag: String): Unit = {
// validation is also done server side, but this will give error earlier.
ProtoUtils.throwIfInvalidTag(tag)
tags.get.remove(tag)
}

private[sql] def getTags(): Set[String] = {
tags.get.toSet
}

private[sql] def clearTags(): Unit = {
tags.get.clear()
}

def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ private[sql] class SparkResult[T](
extends AutoCloseable
with Cleanable { self =>

private[this] var opId: String = _
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
private[this] var arrowSchema: pojo.Schema = _
Expand Down Expand Up @@ -72,13 +73,28 @@ private[sql] class SparkResult[T](
}

private def processResponses(
stopOnOperationId: Boolean = false,
stopOnSchema: Boolean = false,
stopOnArrowSchema: Boolean = false,
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
var nonEmpty = false
var stop = false
while (!stop && responses.hasNext) {
val response = responses.next()

// Save and validate operationId
if (opId == null) {
opId = response.getOperationId
}
if (opId != response.getOperationId) {
// backwards compatibility:
// response from an old server without operationId field would have getOperationId == "".
throw new IllegalStateException(
"Received response with wrong operationId. " +
s"Expected '$opId' but received '${response.getOperationId}'.")
}
stop |= stopOnOperationId

if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
Expand Down Expand Up @@ -148,6 +164,17 @@ private[sql] class SparkResult[T](
structType
}

/**
* @return
* the operationId of the result.
*/
def operationId: String = {
if (opId == null) {
processResponses(stopOnOperationId = true)
}
opId
}

/**
* Create an Array with the contents of the result.
*/
Expand Down
Loading