Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ message Response {

// Result type
oneof result_type {
ArrowBatch batch = 2;
ArrowBatch arrow_batch = 2;
JSONBatch json_batch = 3;
}

Expand All @@ -80,10 +80,7 @@ message Response {
// Batch results of metrics.
message ArrowBatch {
int64 row_count = 1;
int64 uncompressed_bytes = 2;
int64 compressed_bytes = 3;
bytes data = 4;
bytes schema = 5;
bytes data = 2;
}

// Message type when the result is returned as JSON. This is essentially a bulk wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters

class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {

Expand All @@ -48,19 +49,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
}
}

def handlePlan(session: SparkSession, request: proto.Request): Unit = {
def handlePlan(session: SparkSession, request: Request): Unit = {
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
val rows =
Dataset.ofRows(session, planner.transform())
processRows(request.getClientId, rows)
val dataframe = Dataset.ofRows(session, planner.transform())
try {
processAsArrowBatches(request.getClientId, dataframe)
} catch {
case e: Exception =>
logWarning(e.getMessage)
processAsJsonBatches(request.getClientId, dataframe)
}
}

def processRows(clientId: String, rows: DataFrame): Unit = {
def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = {
// Only process up to 10MB of data.
val sb = new StringBuilder
var rowCount = 0
rows.toJSON
dataframe.toJSON
.collect()
.foreach(row => {

Expand Down Expand Up @@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
responseObserver.onNext(response.build())
}

responseObserver.onNext(sendMetricsToResponse(clientId, rows))
responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
responseObserver.onCompleted()
}

def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
val spark = dataframe.sparkSession
val schema = dataframe.schema
// TODO: control the batch size instead of max records
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this since we're already handling the max records?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#38468 (comment) suggested control the batch size < 4MB

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. I only know that gRPC has a hard limits of 2GB/s transfer rate. Never know it might not favor over large messages.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a throughput limit in GRPC itself.

The reason for the batching is that protobuf is not suited for this. Embedding large binary objects might require the reader to materialize them in memory.

Fixing this is an optimization for later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another downside of large allocations is that the GC does not really like them. All large allocation (> 1 MB) are generally placed in the old generation immediately, which requires a full GC to clean-up.

val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone

SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
val numPartitions = rows.getNumPartitions
var numSent = 0

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal { iter =>
ArrowConverters
.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
}

val signal = new Object
val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need a map? We know the number of partitions and we can just create an array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we don't need a lock.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, this is kind of async collect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just change from array to map ... see #38468 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add some comments at the beginning to explain the overall workflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we apply the same idea to JSON batches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we apply the same idea to JSON batches?

I think so, let's optimize it later


val processPartition = (iter: Iterator[Batch]) => iter.toArray

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}

spark.sparkContext.runJob(batches, processPartition, resultHandler)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it had to be (async) submitJob instead of (sync) runJob (#38468 (comment)). In fact, I figured out a simpler way to avoid synchronization. PTAL #38613


// The man thread will wait until 0-th partition is available,
// then send it to client and wait for next partition.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
while (!partitions.contains(currentPartitionId)) {
signal.wait()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it to wait the partitions to be fetched in order? I think we can just fetch all and send the first if that arrives. To optimize this, I think we should eventually do the reordering in some way to match with PySpark's implementation. Even we should deduplicate the codes ideally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, partitions can be fetched by random order. here wait for the currentPartitionId-th (start from 0) partition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am still bit confused on how does this maintain the ordering of partitions. I assume other people when they reading this code, they might be confused as well.

Is it possible to have some developer comment here to explain the algorithm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first partition arrives last, the whole dataset stays in the driver's memory, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We can look into spilling to deal with these situations.

Copy link
Member

@pan3793 pan3793 Nov 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In reduce phase, the task fetches map data of each partition in random order, w/o local sort, user still sees indeterminate data even the driver returns data by partition id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first partition arrives last, the whole dataset stays in the driver's memory, right?

yes, but at least it's not worse than existing collect which always keep whole dataset in memory.

receiving the partitions by order may make it easier to consume in the client, if ordering matters.

I think we will optimize it further, it is just an initial implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current approach still has the risk of holding all the results in the driver memory (assuming the first partition comes last), which violates the design goal of Spark Connect.

I think the Spark driver should send whichever partition that arrives to the client, and the client should allocate an array to hold arrow batches of all partitions. The client need to keep all the result in-memory anyway, so it's better to ask the client to buffer the results ad reorder them by partition id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes the client and the API more complicated. I don't want the implementors of the clients to deal with this. We can add an optimization for a dataframe that is unordered in a follow-up, but for now let's just merge the things that works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it's also important to keep client implementations simple. This "async collect" should be OK in most cases.

}
partitions.remove(currentPartitionId).get
}

partition.foreach { case (bytes, count) =>
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a boolean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we could track this as a spark metric for the query? Fine to do in a follow up if we create a jira

}

currentPartitionId += 1
}
}

// Make sure at least 1 batch will be sent.
if (numSent == 0) {
val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.ArrowBatch
.newBuilder()
.setRowCount(0L)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
}

responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
responseObserver.onCompleted()
}
}

def sendMetricsToResponse(clientId: String, rows: DataFrame): Response = {
// Send a last batch with the metrics
Response
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]:
import pandas as pd

if b.batch is not None and len(b.batch.data) > 0:
with pa.ipc.open_stream(b.batch.data) as rd:
if b.arrow_batch is not None and len(b.arrow_batch.data) > 0:
with pa.ipc.open_stream(b.arrow_batch.data) as rd:
return rd.read_pandas()
elif b.json_batch is not None and len(b.json_batch.data) > 0:
return pd.read_json(io.BytesIO(b.json_batch.data), lines=True)
Expand All @@ -400,6 +400,13 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra

if len(result_dfs) > 0:
df = pd.concat(result_dfs)

# pd.concat generates non-consecutive index like:
# Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
# set it to RangeIndex to be consistent with pyspark
n = len(df)
df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True)

# Attach the metrics to the DataFrame attributes.
if m is not None:
df.attrs["metrics"] = self._build_metrics(m)
Expand Down
36 changes: 18 additions & 18 deletions python/pyspark/sql/connect/proto/base_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
Expand All @@ -53,21 +53,21 @@
_REQUEST_USERCONTEXT._serialized_start = 429
_REQUEST_USERCONTEXT._serialized_end = 551
_RESPONSE._serialized_start = 554
_RESPONSE._serialized_end = 1522
_RESPONSE_ARROWBATCH._serialized_start = 783
_RESPONSE_ARROWBATCH._serialized_end = 958
_RESPONSE_JSONBATCH._serialized_start = 960
_RESPONSE_JSONBATCH._serialized_end = 1020
_RESPONSE_METRICS._serialized_start = 1023
_RESPONSE_METRICS._serialized_end = 1507
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1419
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1507
_ANALYZERESPONSE._serialized_start = 1525
_ANALYZERESPONSE._serialized_end = 1659
_SPARKCONNECTSERVICE._serialized_start = 1662
_SPARKCONNECTSERVICE._serialized_end = 1824
_RESPONSE._serialized_end = 1418
_RESPONSE_ARROWBATCH._serialized_start = 793
_RESPONSE_ARROWBATCH._serialized_end = 854
_RESPONSE_JSONBATCH._serialized_start = 856
_RESPONSE_JSONBATCH._serialized_end = 916
_RESPONSE_METRICS._serialized_start = 919
_RESPONSE_METRICS._serialized_end = 1403
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 1003
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1313
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1201
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1313
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1315
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1403
_ANALYZERESPONSE._serialized_start = 1421
_ANALYZERESPONSE._serialized_end = 1555
_SPARKCONNECTSERVICE._serialized_start = 1558
_SPARKCONNECTSERVICE._serialized_end = 1720
# @@protoc_insertion_point(module_scope)
39 changes: 9 additions & 30 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -178,38 +178,17 @@ class Response(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

ROW_COUNT_FIELD_NUMBER: builtins.int
UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int
COMPRESSED_BYTES_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
row_count: builtins.int
uncompressed_bytes: builtins.int
compressed_bytes: builtins.int
data: builtins.bytes
schema: builtins.bytes
def __init__(
self,
*,
row_count: builtins.int = ...,
uncompressed_bytes: builtins.int = ...,
compressed_bytes: builtins.int = ...,
data: builtins.bytes = ...,
schema: builtins.bytes = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"compressed_bytes",
b"compressed_bytes",
"data",
b"data",
"row_count",
b"row_count",
"schema",
b"schema",
"uncompressed_bytes",
b"uncompressed_bytes",
],
self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"]
) -> None: ...

class JSONBatch(google.protobuf.message.Message):
Expand Down Expand Up @@ -339,12 +318,12 @@ class Response(google.protobuf.message.Message):
) -> None: ...

CLIENT_ID_FIELD_NUMBER: builtins.int
BATCH_FIELD_NUMBER: builtins.int
ARROW_BATCH_FIELD_NUMBER: builtins.int
JSON_BATCH_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
def batch(self) -> global___Response.ArrowBatch: ...
def arrow_batch(self) -> global___Response.ArrowBatch: ...
@property
def json_batch(self) -> global___Response.JSONBatch: ...
@property
Expand All @@ -356,15 +335,15 @@ class Response(google.protobuf.message.Message):
self,
*,
client_id: builtins.str = ...,
batch: global___Response.ArrowBatch | None = ...,
arrow_batch: global___Response.ArrowBatch | None = ...,
json_batch: global___Response.JSONBatch | None = ...,
metrics: global___Response.Metrics | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"batch",
b"batch",
"arrow_batch",
b"arrow_batch",
"json_batch",
b"json_batch",
"metrics",
Expand All @@ -376,8 +355,8 @@ class Response(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"batch",
b"batch",
"arrow_batch",
b"arrow_batch",
"client_id",
b"client_id",
"json_batch",
Expand All @@ -390,7 +369,7 @@ class Response(google.protobuf.message.Message):
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["result_type", b"result_type"]
) -> typing_extensions.Literal["batch", "json_batch"] | None: ...
) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ...

global___Response = Response

Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ def test_range(self):
.equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas())
)

def test_empty_dataset(self):
# SPARK-41005: Test arrow based collection with empty dataset.
self.assertTrue(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment with a JIIRA like # SPARK-XXXX: ... (https://spark.apache.org/contributing.html)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Maintaining contribution style in this module.

self.connect.sql("SELECT 1 AS X LIMIT 0")
.toPandas()
.equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas())
)
pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas()
self.assertEqual(0, len(pdf)) # empty dataset
self.assertEqual(1, len(pdf.columns)) # one column
self.assertEqual("X", pdf.columns[0])

def test_simple_datasource_read(self) -> None:
writeDf = self.df_text
tmpPath = tempfile.mkdtemp()
Expand Down
Loading