Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,105 @@ message Plan {
}
}

// A request to be executed by the service.
message Request {
// Explains the input plan based on a configurable mode.
message Explain {
// Plan explanation mode.
enum ExplainMode {
MODE_UNSPECIFIED = 0;

// Generates only physical plan.
SIMPLE = 1;

// Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan.
// Parsed Logical plan is a unresolved plan that extracted from the query. Analyzed logical plans
// transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects.
// The optimized logical plan transforms through a set of optimization rules, resulting in the
// physical plan.
EXTENDED = 2;

// Generates code for the statement, if any and a physical plan.
CODEGEN = 3;

// If plan node statistics are available, generates a logical plan and also the statistics.
COST = 4;

// Generates a physical plan outline and also node details.
FORMATTED = 5;
}

// (Required) For analyzePlan rpc calls, configure the mode to explain plan in strings.
ExplainMode explain_mode= 1;
}

// User Context is used to refer to one particular user session that is executing
// queries in the backend.
message UserContext {
string user_id = 1;
string user_name = 2;

// To extend the existing user context message that is used to identify incoming requests,
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
// messages into this message. Extensions are stored as a `repeated` type to be able to
// handle multiple active extensions.
repeated google.protobuf.Any extensions = 999;
}

// Request to perform plan analyze, optionally to explain the plan.
message AnalyzePlanRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;
// User context

// (Required) User context
UserContext user_context = 2;
// The logical plan to be executed / analyzed.

// (Required) The logical plan to be analyzed.
Plan plan = 3;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;

// User Context is used to refer to one particular user session that is executing
// queries in the backend.
message UserContext {
string user_id = 1;
string user_name = 2;

// To extend the existing user context message that is used to identify incoming requests,
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
// messages into this message. Extensions are stored as a `repeated` type to be able to
// handle multiple active extensions.
repeated google.protobuf.Any extensions = 999;
}
// (Optional) Get the explain string of the plan.
Explain explain = 5;
}

// Response to performing analysis of the query. Contains relevant metadata to be able to
// reason about the performance.
message AnalyzePlanResponse {
string client_id = 1;
DataType schema = 2;

// The extended explain string as produced by Spark.
string explain_string = 3;
}

// A request to be executed by the service.
message ExecutePlanRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;

// (Required) User context
UserContext user_context = 2;

// (Required) The logical plan to be executed / analyzed.
Plan plan = 3;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;
}

// The response of a query, can be one or more for each request. Responses belonging to the
// same input query, carry the same `client_id`.
message Response {
message ExecutePlanResponse {
string client_id = 1;

// Result type
Expand Down Expand Up @@ -115,23 +182,13 @@ message Response {
}
}

// Response to performing analysis of the query. Contains relevant metadata to be able to
// reason about the performance.
message AnalyzeResponse {
string client_id = 1;
DataType schema = 2;

// The extended explain string as produced by Spark.
string explain_string = 3;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

// Executes a request that contains the query and returns a stream of [[Response]].
rpc ExecutePlan(Request) returns (stream Response) {}
rpc ExecutePlan(ExecutePlanRequest) returns (stream ExecutePlanResponse) {}

// Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.
rpc AnalyzePlan(Request) returns (AnalyzeResponse) {}
rpc AnalyzePlan(AnalyzePlanRequest) returns (AnalyzePlanResponse) {}
}

Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc}
import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
import org.apache.spark.sql.execution.ExtendedMode
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode}

/**
* The SparkConnectService implementation.
Expand All @@ -57,7 +57,9 @@ class SparkConnectService(debug: Boolean)
* @param request
* @param responseObserver
*/
override def executePlan(request: Request, responseObserver: StreamObserver[Response]): Unit = {
override def executePlan(
request: ExecutePlanRequest,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
try {
new SparkConnectStreamHandler(responseObserver).handle(request)
} catch {
Expand All @@ -81,8 +83,8 @@ class SparkConnectService(debug: Boolean)
* @param responseObserver
*/
override def analyzePlan(
request: Request,
responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
request: AnalyzePlanRequest,
responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
try {
if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) {
responseObserver.onError(
Expand All @@ -91,7 +93,20 @@ class SparkConnectService(debug: Boolean)
}
val session =
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session)

val explainMode = request.getExplain.getExplainMode match {
case proto.Explain.ExplainMode.SIMPLE => SimpleMode
case proto.Explain.ExplainMode.EXTENDED => ExtendedMode
case proto.Explain.ExplainMode.CODEGEN => CodegenMode
case proto.Explain.ExplainMode.COST => CostMode
case proto.Explain.ExplainMode.FORMATTED => FormattedMode
case _ =>
throw new IllegalArgumentException(
s"Explain mode unspecified. Accepted " +
"explain modes are 'simple', 'extended', 'codegen', 'cost', 'formatted'.")
}

val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session, explainMode)
response.setClientId(request.getClientId)
responseObserver.onNext(response.build())
responseObserver.onCompleted()
Expand All @@ -105,13 +120,14 @@ class SparkConnectService(debug: Boolean)

def handleAnalyzePlanRequest(
relation: proto.Relation,
session: SparkSession): proto.AnalyzeResponse.Builder = {
session: SparkSession,
explainMode: ExplainMode): proto.AnalyzePlanResponse.Builder = {
val logicalPlan = new SparkConnectPlanner(session).transformRelation(relation)

val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(ExtendedMode)
val explainString = ds.queryExecution.explainString(explainMode)

val response = proto.AnalyzeResponse
val response = proto.AnalyzePlanResponse
.newBuilder()
.setExplainString(explainString)
response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{Request, Response}
import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -36,11 +36,13 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
extends Logging {

// The maximum batch size in bytes for a single batch of data to be returned via proto.
private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024

def handle(v: Request): Unit = {
def handle(v: ExecutePlanRequest): Unit = {
val session =
SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session
v.getPlan.getOpTypeCase match {
Expand All @@ -51,7 +53,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
}
}

def handlePlan(session: SparkSession, request: Request): Unit = {
def handlePlan(session: SparkSession, request: ExecutePlanRequest): Unit = {
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
Expand Down Expand Up @@ -88,8 +90,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte

// Case 1 - FLush and send.
if (sb.size + row.size > MAX_BATCH_SIZE) {
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.JSONBatch
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.JSONBatch
.newBuilder()
.setData(ByteString.copyFromUtf8(sb.toString()))
.setRowCount(rowCount)
Expand All @@ -112,8 +114,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte

// If the last batch is not empty, send out the data to the client.
if (sb.size > 0) {
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.JSONBatch
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.JSONBatch
.newBuilder()
.setData(ByteString.copyFromUtf8(sb.toString()))
.setRowCount(rowCount)
Expand Down Expand Up @@ -205,8 +207,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
}

partition.foreach { case (bytes, count) =>
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.ArrowBatch
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
Expand All @@ -223,8 +225,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
// Make sure at least 1 batch will be sent.
if (numSent == 0) {
val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.ArrowBatch
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(0L)
.setData(ByteString.copyFrom(bytes))
Expand All @@ -238,16 +240,16 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
}
}

def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = {
def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = {
// Send a last batch with the metrics
Response
ExecutePlanResponse
.newBuilder()
.setClientId(clientId)
.setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
.build()
}

def handleCommand(session: SparkSession, request: Request): Unit = {
def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
val command = request.getPlan.getCommand
val planner = new SparkConnectPlanner(session)
planner.process(command)
Expand All @@ -274,13 +276,13 @@ object SparkConnectStreamHandler {
}

object MetricGenerator extends AdaptiveSparkPlanHelper {
def buildMetrics(p: SparkPlan): Response.Metrics = {
val b = Response.Metrics.newBuilder
def buildMetrics(p: SparkPlan): ExecutePlanResponse.Metrics = {
val b = ExecutePlanResponse.Metrics.newBuilder
b.addAllMetrics(transformPlan(p, p.id).asJava)
b.build()
}

def transformChildren(p: SparkPlan): Seq[Response.Metrics.MetricObject] = {
def transformChildren(p: SparkPlan): Seq[ExecutePlanResponse.Metrics.MetricObject] = {
allChildren(p).flatMap(c => transformPlan(c, p.id))
}

Expand All @@ -290,14 +292,16 @@ object MetricGenerator extends AdaptiveSparkPlanHelper {
case _ => p.children
}

def transformPlan(p: SparkPlan, parentId: Int): Seq[Response.Metrics.MetricObject] = {
def transformPlan(
p: SparkPlan,
parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = {
val mv = p.metrics.map(m =>
m._1 -> Response.Metrics.MetricValue.newBuilder
m._1 -> ExecutePlanResponse.Metrics.MetricValue.newBuilder
.setName(m._2.name.getOrElse(""))
.setValue(m._2.value)
.setMetricType(m._2.metricType)
.build())
val mo = Response.Metrics.MetricObject
val mo = ExecutePlanResponse.Metrics.MetricObject
.newBuilder()
.setName(p.nodeName)
.setPlanId(p.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.connect.proto
class ConnectProtoMessagesSuite extends SparkFunSuite {
test("UserContext can deal with extensions") {
// Create the builder.
val builder = proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin")
val builder = proto.UserContext.newBuilder().setUserId("1").setUserName("Martin")

// Create the extension value.
val lit = proto.Expression
Expand All @@ -36,7 +36,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
val serialized = builder.build().toByteArray

// Now, read the serialized value.
val result = proto.Request.UserContext.parseFrom(serialized)
val result = proto.UserContext.parseFrom(serialized)
assert(result.getUserId.equals("1"))
assert(result.getUserName.equals("Martin"))
assert(result.getExtensionsCount == 1)
Expand Down
Loading