From 8de008ffedc836df5cf9a0e3209c7195e65619c9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 18 Nov 2022 12:43:03 +0800 Subject: [PATCH 1/3] init --- .../service/SparkConnectStreamHandler.scala | 93 +++---------------- python/pyspark/sql/connect/client.py | 5 - .../sql/tests/connect/test_connect_basic.py | 47 +++++++++- 3 files changed, 57 insertions(+), 88 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 50ff08f997cb..676862f5f25f 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -34,7 +33,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) extends Logging { @@ -57,13 +55,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) - try { - processAsArrowBatches(request.getClientId, dataframe) - } catch { - case e: Exception => - logWarning(e.getMessage) - processAsJsonBatches(request.getClientId, dataframe) - } + processAsArrowBatches(request.getClientId, dataframe) } def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { @@ -142,83 +134,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp var numSent = 0 if (numPartitions > 0) { - type Batch = (Array[Byte], Long) - val batches = rows.mapPartitionsInternal( SparkConnectStreamHandler .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) - val signal = new Object - val partitions = collection.mutable.Map.empty[Int, Array[Batch]] - var error: Throwable = null - - val processPartition = (iter: Iterator[Batch]) => iter.toArray - - // This callback is executed by the DAGScheduler thread. - // After fetching a partition, it inserts the partition into the Map, and then - // wakes up the main thread. - val resultHandler = (partitionId: Int, partition: Array[Batch]) => { - signal.synchronized { - partitions(partitionId) = partition - signal.notify() - } - () - } - - val future = spark.sparkContext.submitJob( - rdd = batches, - processPartition = processPartition, - partitions = Seq.range(0, numPartitions), - resultHandler = resultHandler, - resultFunc = () => ()) - - // Collect errors and propagate them to the main thread. - future.onComplete { result => - result.failed.foreach { throwable => - signal.synchronized { - error = throwable - signal.notify() - } - } - }(ThreadUtils.sameThread) - - // The main thread will wait until 0-th partition is available, - // then send it to client and wait for the next partition. - // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends - // the arrow batches in main thread to avoid DAGScheduler thread been blocked for - // tasks not related to scheduling. This is particularly important if there are - // multiple users or clients running code at the same time. - var currentPartitionId = 0 - while (currentPartitionId < numPartitions) { - val partition = signal.synchronized { - var result = partitions.remove(currentPartitionId) - while (result.isEmpty && error == null) { - signal.wait() - result = partitions.remove(currentPartitionId) - } - error match { - case NonFatal(e) => - responseObserver.onError(error) - logError("Error while processing query.", e) - return - case fatal: Throwable => throw fatal - case null => result.get - } - } - - partition.foreach { case (bytes, count) => - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.ArrowBatch - .newBuilder() - .setRowCount(count) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - numSent += 1 - } - - currentPartitionId += 1 + batches.collect().foreach { case (bytes, count) => + val response = proto.Response.newBuilder().setClientId(clientId) + val batch = proto.Response.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + numSent += 1 } } diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 5bdf01afc99c..fdcf34b7a47e 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -16,7 +16,6 @@ # -import io import logging import os import typing @@ -446,13 +445,9 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeRes return AnalyzeResult.fromProto(resp) def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]: - import pandas as pd - if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: with pa.ipc.open_stream(b.arrow_batch.data) as rd: return rd.read_pandas() - elif b.json_batch is not None and len(b.json_batch.data) > 0: - return pd.read_json(io.BytesIO(b.json_batch.data), lines=True) return None def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index d3de94a379f8..a1b7c04a50fe 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -221,7 +221,52 @@ def test_create_global_temp_view(self): with self.assertRaises(_MultiThreadedRendezvous): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - @unittest.skip("test_fill_na is flaky") + def test_to_pandas(self): + # SPARK-XXXX: Test to pandas + query = """ + SELECT * FROM VALUES + (false, 1, float(NULL)), (false, NULL, float(2.0)), (NULL, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (1, 1, float(NULL)), (2, NULL, float(2.0)), (3, 3, float(3.0)) + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (1.0, 1, "1"), (NULL, NULL, NULL), (2.0, 3, "3") + AS tab(a, b, c) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + + query = """ + SELECT * FROM VALUES + (float(1.0), 1.0, 1, "1"), (float(2.0), 2.0, 2, "2"), (float(3.0), 2.0, 3, "3") + AS tab(a, b, c, d) + """ + + self.assert_eq( + self.connect.sql(query).toPandas(), + self.spark.sql(query).toPandas(), + ) + def test_fill_na(self): # SPARK-41128: Test fill na query = """ From 4965ac7daf3b5fd642de1effb4101a7c4746cc1e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 18 Nov 2022 13:22:55 +0800 Subject: [PATCH 2/3] nit --- .../spark/sql/connect/service/SparkConnectStreamHandler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 676862f5f25f..37f0d159092e 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -139,8 +139,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) batches.collect().foreach { case (bytes, count) => - val response = proto.Response.newBuilder().setClientId(clientId) - val batch = proto.Response.ArrowBatch + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch .newBuilder() .setRowCount(count) .setData(ByteString.copyFrom(bytes)) From ed59e6cf6caae9c4fbb19a7958a509e21a96c130 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 21 Nov 2022 10:32:40 +0800 Subject: [PATCH 3/3] remove json path --- .../main/protobuf/spark/connect/base.proto | 14 +---- .../service/SparkConnectStreamHandler.scala | 63 ------------------- python/pyspark/sql/connect/proto/base_pb2.py | 41 +++++------- python/pyspark/sql/connect/proto/base_pb2.pyi | 51 +-------------- .../sql/tests/connect/test_connect_basic.py | 18 ++++-- 5 files changed, 30 insertions(+), 157 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 66e27187153b..277da6b2431d 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -139,11 +139,7 @@ message ExecutePlanRequest { message ExecutePlanResponse { string client_id = 1; - // Result type - oneof result_type { - ArrowBatch arrow_batch = 2; - JSONBatch json_batch = 3; - } + ArrowBatch arrow_batch = 2; // Metrics for the query execution. Typically, this field is only present in the last // batch of results and then represent the overall state of the query execution. @@ -155,14 +151,6 @@ message ExecutePlanResponse { bytes data = 2; } - // Message type when the result is returned as JSON. This is essentially a bulk wrapper - // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - // of `{col -> row}`. - message JSONBatch { - int64 row_count = 1; - bytes data = 2; - } - message Metrics { repeated MetricObject metrics = 1; diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 37f0d159092e..092bdd00dc1c 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging @@ -58,68 +57,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp processAsArrowBatches(request.getClientId, dataframe) } - def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = { - // Only process up to 10MB of data. - val sb = new StringBuilder - var rowCount = 0 - dataframe.toJSON - .collect() - .foreach(row => { - - // There are a few cases to cover here. - // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE - // -> send the current batch and reset. - // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE - // -> append the row to the buffer. - // 3. The row in question is larger than the MAX_BATCH_SIZE - // -> fail the query. - - // Case 3. - Fail - if (row.size > MAX_BATCH_SIZE) { - throw SparkException.internalError( - s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}") - } - - // Case 1 - FLush and send. - if (sb.size + row.size > MAX_BATCH_SIZE) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - sb.clear() - sb.append(row) - rowCount = 1 - } else { - // Case 2 - Append. - // Make sure to put the newline delimiters only between items and not at the end. - if (rowCount > 0) { - sb.append("\n") - } - sb.append(row) - rowCount += 1 - } - }) - - // If the last batch is not empty, send out the data to the client. - if (sb.size > 0) { - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.JSONBatch - .newBuilder() - .setData(ByteString.copyFromUtf8(sb.toString())) - .setRowCount(rowCount) - .build() - response.setJsonBatch(batch) - responseObserver.onNext(response.build()) - } - - responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) - responseObserver.onCompleted() - } - def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 7d9f98b243e1..daa1c25cc8fe 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xad\x07\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12M\n\njson_batch\x18\x03 \x01(\x0b\x32,.spark.connect.ExecutePlanResponse.JSONBatchH\x00R\tjsonBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x01\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x8f\x06\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12N\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchR\narrowBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -48,7 +48,6 @@ _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"] _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"] _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"] -_EXECUTEPLANRESPONSE_JSONBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["JSONBatch"] _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"] _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[ "MetricObject" @@ -139,15 +138,6 @@ # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ArrowBatch) }, ), - "JSONBatch": _reflection.GeneratedProtocolMessageType( - "JSONBatch", - (_message.Message,), - { - "DESCRIPTOR": _EXECUTEPLANRESPONSE_JSONBATCH, - "__module__": "spark.connect.base_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.JSONBatch) - }, - ), "Metrics": _reflection.GeneratedProtocolMessageType( "Metrics", (_message.Message,), @@ -191,7 +181,6 @@ ) _sym_db.RegisterMessage(ExecutePlanResponse) _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch) -_sym_db.RegisterMessage(ExecutePlanResponse.JSONBatch) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry) @@ -219,19 +208,17 @@ _EXECUTEPLANREQUEST._serialized_start = 986 _EXECUTEPLANREQUEST._serialized_end = 1193 _EXECUTEPLANRESPONSE._serialized_start = 1196 - _EXECUTEPLANRESPONSE._serialized_end = 2137 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542 - _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122 - _SPARKCONNECTSERVICE._serialized_start = 2140 - _SPARKCONNECTSERVICE._serialized_end = 2339 + _EXECUTEPLANRESPONSE._serialized_end = 1979 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979 + _SPARKCONNECTSERVICE._serialized_start = 1982 + _SPARKCONNECTSERVICE._serialized_end = 2181 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 18b70de57a3c..64bb51d4c0b9 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -401,28 +401,6 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] ) -> None: ... - class JSONBatch(google.protobuf.message.Message): - """Message type when the result is returned as JSON. This is essentially a bulk wrapper - for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format - of `{col -> row}`. - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ROW_COUNT_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - row_count: builtins.int - data: builtins.bytes - def __init__( - self, - *, - row_count: builtins.int = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"] - ) -> None: ... - class Metrics(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -530,14 +508,11 @@ class ExecutePlanResponse(google.protobuf.message.Message): CLIENT_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int - JSON_BATCH_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property - def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ... - @property def metrics(self) -> global___ExecutePlanResponse.Metrics: """Metrics for the query execution. Typically, this field is only present in the last batch of results and then represent the overall state of the query execution. @@ -547,39 +522,17 @@ class ExecutePlanResponse(google.protobuf.message.Message): *, client_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., - json_batch: global___ExecutePlanResponse.JSONBatch | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", - ], + field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "arrow_batch", - b"arrow_batch", - "client_id", - b"client_id", - "json_batch", - b"json_batch", - "metrics", - b"metrics", - "result_type", - b"result_type", + "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics" ], ) -> None: ... - def WhichOneof( - self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ... global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a1b7c04a50fe..9e7a5f2f4a54 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -222,10 +222,12 @@ def test_create_global_temp_view(self): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_to_pandas(self): - # SPARK-XXXX: Test to pandas + # SPARK-41005: Test to pandas query = """ SELECT * FROM VALUES - (false, 1, float(NULL)), (false, NULL, float(2.0)), (NULL, 3, float(3.0)) + (false, 1, NULL), + (false, NULL, float(2.0)), + (NULL, 3, float(3.0)) AS tab(a, b, c) """ @@ -236,7 +238,9 @@ def test_to_pandas(self): query = """ SELECT * FROM VALUES - (1, 1, float(NULL)), (2, NULL, float(2.0)), (3, 3, float(3.0)) + (1, 1, NULL), + (2, NULL, float(2.0)), + (3, 3, float(3.0)) AS tab(a, b, c) """ @@ -247,7 +251,9 @@ def test_to_pandas(self): query = """ SELECT * FROM VALUES - (1.0, 1, "1"), (NULL, NULL, NULL), (2.0, 3, "3") + (double(1.0), 1, "1"), + (NULL, NULL, NULL), + (double(2.0), 3, "3") AS tab(a, b, c) """ @@ -258,7 +264,9 @@ def test_to_pandas(self): query = """ SELECT * FROM VALUES - (float(1.0), 1.0, 1, "1"), (float(2.0), 2.0, 2, "2"), (float(3.0), 2.0, 3, "3") + (float(1.0), double(1.0), 1, "1"), + (float(2.0), double(2.0), 2, "2"), + (float(3.0), double(3.0), 3, "3") AS tab(a, b, c, d) """