diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4d55639f876b..a887a9d8199d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -79,7 +79,10 @@ final case class InvalidCommandInput( private val cause: Throwable = null) extends Exception(message, cause) -class SparkConnectPlanner(val session: SparkSession) extends Logging { +class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { + + def session: SparkSession = sessionHolder.session + private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index ca7fa0d42c54..cc2327abb5cd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -61,3 +61,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } } + +object SessionHolder { + + /** Creates a dummy session holder for use in tests. */ + def forTesting(session: SparkSession): SessionHolder = { + SessionHolder(userId = "testUser", sessionId = UUID.randomUUID().toString, session = session) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 034644f5765a..947f6ebbebeb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,7 +23,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} import org.apache.spark.sql.connect.planner.SparkConnectPlanner @@ -35,12 +35,10 @@ private[connect] class SparkConnectAnalyzeHandler( def handle(request: proto.AnalyzePlanRequest): Unit = SparkConnectArtifactManager.withArtifactClassLoader { - val session = - SparkConnectService - .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) - .session - session.withActive { - val response = process(request, session) + val sessionHolder = SparkConnectService + .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) + sessionHolder.session.withActive { + val response = process(request, sessionHolder) responseObserver.onNext(response) responseObserver.onCompleted() } @@ -48,8 +46,9 @@ private[connect] class SparkConnectAnalyzeHandler( def process( request: proto.AnalyzePlanRequest, - session: SparkSession): proto.AnalyzePlanResponse = { - lazy val planner = new SparkConnectPlanner(session) + sessionHolder: SessionHolder): proto.AnalyzePlanResponse = { + lazy val planner = new SparkConnectPlanner(sessionHolder) + val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() request.getAnalyzeCase match { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 4958fd69b9de..d11f4dcc6002 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils} @@ -83,8 +83,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp try { v.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) - case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v) + case proto.Plan.OpTypeCase.COMMAND => handleCommand(sessionHolder, v) + case proto.Plan.OpTypeCase.ROOT => handlePlan(sessionHolder, v) case _ => throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.") } @@ -94,10 +94,11 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp } } - private def handlePlan(session: SparkSession, request: ExecutePlanRequest): Unit = { + private def handlePlan(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { // Extract the plan from the request and convert it to a logical plan - val planner = new SparkConnectPlanner(session) - val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) + val planner = new SparkConnectPlanner(sessionHolder) + val dataframe = + Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot)) responseObserver.onNext( SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(request.getSessionId, dataframe, responseObserver) @@ -110,9 +111,9 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp responseObserver.onCompleted() } - private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { + private def handleCommand(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(session) + val planner = new SparkConnectPlanner(sessionHolder) planner.process( command = command, userId = request.getUserContext.getUserId, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index 1d8e8446c459..e5c14192a318 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.test.SharedSparkSession @@ -162,7 +163,7 @@ class ProtoToParsedPlanTestSuite val name = fileName.stripSuffix(".proto.bin") test(name) { val relation = readRelation(file) - val planner = new SparkConnectPlanner(spark) + val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark)) val catalystPlan = analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker) val actual = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 37d4bec9c874..ab01f2a6c147 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProj import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -51,11 +52,12 @@ trait SparkConnectPlanTest extends SharedSparkSession { } def transform(rel: proto.Relation): logical.LogicalPlan = { - new SparkConnectPlanner(spark).transformRelation(rel) + new SparkConnectPlanner(SessionHolder.forTesting(spark)).transformRelation(rel) } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId", new MockObserver()) + new SparkConnectPlanner(SessionHolder.forTesting(spark)) + .process(cmd, "clientId", "sessionId", new MockObserver()) } def readRel: proto.Relation = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index c36ba76f9845..ed47e8a647c1 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService} +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.test.SharedSparkSession /** @@ -37,6 +38,8 @@ import org.apache.spark.sql.test.SharedSparkSession */ class SparkConnectServiceSuite extends SharedSparkSession { + private def sparkSessionHolder = SessionHolder.forTesting(spark) + test("Test schema in analyze response") { withTable("test") { spark.sql(""" @@ -64,7 +67,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setSchema(proto.AnalyzePlanRequest.Schema.newBuilder().setPlan(plan).build()) .build() - val response1 = handler.process(request1, spark) + val response1 = handler.process(request1, sparkSessionHolder) assert(response1.hasSchema) assert(response1.getSchema.getSchema.hasStruct) val schema = response1.getSchema.getSchema.getStruct @@ -85,7 +88,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .setExplainMode(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE) .build()) .build() - val response2 = handler.process(request2, spark) + val response2 = handler.process(request2, sparkSessionHolder) assert(response2.hasExplain) assert(response2.getExplain.getExplainString.size > 0) @@ -93,7 +96,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setIsLocal(proto.AnalyzePlanRequest.IsLocal.newBuilder().setPlan(plan).build()) .build() - val response3 = handler.process(request3, spark) + val response3 = handler.process(request3, sparkSessionHolder) assert(response3.hasIsLocal) assert(!response3.getIsLocal.getIsLocal) @@ -101,7 +104,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setIsStreaming(proto.AnalyzePlanRequest.IsStreaming.newBuilder().setPlan(plan).build()) .build() - val response4 = handler.process(request4, spark) + val response4 = handler.process(request4, sparkSessionHolder) assert(response4.hasIsStreaming) assert(!response4.getIsStreaming.getIsStreaming) @@ -109,7 +112,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setTreeString(proto.AnalyzePlanRequest.TreeString.newBuilder().setPlan(plan).build()) .build() - val response5 = handler.process(request5, spark) + val response5 = handler.process(request5, sparkSessionHolder) assert(response5.hasTreeString) val treeString = response5.getTreeString.getTreeString assert(treeString.contains("root")) @@ -120,7 +123,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setInputFiles(proto.AnalyzePlanRequest.InputFiles.newBuilder().setPlan(plan).build()) .build() - val response6 = handler.process(request6, spark) + val response6 = handler.process(request6, sparkSessionHolder) assert(response6.hasInputFiles) assert(response6.getInputFiles.getFilesCount === 0) } @@ -291,7 +294,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .build()) .build() - val response = handler.process(request, spark) + val response = handler.process(request, sparkSessionHolder) assert(response.getExplain.getExplainString.contains("Parsed Logical Plan")) assert(response.getExplain.getExplainString.contains("Analyzed Logical Plan")) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index d61b54c67c25..2bdabc7ccc21 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest} +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.test.SharedSparkSession class DummyPlugin extends RelationPlugin { @@ -195,7 +196,8 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId", new MockObserver()) + new SparkConnectPlanner(SessionHolder.forTesting(spark)) + .process(plan, "clientId", "sessionId", new MockObserver()) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } }