From 574fe4c14a3f7f9f2997567a0f76e6ce6df05d96 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Fri, 11 Nov 2022 14:17:41 -0800 Subject: [PATCH 1/2] [SPARK-41122][CONNECT] Explain API can support different modes. --- .../main/protobuf/spark/connect/base.proto | 115 ++++-- .../connect/service/SparkConnectService.scala | 34 +- .../service/SparkConnectStreamHandler.scala | 46 +-- .../messages/ConnectProtoMessagesSuite.scala | 4 +- .../planner/SparkConnectServiceSuite.scala | 49 ++- python/pyspark/sql/connect/client.py | 55 ++- python/pyspark/sql/connect/dataframe.py | 63 +++- python/pyspark/sql/connect/proto/base_pb2.py | 56 +-- python/pyspark/sql/connect/proto/base_pb2.pyi | 329 +++++++++++++----- .../sql/connect/proto/base_pb2_grpc.py | 24 +- .../sql/tests/connect/test_connect_basic.py | 16 +- 11 files changed, 586 insertions(+), 205 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index a521eab20d842..66e27187153b9 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -38,14 +38,61 @@ message Plan { } } -// A request to be executed by the service. -message Request { +// Explains the input plan based on a configurable mode. +message Explain { + // Plan explanation mode. + enum ExplainMode { + MODE_UNSPECIFIED = 0; + + // Generates only physical plan. + SIMPLE = 1; + + // Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan. + // Parsed Logical plan is a unresolved plan that extracted from the query. Analyzed logical plans + // transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects. + // The optimized logical plan transforms through a set of optimization rules, resulting in the + // physical plan. + EXTENDED = 2; + + // Generates code for the statement, if any and a physical plan. + CODEGEN = 3; + + // If plan node statistics are available, generates a logical plan and also the statistics. + COST = 4; + + // Generates a physical plan outline and also node details. + FORMATTED = 5; + } + + // (Required) For analyzePlan rpc calls, configure the mode to explain plan in strings. + ExplainMode explain_mode= 1; +} + +// User Context is used to refer to one particular user session that is executing +// queries in the backend. +message UserContext { + string user_id = 1; + string user_name = 2; + + // To extend the existing user context message that is used to identify incoming requests, + // Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + // messages into this message. Extensions are stored as a `repeated` type to be able to + // handle multiple active extensions. + repeated google.protobuf.Any extensions = 999; +} + +// Request to perform plan analyze, optionally to explain the plan. +message AnalyzePlanRequest { + // (Required) + // // The client_id is set by the client to be able to collate streaming responses from // different queries. string client_id = 1; - // User context + + // (Required) User context UserContext user_context = 2; - // The logical plan to be executed / analyzed. + + // (Required) The logical plan to be analyzed. Plan plan = 3; // Provides optional information about the client sending the request. This field @@ -53,23 +100,43 @@ message Request { // logging purposes and will not be interpreted by the server. optional string client_type = 4; - // User Context is used to refer to one particular user session that is executing - // queries in the backend. - message UserContext { - string user_id = 1; - string user_name = 2; - - // To extend the existing user context message that is used to identify incoming requests, - // Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other - // messages into this message. Extensions are stored as a `repeated` type to be able to - // handle multiple active extensions. - repeated google.protobuf.Any extensions = 999; - } + // (Optional) Get the explain string of the plan. + Explain explain = 5; +} + +// Response to performing analysis of the query. Contains relevant metadata to be able to +// reason about the performance. +message AnalyzePlanResponse { + string client_id = 1; + DataType schema = 2; + + // The extended explain string as produced by Spark. + string explain_string = 3; +} + +// A request to be executed by the service. +message ExecutePlanRequest { + // (Required) + // + // The client_id is set by the client to be able to collate streaming responses from + // different queries. + string client_id = 1; + + // (Required) User context + UserContext user_context = 2; + + // (Required) The logical plan to be executed / analyzed. + Plan plan = 3; + + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 4; } // The response of a query, can be one or more for each request. Responses belonging to the // same input query, carry the same `client_id`. -message Response { +message ExecutePlanResponse { string client_id = 1; // Result type @@ -115,23 +182,13 @@ message Response { } } -// Response to performing analysis of the query. Contains relevant metadata to be able to -// reason about the performance. -message AnalyzeResponse { - string client_id = 1; - DataType schema = 2; - - // The extended explain string as produced by Spark. - string explain_string = 3; -} - // Main interface for the SparkConnect service. service SparkConnectService { // Executes a request that contains the query and returns a stream of [[Response]]. - rpc ExecutePlan(Request) returns (stream Response) {} + rpc ExecutePlan(ExecutePlanRequest) returns (stream ExecutePlanResponse) {} // Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query. - rpc AnalyzePlan(Request) returns (AnalyzeResponse) {} + rpc AnalyzePlan(AnalyzePlanRequest) returns (AnalyzePlanResponse) {} } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index abbad51c601aa..0c7a2ad2690c3 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -28,12 +28,12 @@ import io.grpc.stub.StreamObserver import org.apache.spark.SparkEnv import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc} +import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} -import org.apache.spark.sql.execution.ExtendedMode +import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode} /** * The SparkConnectService implementation. @@ -57,7 +57,9 @@ class SparkConnectService(debug: Boolean) * @param request * @param responseObserver */ - override def executePlan(request: Request, responseObserver: StreamObserver[Response]): Unit = { + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { try { new SparkConnectStreamHandler(responseObserver).handle(request) } catch { @@ -81,8 +83,8 @@ class SparkConnectService(debug: Boolean) * @param responseObserver */ override def analyzePlan( - request: Request, - responseObserver: StreamObserver[AnalyzeResponse]): Unit = { + request: AnalyzePlanRequest, + responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { try { if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { responseObserver.onError( @@ -91,7 +93,20 @@ class SparkConnectService(debug: Boolean) } val session = SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session - val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session) + + val explainMode = request.getExplain.getExplainMode match { + case proto.Explain.ExplainMode.SIMPLE => SimpleMode + case proto.Explain.ExplainMode.EXTENDED => ExtendedMode + case proto.Explain.ExplainMode.CODEGEN => CodegenMode + case proto.Explain.ExplainMode.COST => CostMode + case proto.Explain.ExplainMode.FORMATTED => FormattedMode + case _ => + throw new IllegalArgumentException( + s"Explain mode unspecified. Accepted " + + "explain modes are 'simple', 'extended', 'codegen', 'cost', 'formatted'.") + } + + val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session, explainMode) response.setClientId(request.getClientId) responseObserver.onNext(response.build()) responseObserver.onCompleted() @@ -105,13 +120,14 @@ class SparkConnectService(debug: Boolean) def handleAnalyzePlanRequest( relation: proto.Relation, - session: SparkSession): proto.AnalyzeResponse.Builder = { + session: SparkSession, + explainMode: ExplainMode): proto.AnalyzePlanResponse.Builder = { val logicalPlan = new SparkConnectPlanner(session).transformRelation(relation) val ds = Dataset.ofRows(session, logicalPlan) - val explainString = ds.queryExecution.explainString(ExtendedMode) + val explainString = ds.queryExecution.explainString(explainMode) - val response = proto.AnalyzeResponse + val response = proto.AnalyzePlanResponse .newBuilder() .setExplainString(explainString) response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema)) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index a780858d55caa..50ff08f997cb6 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -25,7 +25,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.SparkException import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{Request, Response} +import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -36,11 +36,13 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils -class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { +class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) + extends Logging { + // The maximum batch size in bytes for a single batch of data to be returned via proto. private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024 - def handle(v: Request): Unit = { + def handle(v: ExecutePlanRequest): Unit = { val session = SparkConnectService.getOrCreateIsolatedSession(v.getUserContext.getUserId).session v.getPlan.getOpTypeCase match { @@ -51,7 +53,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } - def handlePlan(session: SparkSession, request: Request): Unit = { + def handlePlan(session: SparkSession, request: ExecutePlanRequest): Unit = { // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) @@ -88,8 +90,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte // Case 1 - FLush and send. if (sb.size + row.size > MAX_BATCH_SIZE) { - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.JSONBatch + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.JSONBatch .newBuilder() .setData(ByteString.copyFromUtf8(sb.toString())) .setRowCount(rowCount) @@ -112,8 +114,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte // If the last batch is not empty, send out the data to the client. if (sb.size > 0) { - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.JSONBatch + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.JSONBatch .newBuilder() .setData(ByteString.copyFromUtf8(sb.toString())) .setRowCount(rowCount) @@ -205,8 +207,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } partition.foreach { case (bytes, count) => - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch .newBuilder() .setRowCount(count) .setData(ByteString.copyFrom(bytes)) @@ -223,8 +225,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte // Make sure at least 1 batch will be sent. if (numSent == 0) { val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch .newBuilder() .setRowCount(0L) .setData(ByteString.copyFrom(bytes)) @@ -238,16 +240,16 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } - def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = { + def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = { // Send a last batch with the metrics - Response + ExecutePlanResponse .newBuilder() .setClientId(clientId) .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) .build() } - def handleCommand(session: SparkSession, request: Request): Unit = { + def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { val command = request.getPlan.getCommand val planner = new SparkConnectPlanner(session) planner.process(command) @@ -274,13 +276,13 @@ object SparkConnectStreamHandler { } object MetricGenerator extends AdaptiveSparkPlanHelper { - def buildMetrics(p: SparkPlan): Response.Metrics = { - val b = Response.Metrics.newBuilder + def buildMetrics(p: SparkPlan): ExecutePlanResponse.Metrics = { + val b = ExecutePlanResponse.Metrics.newBuilder b.addAllMetrics(transformPlan(p, p.id).asJava) b.build() } - def transformChildren(p: SparkPlan): Seq[Response.Metrics.MetricObject] = { + def transformChildren(p: SparkPlan): Seq[ExecutePlanResponse.Metrics.MetricObject] = { allChildren(p).flatMap(c => transformPlan(c, p.id)) } @@ -290,14 +292,16 @@ object MetricGenerator extends AdaptiveSparkPlanHelper { case _ => p.children } - def transformPlan(p: SparkPlan, parentId: Int): Seq[Response.Metrics.MetricObject] = { + def transformPlan( + p: SparkPlan, + parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = { val mv = p.metrics.map(m => - m._1 -> Response.Metrics.MetricValue.newBuilder + m._1 -> ExecutePlanResponse.Metrics.MetricValue.newBuilder .setName(m._2.name.getOrElse("")) .setValue(m._2.value) .setMetricType(m._2.metricType) .build()) - val mo = Response.Metrics.MetricObject + val mo = ExecutePlanResponse.Metrics.MetricObject .newBuilder() .setName(p.nodeName) .setPlanId(p.id) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala index 4132cca91086c..31b572afa21fd 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.connect.proto class ConnectProtoMessagesSuite extends SparkFunSuite { test("UserContext can deal with extensions") { // Create the builder. - val builder = proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin") + val builder = proto.UserContext.newBuilder().setUserId("1").setUserName("Martin") // Create the extension value. val lit = proto.Expression @@ -36,7 +36,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite { val serialized = builder.build().toByteArray // Now, read the serialized value. - val result = proto.Request.UserContext.parseFrom(serialized) + val result = proto.UserContext.parseFrom(serialized) assert(result.getUserId.equals("1")) assert(result.getUserName.equals("Martin")) assert(result.getExtensionsCount == 1) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 7ff3a823fa1c2..1ada249131028 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.execution.ExplainMode import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ThreadUtils @@ -51,7 +52,8 @@ class SparkConnectServiceSuite extends SharedSparkSession { .build()) .build() - val response = instance.handleAnalyzePlanRequest(relation, spark) + val response = + instance.handleAnalyzePlanRequest(relation, spark, ExplainMode.fromString("simple")) assert(response.getSchema.hasStruct) val schema = response.getSchema.getStruct @@ -95,12 +97,57 @@ class SparkConnectServiceSuite extends SharedSparkSession { request, new StreamObserver[proto.Response] { private val responses = Seq.newBuilder[proto.Response] + override def onNext(v: proto.Response): Unit = responses += v + override def onError(throwable: Throwable): Unit = promise.failure(throwable) + override def onCompleted(): Unit = promise.success(responses.result()) }) intercept[SparkException] { ThreadUtils.awaitResult(promise.future, 2.seconds) } } + + test("Test explain mode in analyze response") { + withTable("test") { + spark.sql(""" + | CREATE TABLE test (col1 INT, col2 STRING) + | USING parquet + |""".stripMargin) + val instance = new SparkConnectService(false) + val relation = proto.Relation + .newBuilder() + .setProject( + proto.Project + .newBuilder() + .addExpressions( + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .addParts("abs") + .addArguments(proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(-1))))) + .setInput( + proto.Relation + .newBuilder() + .setRead(proto.Read + .newBuilder() + .setNamedTable( + proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build())))) + .build() + + val response = + instance + .handleAnalyzePlanRequest(relation, spark, ExplainMode.fromString("extended")) + .build() + assert(response.getExplainString.contains("Parsed Logical Plan")) + assert(response.getExplainString.contains("Analyzed Logical Plan")) + assert(response.getExplainString.contains("Optimized Logical Plan")) + assert(response.getExplainString.contains("Physical Plan")) + } + } } diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 3c3203a8f514d..5bdf01afc99c4 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -301,13 +301,15 @@ def register_udf( fun.parts.append(name) fun.serialized_function = cloudpickle.dumps((function, return_type)) - req = self._request_with_metadata() + req = self._execute_plan_request_with_metadata() req.plan.command.create_function.CopyFrom(fun) self._execute_and_fetch(req) return name - def _build_metrics(self, metrics: "pb2.Response.Metrics") -> typing.List[PlanMetrics]: + def _build_metrics( + self, metrics: "pb2.ExecutePlanResponse.Metrics" + ) -> typing.List[PlanMetrics]: return [ PlanMetrics( x.name, @@ -355,7 +357,7 @@ def range( ) def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: - req = self._request_with_metadata() + req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) return self._execute_and_fetch(req) @@ -393,30 +395,57 @@ def schema(self, plan: pb2.Plan) -> StructType: ) return StructType(structFields) - def explain_string(self, plan: pb2.Plan) -> str: - return self._analyze(plan).explain_string + def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: + result = self._analyze(plan, explain_mode) + return result.explain_string def execute_command(self, command: pb2.Command) -> None: - req = pb2.Request() + req = self._execute_plan_request_with_metadata() if self._user_id: req.user_context.user_id = self._user_id req.plan.command.CopyFrom(command) self._execute_and_fetch(req) + return - def _request_with_metadata(self) -> pb2.Request: - req = pb2.Request() + def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest: + req = pb2.ExecutePlanRequest() req.client_type = "_SPARK_CONNECT_PYTHON" if self._user_id: req.user_context.user_id = self._user_id return req - def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: - req = self._request_with_metadata() + def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: + req = pb2.AnalyzePlanRequest() + req.client_type = "_SPARK_CONNECT_PYTHON" + if self._user_id: + req.user_context.user_id = self._user_id + return req + + def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeResult: + req = self._analyze_plan_request_with_metadata() req.plan.CopyFrom(plan) + if explain_mode not in ["simple", "extended", "codegen", "cost", "formatted"]: + raise ValueError( + f""" + Unknown explain mode: {explain_mode}. Accepted " + "explain modes are 'simple', 'extended', 'codegen', 'cost', 'formatted'." + """ + ) + if explain_mode == "simple": + req.explain.explain_mode = pb2.Explain.ExplainMode.SIMPLE + elif explain_mode == "extended": + req.explain.explain_mode = pb2.Explain.ExplainMode.EXTENDED + elif explain_mode == "cost": + req.explain.explain_mode = pb2.Explain.ExplainMode.COST + elif explain_mode == "codegen": + req.explain.explain_mode = pb2.Explain.ExplainMode.CODEGEN + else: # formatted + req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED + resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) - def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: + def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]: import pandas as pd if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: @@ -426,10 +455,10 @@ def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) return None - def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: + def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: import pandas as pd - m: Optional[pb2.Response.Metrics] = None + m: Optional[pb2.ExecutePlanResponse.Metrics] = None result_dfs = [] for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 5dd28c0e6a94a..cfcbd6a8394aa 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -25,6 +25,7 @@ Union, TYPE_CHECKING, overload, + cast, ) import pandas @@ -755,12 +756,70 @@ def schema(self) -> StructType: else: return self._schema - def explain(self) -> str: + def explain( + self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None + ) -> str: + """Retruns plans in string for debugging purpose. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + extended : bool, optional + default ``False``. If ``False``, returns only the physical plan. + When this is a string without specifying the ``mode``, it works as the mode is + specified. + mode : str, optional + specifies the expected output format of plans. + + * ``simple``: Print only a physical plan. + * ``extended``: Print both logical and physical plans. + * ``codegen``: Print a physical plan and generated codes if they are available. + * ``cost``: Print a logical plan and statistics if they are available. + * ``formatted``: Split explain output into two sections: a physical plan outline \ + and node details. + """ + if extended is not None and mode is not None: + raise ValueError("extended and mode should not be set together.") + + # For the no argument case: df.explain() + is_no_argument = extended is None and mode is None + + # For the cases below: + # explain(True) + # explain(extended=False) + is_extended_case = isinstance(extended, bool) and mode is None + + # For the case when extended is mode: + # df.explain("formatted") + is_extended_as_mode = isinstance(extended, str) and mode is None + + # For the mode specified: + # df.explain(mode="formatted") + is_mode_case = extended is None and isinstance(mode, str) + + if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): + argtypes = [str(type(arg)) for arg in [extended, mode] if arg is not None] + raise TypeError( + "extended (optional) and mode (optional) should be a string " + "and bool; however, got [%s]." % ", ".join(argtypes) + ) + + # Sets an explain mode depending on a given argument + if is_no_argument: + explain_mode = "simple" + elif is_extended_case: + explain_mode = "extended" if extended else "simple" + elif is_mode_case: + explain_mode = cast(str, mode) + elif is_extended_as_mode: + explain_mode = cast(str, extended) + if self._plan is not None: query = self._plan.to_proto(self._session) if self._session is None: raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.explain_string(query) + return self._session.explain_string(query, explain_mode) else: return "" diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 0527e9b49aa86..8f61cde151e45 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xc8\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensionsB\x0e\n\x0c_client_type"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xad\x07\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12M\n\njson_batch\x18\x03 \x01(\x0b\x32,.spark.connect.ExecutePlanResponse.JSONBatchH\x00R\tjsonBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,30 +44,36 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" _PLAN._serialized_start = 158 _PLAN._serialized_end = 274 - _REQUEST._serialized_start = 277 - _REQUEST._serialized_end = 605 - _REQUEST_USERCONTEXT._serialized_start = 467 - _REQUEST_USERCONTEXT._serialized_end = 589 - _RESPONSE._serialized_start = 608 - _RESPONSE._serialized_end = 1472 - _RESPONSE_ARROWBATCH._serialized_start = 847 - _RESPONSE_ARROWBATCH._serialized_end = 908 - _RESPONSE_JSONBATCH._serialized_start = 910 - _RESPONSE_JSONBATCH._serialized_end = 970 - _RESPONSE_METRICS._serialized_start = 973 - _RESPONSE_METRICS._serialized_end = 1457 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1057 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1367 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1255 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1367 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1369 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1457 - _ANALYZERESPONSE._serialized_start = 1475 - _ANALYZERESPONSE._serialized_end = 1609 - _SPARKCONNECTSERVICE._serialized_start = 1612 - _SPARKCONNECTSERVICE._serialized_end = 1774 + _EXPLAIN._serialized_start = 277 + _EXPLAIN._serialized_end = 458 + _EXPLAIN_EXPLAINMODE._serialized_start = 359 + _EXPLAIN_EXPLAINMODE._serialized_end = 458 + _USERCONTEXT._serialized_start = 460 + _USERCONTEXT._serialized_end = 582 + _ANALYZEPLANREQUEST._serialized_start = 585 + _ANALYZEPLANREQUEST._serialized_end = 842 + _ANALYZEPLANRESPONSE._serialized_start = 845 + _ANALYZEPLANRESPONSE._serialized_end = 983 + _EXECUTEPLANREQUEST._serialized_start = 986 + _EXECUTEPLANREQUEST._serialized_end = 1193 + _EXECUTEPLANRESPONSE._serialized_start = 1196 + _EXECUTEPLANRESPONSE._serialized_end = 2137 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540 + _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542 + _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122 + _SPARKCONNECTSERVICE._serialized_start = 2140 + _SPARKCONNECTSERVICE._serialized_end = 2339 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index e70f9db14a368..18b70de57a3cd 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -38,13 +38,15 @@ import collections.abc import google.protobuf.any_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import pyspark.sql.connect.proto.commands_pb2 import pyspark.sql.connect.proto.relations_pb2 import pyspark.sql.connect.proto.types_pb2 import sys +import typing -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 10): import typing as typing_extensions else: import typing_extensions @@ -90,74 +92,148 @@ class Plan(google.protobuf.message.Message): global___Plan = Plan -class Request(google.protobuf.message.Message): - """A request to be executed by the service.""" +class Explain(google.protobuf.message.Message): + """Explains the input plan based on a configurable mode.""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - class UserContext(google.protobuf.message.Message): - """User Context is used to refer to one particular user session that is executing - queries in the backend. + class _ExplainMode: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _ExplainModeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Explain._ExplainMode.ValueType], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + MODE_UNSPECIFIED: Explain._ExplainMode.ValueType # 0 + SIMPLE: Explain._ExplainMode.ValueType # 1 + """Generates only physical plan.""" + EXTENDED: Explain._ExplainMode.ValueType # 2 + """Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan. + Parsed Logical plan is a unresolved plan that extracted from the query. Analyzed logical plans + transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects. + The optimized logical plan transforms through a set of optimization rules, resulting in the + physical plan. """ + CODEGEN: Explain._ExplainMode.ValueType # 3 + """Generates code for the statement, if any and a physical plan.""" + COST: Explain._ExplainMode.ValueType # 4 + """If plan node statistics are available, generates a logical plan and also the statistics.""" + FORMATTED: Explain._ExplainMode.ValueType # 5 + """Generates a physical plan outline and also node details.""" + + class ExplainMode(_ExplainMode, metaclass=_ExplainModeEnumTypeWrapper): + """Plan explanation mode.""" + + MODE_UNSPECIFIED: Explain.ExplainMode.ValueType # 0 + SIMPLE: Explain.ExplainMode.ValueType # 1 + """Generates only physical plan.""" + EXTENDED: Explain.ExplainMode.ValueType # 2 + """Generates parsed logical plan, analyzed logical plan, optimized logical plan and physical plan. + Parsed Logical plan is a unresolved plan that extracted from the query. Analyzed logical plans + transforms which translates unresolvedAttribute and unresolvedRelation into fully typed objects. + The optimized logical plan transforms through a set of optimization rules, resulting in the + physical plan. + """ + CODEGEN: Explain.ExplainMode.ValueType # 3 + """Generates code for the statement, if any and a physical plan.""" + COST: Explain.ExplainMode.ValueType # 4 + """If plan node statistics are available, generates a logical plan and also the statistics.""" + FORMATTED: Explain.ExplainMode.ValueType # 5 + """Generates a physical plan outline and also node details.""" + + EXPLAIN_MODE_FIELD_NUMBER: builtins.int + explain_mode: global___Explain.ExplainMode.ValueType + """(Required) For analyzePlan rpc calls, configure the mode to explain plan in strings.""" + def __init__( + self, + *, + explain_mode: global___Explain.ExplainMode.ValueType = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["explain_mode", b"explain_mode"] + ) -> None: ... - DESCRIPTOR: google.protobuf.descriptor.Descriptor +global___Explain = Explain - USER_ID_FIELD_NUMBER: builtins.int - USER_NAME_FIELD_NUMBER: builtins.int - EXTENSIONS_FIELD_NUMBER: builtins.int - user_id: builtins.str - user_name: builtins.str - @property - def extensions( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - google.protobuf.any_pb2.Any - ]: - """To extend the existing user context message that is used to identify incoming requests, - Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other - messages into this message. Extensions are stored as a `repeated` type to be able to - handle multiple active extensions. - """ - def __init__( - self, - *, - user_id: builtins.str = ..., - user_name: builtins.str = ..., - extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., - ) -> None: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name" - ], - ) -> None: ... +class UserContext(google.protobuf.message.Message): + """User Context is used to refer to one particular user session that is executing + queries in the backend. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + USER_ID_FIELD_NUMBER: builtins.int + USER_NAME_FIELD_NUMBER: builtins.int + EXTENSIONS_FIELD_NUMBER: builtins.int + user_id: builtins.str + user_name: builtins.str + @property + def extensions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + google.protobuf.any_pb2.Any + ]: + """To extend the existing user context message that is used to identify incoming requests, + Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + messages into this message. Extensions are stored as a `repeated` type to be able to + handle multiple active extensions. + """ + def __init__( + self, + *, + user_id: builtins.str = ..., + user_name: builtins.str = ..., + extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name" + ], + ) -> None: ... + +global___UserContext = UserContext + +class AnalyzePlanRequest(google.protobuf.message.Message): + """Request to perform plan analyze, optionally to explain the plan.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor CLIENT_ID_FIELD_NUMBER: builtins.int USER_CONTEXT_FIELD_NUMBER: builtins.int PLAN_FIELD_NUMBER: builtins.int CLIENT_TYPE_FIELD_NUMBER: builtins.int + EXPLAIN_FIELD_NUMBER: builtins.int client_id: builtins.str - """The client_id is set by the client to be able to collate streaming responses from + """(Required) + + The client_id is set by the client to be able to collate streaming responses from different queries. """ @property - def user_context(self) -> global___Request.UserContext: - """User context""" + def user_context(self) -> global___UserContext: + """(Required) User context""" @property def plan(self) -> global___Plan: - """The logical plan to be executed / analyzed.""" + """(Required) The logical plan to be analyzed.""" client_type: builtins.str """Provides optional information about the client sending the request. This field can be used for language or version specific information and is only intended for logging purposes and will not be interpreted by the server. """ + @property + def explain(self) -> global___Explain: + """(Optional) Get the explain string of the plan.""" def __init__( self, *, client_id: builtins.str = ..., - user_context: global___Request.UserContext | None = ..., + user_context: global___UserContext | None = ..., plan: global___Plan | None = ..., client_type: builtins.str | None = ..., + explain: global___Explain | None = ..., ) -> None: ... def HasField( self, @@ -166,6 +242,8 @@ class Request(google.protobuf.message.Message): b"_client_type", "client_type", b"client_type", + "explain", + b"explain", "plan", b"plan", "user_context", @@ -181,6 +259,8 @@ class Request(google.protobuf.message.Message): b"client_id", "client_type", b"client_type", + "explain", + b"explain", "plan", b"plan", "user_context", @@ -191,9 +271,111 @@ class Request(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] ) -> typing_extensions.Literal["client_type"] | None: ... -global___Request = Request +global___AnalyzePlanRequest = AnalyzePlanRequest + +class AnalyzePlanResponse(google.protobuf.message.Message): + """Response to performing analysis of the query. Contains relevant metadata to be able to + reason about the performance. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_ID_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int + EXPLAIN_STRING_FIELD_NUMBER: builtins.int + client_id: builtins.str + @property + def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + explain_string: builtins.str + """The extended explain string as produced by Spark.""" + def __init__( + self, + *, + client_id: builtins.str = ..., + schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + explain_string: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["schema", b"schema"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema" + ], + ) -> None: ... + +global___AnalyzePlanResponse = AnalyzePlanResponse + +class ExecutePlanRequest(google.protobuf.message.Message): + """A request to be executed by the service.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + PLAN_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int + client_id: builtins.str + """(Required) -class Response(google.protobuf.message.Message): + The client_id is set by the client to be able to collate streaming responses from + different queries. + """ + @property + def user_context(self) -> global___UserContext: + """(Required) User context""" + @property + def plan(self) -> global___Plan: + """(Required) The logical plan to be executed / analyzed.""" + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ + def __init__( + self, + *, + client_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + plan: global___Plan | None = ..., + client_type: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "plan", + b"plan", + "user_context", + b"user_context", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_id", + b"client_id", + "client_type", + b"client_type", + "plan", + b"plan", + "user_context", + b"user_context", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... + +global___ExecutePlanRequest = ExecutePlanRequest + +class ExecutePlanResponse(google.protobuf.message.Message): """The response of a query, can be one or more for each request. Responses belonging to the same input query, carry the same `client_id`. """ @@ -254,12 +436,12 @@ class Response(google.protobuf.message.Message): VALUE_FIELD_NUMBER: builtins.int key: builtins.str @property - def value(self) -> global___Response.Metrics.MetricValue: ... + def value(self) -> global___ExecutePlanResponse.Metrics.MetricValue: ... def __init__( self, *, key: builtins.str = ..., - value: global___Response.Metrics.MetricValue | None = ..., + value: global___ExecutePlanResponse.Metrics.MetricValue | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal["value", b"value"] @@ -279,7 +461,7 @@ class Response(google.protobuf.message.Message): def execution_metrics( self, ) -> google.protobuf.internal.containers.MessageMap[ - builtins.str, global___Response.Metrics.MetricValue + builtins.str, global___ExecutePlanResponse.Metrics.MetricValue ]: ... def __init__( self, @@ -288,7 +470,7 @@ class Response(google.protobuf.message.Message): plan_id: builtins.int = ..., parent: builtins.int = ..., execution_metrics: collections.abc.Mapping[ - builtins.str, global___Response.Metrics.MetricValue + builtins.str, global___ExecutePlanResponse.Metrics.MetricValue ] | None = ..., ) -> None: ... @@ -334,12 +516,13 @@ class Response(google.protobuf.message.Message): def metrics( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - global___Response.Metrics.MetricObject + global___ExecutePlanResponse.Metrics.MetricObject ]: ... def __init__( self, *, - metrics: collections.abc.Iterable[global___Response.Metrics.MetricObject] | None = ..., + metrics: collections.abc.Iterable[global___ExecutePlanResponse.Metrics.MetricObject] + | None = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal["metrics", b"metrics"] @@ -351,11 +534,11 @@ class Response(google.protobuf.message.Message): METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def arrow_batch(self) -> global___Response.ArrowBatch: ... + def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property - def json_batch(self) -> global___Response.JSONBatch: ... + def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ... @property - def metrics(self) -> global___Response.Metrics: + def metrics(self) -> global___ExecutePlanResponse.Metrics: """Metrics for the query execution. Typically, this field is only present in the last batch of results and then represent the overall state of the query execution. """ @@ -363,9 +546,9 @@ class Response(google.protobuf.message.Message): self, *, client_id: builtins.str = ..., - arrow_batch: global___Response.ArrowBatch | None = ..., - json_batch: global___Response.JSONBatch | None = ..., - metrics: global___Response.Metrics | None = ..., + arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., + json_batch: global___ExecutePlanResponse.JSONBatch | None = ..., + metrics: global___ExecutePlanResponse.Metrics | None = ..., ) -> None: ... def HasField( self, @@ -399,38 +582,4 @@ class Response(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... -global___Response = Response - -class AnalyzeResponse(google.protobuf.message.Message): - """Response to performing analysis of the query. Contains relevant metadata to be able to - reason about the performance. - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - CLIENT_ID_FIELD_NUMBER: builtins.int - SCHEMA_FIELD_NUMBER: builtins.int - EXPLAIN_STRING_FIELD_NUMBER: builtins.int - client_id: builtins.str - @property - def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... - explain_string: builtins.str - """The extended explain string as produced by Spark.""" - def __init__( - self, - *, - client_id: builtins.str = ..., - schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., - explain_string: builtins.str = ..., - ) -> None: ... - def HasField( - self, field_name: typing_extensions.Literal["schema", b"schema"] - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema" - ], - ) -> None: ... - -global___AnalyzeResponse = AnalyzeResponse +global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index 77307603f6ebc..139727e283007 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -32,13 +32,13 @@ def __init__(self, channel): """ self.ExecutePlan = channel.unary_stream( "/spark.connect.SparkConnectService/ExecutePlan", - request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString, - response_deserializer=spark_dot_connect_dot_base__pb2.Response.FromString, + request_serializer=spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString, ) self.AnalyzePlan = channel.unary_unary( "/spark.connect.SparkConnectService/AnalyzePlan", - request_serializer=spark_dot_connect_dot_base__pb2.Request.SerializeToString, - response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString, + request_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.FromString, ) @@ -62,13 +62,13 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): rpc_method_handlers = { "ExecutePlan": grpc.unary_stream_rpc_method_handler( servicer.ExecutePlan, - request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString, - response_serializer=spark_dot_connect_dot_base__pb2.Response.SerializeToString, + request_deserializer=spark_dot_connect_dot_base__pb2.ExecutePlanRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.ExecutePlanResponse.SerializeToString, ), "AnalyzePlan": grpc.unary_unary_rpc_method_handler( servicer.AnalyzePlan, - request_deserializer=spark_dot_connect_dot_base__pb2.Request.FromString, - response_serializer=spark_dot_connect_dot_base__pb2.AnalyzeResponse.SerializeToString, + request_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -98,8 +98,8 @@ def ExecutePlan( request, target, "/spark.connect.SparkConnectService/ExecutePlan", - spark_dot_connect_dot_base__pb2.Request.SerializeToString, - spark_dot_connect_dot_base__pb2.Response.FromString, + spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString, options, channel_credentials, insecure, @@ -127,8 +127,8 @@ def AnalyzePlan( request, target, "/spark.connect.SparkConnectService/AnalyzePlan", - spark_dot_connect_dot_base__pb2.Request.SerializeToString, - spark_dot_connect_dot_base__pb2.AnalyzeResponse.FromString, + spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.FromString, options, channel_credentials, insecure, diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 01000e95f568b..fb455b7f3f430 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -189,7 +189,9 @@ def test_take(self) -> None: def test_subquery_alias(self) -> None: # SPARK-40938: test subquery alias. - plan_text = self.connect.read.table(self.tbl_name).alias("special_alias").explain() + plan_text = ( + self.connect.read.table(self.tbl_name).alias("special_alias").explain(extended=True) + ) self.assertTrue("special_alias" in plan_text) def test_range(self): @@ -277,6 +279,18 @@ def test_show(self): expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n" self.assertEqual(show_str, expected) + def test_explain_string(self): + # SPARK-41122: test explain API. + plan_str = self.connect.sql("SELECT 1").explain(extended=True) + self.assertTrue("Parsed Logical Plan" in plan_str) + self.assertTrue("Analyzed Logical Plan" in plan_str) + self.assertTrue("Optimized Logical Plan" in plan_str) + self.assertTrue("Physical Plan" in plan_str) + + with self.assertRaises(ValueError) as context: + self.connect.sql("SELECT 1").explain(mode="unknown") + self.assertTrue("unknown" in str(context.exception)) + def test_simple_datasource_read(self) -> None: writeDf = self.df_text tmpPath = tempfile.mkdtemp() From ab96e1ad73360b75d6d8a5bfcfda877fa9e401da Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 17 Nov 2022 16:53:26 -0800 Subject: [PATCH 2/2] update --- .../connect/planner/SparkConnectServiceSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 1ada249131028..133ce980ecddb 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -78,7 +78,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { session.udf.register("insta_kill", instaKill) val connect = new MockRemoteSession() - val context = proto.Request.UserContext + val context = proto.UserContext .newBuilder() .setUserId("c1") .build() @@ -86,19 +86,19 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setRoot(connect.sql("select insta_kill(id) from range(10)")) .build() - val request = proto.Request + val request = proto.ExecutePlanRequest .newBuilder() .setPlan(plan) .setUserContext(context) .build() - val promise = Promise[Seq[proto.Response]] + val promise = Promise[Seq[proto.ExecutePlanResponse]] instance.executePlan( request, - new StreamObserver[proto.Response] { - private val responses = Seq.newBuilder[proto.Response] + new StreamObserver[proto.ExecutePlanResponse] { + private val responses = Seq.newBuilder[proto.ExecutePlanResponse] - override def onNext(v: proto.Response): Unit = responses += v + override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v override def onError(throwable: Throwable): Unit = promise.failure(throwable)