Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ final case class InvalidCommandInput(
private val cause: Throwable = null)
extends Exception(message, cause)

class SparkConnectPlanner(val session: SparkSession) extends Logging {
class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change. The rest of the changes are consequence of API change.
@grundprinzip, @HyukjinKwon PTAL.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Member

@HyukjinKwon HyukjinKwon Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it needs some reviews @vicennial and/or @hvanhovell


def session: SparkSession = sessionHolder.session

private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
}
}
}

object SessionHolder {

/** Creates a dummy session holder for use in tests. */
def forTesting(session: SparkSession): SessionHolder = {
SessionHolder(userId = "testUser", sessionId = UUID.randomUUID().toString, session = session)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
Expand All @@ -35,21 +35,20 @@ private[connect] class SparkConnectAnalyzeHandler(

def handle(request: proto.AnalyzePlanRequest): Unit =
SparkConnectArtifactManager.withArtifactClassLoader {
val session =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.session
session.withActive {
val response = process(request, session)
val sessionHolder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
sessionHolder.session.withActive {
val response = process(request, sessionHolder)
responseObserver.onNext(response)
responseObserver.onCompleted()
}
}

def process(
request: proto.AnalyzePlanRequest,
session: SparkSession): proto.AnalyzePlanResponse = {
lazy val planner = new SparkConnectPlanner(session)
sessionHolder: SessionHolder): proto.AnalyzePlanResponse = {
lazy val planner = new SparkConnectPlanner(sessionHolder)
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()

request.getAnalyzeCase match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils}
Expand Down Expand Up @@ -83,8 +83,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp

try {
v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
case proto.Plan.OpTypeCase.COMMAND => handleCommand(sessionHolder, v)
case proto.Plan.OpTypeCase.ROOT => handlePlan(sessionHolder, v)
case _ =>
throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
}
Expand All @@ -94,10 +94,11 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
}
}

private def handlePlan(session: SparkSession, request: ExecutePlanRequest): Unit = {
private def handlePlan(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = {
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
val planner = new SparkConnectPlanner(sessionHolder)
val dataframe =
Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot))
responseObserver.onNext(
SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
Expand All @@ -110,9 +111,9 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
responseObserver.onCompleted()
}

private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
private def handleCommand(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = {
val command = request.getPlan.getCommand
val planner = new SparkConnectPlanner(session)
val planner = new SparkConnectPlanner(sessionHolder)
planner.process(
command = command,
userId = request.getUserContext.getUserId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -162,7 +163,7 @@ class ProtoToParsedPlanTestSuite
val name = fileName.stripSuffix(".proto.bin")
test(name) {
val relation = readRelation(file)
val planner = new SparkConnectPlanner(spark)
val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark))
val catalystPlan =
analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker)
val actual =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProj
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
Expand All @@ -51,11 +52,12 @@ trait SparkConnectPlanTest extends SharedSparkSession {
}

def transform(rel: proto.Relation): logical.LogicalPlan = {
new SparkConnectPlanner(spark).transformRelation(rel)
new SparkConnectPlanner(SessionHolder.forTesting(spark)).transformRelation(rel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!

}

def transform(cmd: proto.Command): Unit = {
new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId", new MockObserver())
new SparkConnectPlanner(SessionHolder.forTesting(spark))
.process(cmd, "clientId", "sessionId", new MockObserver())
}

def readRel: proto.Relation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService}
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.test.SharedSparkSession

/**
* Testing Connect Service implementation.
*/
class SparkConnectServiceSuite extends SharedSparkSession {

private def sparkSessionHolder = SessionHolder.forTesting(spark)

test("Test schema in analyze response") {
withTable("test") {
spark.sql("""
Expand Down Expand Up @@ -64,7 +67,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.newBuilder()
.setSchema(proto.AnalyzePlanRequest.Schema.newBuilder().setPlan(plan).build())
.build()
val response1 = handler.process(request1, spark)
val response1 = handler.process(request1, sparkSessionHolder)
assert(response1.hasSchema)
assert(response1.getSchema.getSchema.hasStruct)
val schema = response1.getSchema.getSchema.getStruct
Expand All @@ -85,31 +88,31 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.setExplainMode(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE)
.build())
.build()
val response2 = handler.process(request2, spark)
val response2 = handler.process(request2, sparkSessionHolder)
assert(response2.hasExplain)
assert(response2.getExplain.getExplainString.size > 0)

val request3 = proto.AnalyzePlanRequest
.newBuilder()
.setIsLocal(proto.AnalyzePlanRequest.IsLocal.newBuilder().setPlan(plan).build())
.build()
val response3 = handler.process(request3, spark)
val response3 = handler.process(request3, sparkSessionHolder)
assert(response3.hasIsLocal)
assert(!response3.getIsLocal.getIsLocal)

val request4 = proto.AnalyzePlanRequest
.newBuilder()
.setIsStreaming(proto.AnalyzePlanRequest.IsStreaming.newBuilder().setPlan(plan).build())
.build()
val response4 = handler.process(request4, spark)
val response4 = handler.process(request4, sparkSessionHolder)
assert(response4.hasIsStreaming)
assert(!response4.getIsStreaming.getIsStreaming)

val request5 = proto.AnalyzePlanRequest
.newBuilder()
.setTreeString(proto.AnalyzePlanRequest.TreeString.newBuilder().setPlan(plan).build())
.build()
val response5 = handler.process(request5, spark)
val response5 = handler.process(request5, sparkSessionHolder)
assert(response5.hasTreeString)
val treeString = response5.getTreeString.getTreeString
assert(treeString.contains("root"))
Expand All @@ -120,7 +123,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.newBuilder()
.setInputFiles(proto.AnalyzePlanRequest.InputFiles.newBuilder().setPlan(plan).build())
.build()
val response6 = handler.process(request6, spark)
val response6 = handler.process(request6, sparkSessionHolder)
assert(response6.hasInputFiles)
assert(response6.getInputFiles.getFilesCount === 0)
}
Expand Down Expand Up @@ -291,7 +294,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.build())
.build()

val response = handler.process(request, spark)
val response = handler.process(request, sparkSessionHolder)

assert(response.getExplain.getExplainString.contains("Parsed Logical Plan"))
assert(response.getExplain.getExplainString.contains("Analyzed Logical Plan"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest}
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.test.SharedSparkSession

class DummyPlugin extends RelationPlugin {
Expand Down Expand Up @@ -195,7 +196,8 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
.build()))
.build()

new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId", new MockObserver())
new SparkConnectPlanner(SessionHolder.forTesting(spark))
.process(plan, "clientId", "sessionId", new MockObserver())
assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin"))
}
}
Expand Down