diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 96ed593e72ff..ea2bbe0093fc 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -195,11 +195,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends val responseObserver = executeHolder.responseObserver val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(executeHolder.sessionHolder) - planner.process( - command = command, - responseObserver = responseObserver, - executeHolder = executeHolder) + val planner = new SparkConnectPlanner(executeHolder) + planner.process(command = command, responseObserver = responseObserver) } private def requestString(request: Message) = { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 869608a9ab90..f8b67add879e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -56,7 +56,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) throw new IllegalStateException( s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.") } - val planner = new SparkConnectPlanner(sessionHolder) + val planner = new SparkConnectPlanner(executeHolder) val tracker = executeHolder.eventsManager.createQueryPlanningTracker val dataframe = Dataset.ofRows( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index eead5cb38ad8..fa964c02a253 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -86,7 +86,18 @@ final case class InvalidCommandInput( private val cause: Throwable = null) extends Exception(message, cause) -class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { +class SparkConnectPlanner( + val sessionHolder: SessionHolder, + val executeHolderOpt: Option[ExecuteHolder] = None) + extends Logging { + + def this(executeHolder: ExecuteHolder) = { + this(executeHolder.sessionHolder, Some(executeHolder)) + } + + if (!executeHolderOpt.forall { e => e.sessionHolder == sessionHolder }) { + throw new IllegalArgumentException("executeHolder does not belong to sessionHolder") + } private[connect] def session: SparkSession = sessionHolder.session @@ -94,6 +105,10 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private[connect] def sessionId: String = sessionHolder.sessionId + lazy val executeHolder = executeHolderOpt.getOrElse { + throw new IllegalArgumentException("executeHolder is not set") + } + private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) @@ -2461,48 +2476,39 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def process( command: proto.Command, - responseObserver: StreamObserver[ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => - handleRegisterUserDefinedFunction(command.getRegisterFunction, executeHolder) + handleRegisterUserDefinedFunction(command.getRegisterFunction) case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION => - handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction, executeHolder) + handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction) case proto.Command.CommandTypeCase.WRITE_OPERATION => - handleWriteOperation(command.getWriteOperation, executeHolder) + handleWriteOperation(command.getWriteOperation) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => - handleCreateViewCommand(command.getCreateDataframeView, executeHolder) + handleCreateViewCommand(command.getCreateDataframeView) case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => - handleWriteOperationV2(command.getWriteOperationV2, executeHolder) + handleWriteOperationV2(command.getWriteOperationV2) case proto.Command.CommandTypeCase.EXTENSION => - handleCommandPlugin(command.getExtension, executeHolder) + handleCommandPlugin(command.getExtension) case proto.Command.CommandTypeCase.SQL_COMMAND => - handleSqlCommand(command.getSqlCommand, responseObserver, executeHolder) + handleSqlCommand(command.getSqlCommand, responseObserver) case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START => - handleWriteStreamOperationStart( - command.getWriteStreamOperationStart, - responseObserver, - executeHolder) + handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver) case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND => - handleStreamingQueryCommand( - command.getStreamingQueryCommand, - responseObserver, - executeHolder) + handleStreamingQueryCommand(command.getStreamingQueryCommand, responseObserver) case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND => handleStreamingQueryManagerCommand( command.getStreamingQueryManagerCommand, - responseObserver, - executeHolder) + responseObserver) case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND => - handleGetResourcesCommand(responseObserver, executeHolder) + handleGetResourcesCommand(responseObserver) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } def handleSqlCommand( getSqlCommand: SqlCommand, - responseObserver: StreamObserver[ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { // Eagerly execute commands of the provided SQL string. val args = getSqlCommand.getArgsMap val namedArguments = getSqlCommand.getNamedArgumentsMap @@ -2600,8 +2606,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterUserDefinedFunction( - fun: proto.CommonInlineUserDefinedFunction, - executeHolder: ExecuteHolder): Unit = { + fun: proto.CommonInlineUserDefinedFunction): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => handleRegisterPythonUDF(fun) @@ -2617,8 +2622,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterUserDefinedTableFunction( - fun: proto.CommonInlineUserDefinedTableFunction, - executeHolder: ExecuteHolder): Unit = { + fun: proto.CommonInlineUserDefinedTableFunction): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF => val function = createPythonUserDefinedTableFunction(fun) @@ -2685,7 +2689,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { session.udf.register(fun.getFunctionName, udf) } - private def handleCommandPlugin(extension: ProtoAny, executeHolder: ExecuteHolder): Unit = { + private def handleCommandPlugin(extension: ProtoAny): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. .view @@ -2698,9 +2702,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { executeHolder.eventsManager.postFinished() } - private def handleCreateViewCommand( - createView: proto.CreateDataFrameViewCommand, - executeHolder: ExecuteHolder): Unit = { + private def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = { val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView val tableIdentifier = @@ -2736,9 +2738,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - private def handleWriteOperation( - writeOperation: proto.WriteOperation, - executeHolder: ExecuteHolder): Unit = { + private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. @@ -2810,9 +2810,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - def handleWriteOperationV2( - writeOperation: proto.WriteOperationV2, - executeHolder: ExecuteHolder): Unit = { + def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. @@ -2873,8 +2871,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleWriteStreamOperationStart( writeOp: WriteStreamOperationStart, - responseObserver: StreamObserver[ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val plan = transformRelation(writeOp.getInput) val tracker = executeHolder.eventsManager.createQueryPlanningTracker val dataset = Dataset.ofRows(session, plan, tracker) @@ -2999,8 +2996,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryCommand( command: StreamingQueryCommand, - responseObserver: StreamObserver[ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val id = command.getQueryId.getId val runId = command.getQueryId.getRunId @@ -3177,8 +3173,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryManagerCommand( command: StreamingQueryManagerCommand, - responseObserver: StreamObserver[ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val respBuilder = StreamingQueryManagerCommandResult.newBuilder() command.getCommandCase match { @@ -3257,8 +3252,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } def handleGetResourcesCommand( - responseObserver: StreamObserver[proto.ExecutePlanResponse], - executeHolder: ExecuteHolder): Unit = { + responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { executeHolder.eventsManager.postFinished() responseObserver.onNext( proto.ExecutePlanResponse diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index eb84dfc4e3df..dfada825df47 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -58,8 +58,8 @@ trait SparkConnectPlanTest extends SharedSparkSession { def transform(cmd: proto.Command): Unit = { val executeHolder = buildExecutePlanHolder(cmd) - new SparkConnectPlanner(executeHolder.sessionHolder) - .process(cmd, new MockObserver(), executeHolder) + new SparkConnectPlanner(executeHolder) + .process(cmd, new MockObserver()) } def readRel: proto.Relation = @@ -148,7 +148,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Simple Limit") { assertThrows[IndexOutOfBoundsException] { - new SparkConnectPlanner(None.orNull) + new SparkConnectPlanner(SessionHolder.forTesting(None.orNull)) .transformRelation( proto.Relation.newBuilder .setLimit(proto.Limit.newBuilder.setLimit(10)) @@ -159,10 +159,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("InvalidInputs") { // No Relation Set intercept[IndexOutOfBoundsException]( - new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build())) + new SparkConnectPlanner(SessionHolder.forTesting(None.orNull)) + .transformRelation(proto.Relation.newBuilder().build())) intercept[InvalidPlanInput]( - new SparkConnectPlanner(None.orNull) + new SparkConnectPlanner(SessionHolder.forTesting(None.orNull)) .transformRelation( proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build())) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index fdb903237941..ea9ae3ed9d9c 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -196,8 +196,8 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build() val executeHolder = buildExecutePlanHolder(plan) - new SparkConnectPlanner(executeHolder.sessionHolder) - .process(plan, new MockObserver(), executeHolder) + new SparkConnectPlanner(executeHolder) + .process(plan, new MockObserver()) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } }