Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
val responseObserver = executeHolder.responseObserver

val command = request.getPlan.getCommand
val planner = new SparkConnectPlanner(executeHolder.sessionHolder)
planner.process(
command = command,
responseObserver = responseObserver,
executeHolder = executeHolder)
val planner = new SparkConnectPlanner(executeHolder)
planner.process(command = command, responseObserver = responseObserver)
}

private def requestString(request: Message) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
throw new IllegalStateException(
s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.")
}
val planner = new SparkConnectPlanner(sessionHolder)
val planner = new SparkConnectPlanner(executeHolder)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker
val dataframe =
Dataset.ofRows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,29 @@ final case class InvalidCommandInput(
private val cause: Throwable = null)
extends Exception(message, cause)

class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
class SparkConnectPlanner(
val sessionHolder: SessionHolder,
val executeHolderOpt: Option[ExecuteHolder] = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just pass in the ExecuteHolder, and create a getter for SessionHolder? IMO it is a bit weird to create this structure for just a single unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used in AnalyzePlan that doesn't have an ExecuteHolder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense.

extends Logging {

def this(executeHolder: ExecuteHolder) = {
this(executeHolder.sessionHolder, Some(executeHolder))
}

if (!executeHolderOpt.forall { e => e.sessionHolder == sessionHolder }) {
throw new IllegalArgumentException("executeHolder does not belong to sessionHolder")
}

private[connect] def session: SparkSession = sessionHolder.session

private[connect] def userId: String = sessionHolder.userId

private[connect] def sessionId: String = sessionHolder.sessionId

lazy val executeHolder = executeHolderOpt.getOrElse {
throw new IllegalArgumentException("executeHolder is not set")
}

private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

Expand Down Expand Up @@ -2461,48 +2476,39 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {

def process(
command: proto.Command,
responseObserver: StreamObserver[ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
Comment on lines 2477 to +2479
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, this shouldn't be an issue because it's not really a public API. but it's not backwards compatible.

This is particularly interesting because you went through some length to fix this problem in the constructor though.

@hvanhovell Do we consider the SparkConenct planner to be public? Probably, not correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, between Spark 3.4 and Spark 3.5, these interfaces also changed.
For example #41618 made it take a SessionHolder instead of SparkSession, so that various parameters don't have to be passed to it lossely. This PR adding ExecuteHolder to it is in a very similar spirit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think we should. In the end that will hamper our ability to evolve the interface.

command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
handleRegisterUserDefinedFunction(command.getRegisterFunction, executeHolder)
handleRegisterUserDefinedFunction(command.getRegisterFunction)
case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION =>
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction, executeHolder)
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation, executeHolder)
handleWriteOperation(command.getWriteOperation)
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
handleCreateViewCommand(command.getCreateDataframeView, executeHolder)
handleCreateViewCommand(command.getCreateDataframeView)
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
handleWriteOperationV2(command.getWriteOperationV2, executeHolder)
handleWriteOperationV2(command.getWriteOperationV2)
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension, executeHolder)
handleCommandPlugin(command.getExtension)
case proto.Command.CommandTypeCase.SQL_COMMAND =>
handleSqlCommand(command.getSqlCommand, responseObserver, executeHolder)
handleSqlCommand(command.getSqlCommand, responseObserver)
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
handleWriteStreamOperationStart(
command.getWriteStreamOperationStart,
responseObserver,
executeHolder)
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
handleStreamingQueryCommand(
command.getStreamingQueryCommand,
responseObserver,
executeHolder)
handleStreamingQueryCommand(command.getStreamingQueryCommand, responseObserver)
case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND =>
handleStreamingQueryManagerCommand(
command.getStreamingQueryManagerCommand,
responseObserver,
executeHolder)
responseObserver)
case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
handleGetResourcesCommand(responseObserver, executeHolder)
handleGetResourcesCommand(responseObserver)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}

def handleSqlCommand(
getSqlCommand: SqlCommand,
responseObserver: StreamObserver[ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
// Eagerly execute commands of the provided SQL string.
val args = getSqlCommand.getArgsMap
val namedArguments = getSqlCommand.getNamedArgumentsMap
Expand Down Expand Up @@ -2600,8 +2606,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

private def handleRegisterUserDefinedFunction(
fun: proto.CommonInlineUserDefinedFunction,
executeHolder: ExecuteHolder): Unit = {
fun: proto.CommonInlineUserDefinedFunction): Unit = {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
handleRegisterPythonUDF(fun)
Expand All @@ -2617,8 +2622,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

private def handleRegisterUserDefinedTableFunction(
fun: proto.CommonInlineUserDefinedTableFunction,
executeHolder: ExecuteHolder): Unit = {
fun: proto.CommonInlineUserDefinedTableFunction): Unit = {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF =>
val function = createPythonUserDefinedTableFunction(fun)
Expand Down Expand Up @@ -2685,7 +2689,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
session.udf.register(fun.getFunctionName, udf)
}

private def handleCommandPlugin(extension: ProtoAny, executeHolder: ExecuteHolder): Unit = {
private def handleCommandPlugin(extension: ProtoAny): Unit = {
SparkConnectPluginRegistry.commandRegistry
// Lazily traverse the collection.
.view
Expand All @@ -2698,9 +2702,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
executeHolder.eventsManager.postFinished()
}

private def handleCreateViewCommand(
createView: proto.CreateDataFrameViewCommand,
executeHolder: ExecuteHolder): Unit = {
private def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = {
val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView

val tableIdentifier =
Expand Down Expand Up @@ -2736,9 +2738,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
*
* @param writeOperation
*/
private def handleWriteOperation(
writeOperation: proto.WriteOperation,
executeHolder: ExecuteHolder): Unit = {
private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = {
// Transform the input plan into the logical plan.
val plan = transformRelation(writeOperation.getInput)
// And create a Dataset from the plan.
Expand Down Expand Up @@ -2810,9 +2810,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
*
* @param writeOperation
*/
def handleWriteOperationV2(
writeOperation: proto.WriteOperationV2,
executeHolder: ExecuteHolder): Unit = {
def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
// Transform the input plan into the logical plan.
val plan = transformRelation(writeOperation.getInput)
// And create a Dataset from the plan.
Expand Down Expand Up @@ -2873,8 +2871,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {

def handleWriteStreamOperationStart(
writeOp: WriteStreamOperationStart,
responseObserver: StreamObserver[ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val plan = transformRelation(writeOp.getInput)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker
val dataset = Dataset.ofRows(session, plan, tracker)
Expand Down Expand Up @@ -2999,8 +2996,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {

def handleStreamingQueryCommand(
command: StreamingQueryCommand,
responseObserver: StreamObserver[ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {

val id = command.getQueryId.getId
val runId = command.getQueryId.getRunId
Expand Down Expand Up @@ -3177,8 +3173,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {

def handleStreamingQueryManagerCommand(
command: StreamingQueryManagerCommand,
responseObserver: StreamObserver[ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val respBuilder = StreamingQueryManagerCommandResult.newBuilder()

command.getCommandCase match {
Expand Down Expand Up @@ -3257,8 +3252,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

def handleGetResourcesCommand(
responseObserver: StreamObserver[proto.ExecutePlanResponse],
executeHolder: ExecuteHolder): Unit = {
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
executeHolder.eventsManager.postFinished()
responseObserver.onNext(
proto.ExecutePlanResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ trait SparkConnectPlanTest extends SharedSparkSession {

def transform(cmd: proto.Command): Unit = {
val executeHolder = buildExecutePlanHolder(cmd)
new SparkConnectPlanner(executeHolder.sessionHolder)
.process(cmd, new MockObserver(), executeHolder)
new SparkConnectPlanner(executeHolder)
.process(cmd, new MockObserver())
}

def readRel: proto.Relation =
Expand Down Expand Up @@ -148,7 +148,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

test("Simple Limit") {
assertThrows[IndexOutOfBoundsException] {
new SparkConnectPlanner(None.orNull)
new SparkConnectPlanner(SessionHolder.forTesting(None.orNull))
.transformRelation(
proto.Relation.newBuilder
.setLimit(proto.Limit.newBuilder.setLimit(10))
Expand All @@ -159,10 +159,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
test("InvalidInputs") {
// No Relation Set
intercept[IndexOutOfBoundsException](
new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build()))
new SparkConnectPlanner(SessionHolder.forTesting(None.orNull))
.transformRelation(proto.Relation.newBuilder().build()))

intercept[InvalidPlanInput](
new SparkConnectPlanner(None.orNull)
new SparkConnectPlanner(SessionHolder.forTesting(None.orNull))
.transformRelation(
proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
.build()

val executeHolder = buildExecutePlanHolder(plan)
new SparkConnectPlanner(executeHolder.sessionHolder)
.process(plan, new MockObserver(), executeHolder)
new SparkConnectPlanner(executeHolder)
.process(plan, new MockObserver())
assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin"))
}
}
Expand Down