Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 81 additions & 19 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ def __init__(
retry_policy: Optional[Dict[str, Any]] = None,
use_reattachable_execute: bool = True,
session_hooks: Optional[list["SparkSession.Hook"]] = None,
allow_arrow_batch_chunking: bool = True,
preferred_arrow_chunk_size: Optional[int] = None,
):
"""
Creates a new SparkSession for the Spark Connect interface.
Expand Down Expand Up @@ -639,6 +641,21 @@ def __init__(
Enable reattachable execution.
session_hooks: list[SparkSession.Hook], optional
List of session hooks to call.
allow_arrow_batch_chunking: bool
Whether to allow the server to split large Arrow batches into smaller chunks.
Although Arrow results are split into batches with a size limit according to estimation,
the size of the batches is not guaranteed to be less than the limit, especially when a
single row is larger than the limit, in which case the server will fail to split it
further into smaller batches. As a result, the client may encounter a gRPC error stating
"Received message larger than max" when a batch is too large.
If true, the server will split large Arrow batches into smaller chunks, and the client
is expected to handle the chunked Arrow batches.
If false, the server will not chunk large Arrow batches.
preferred_arrow_chunk_size: Optional[int]
Optional preferred Arrow batch size in bytes for the server to use when sending Arrow
results.
The server will attempt to use this size if it is set and within the valid range
([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used.
"""
self.thread_local = threading.local()

Expand Down Expand Up @@ -678,6 +695,8 @@ def __init__(
self._user_id, self._session_id, self._channel, self._builder.metadata()
)
self._use_reattachable_execute = use_reattachable_execute
self._allow_arrow_batch_chunking = allow_arrow_batch_chunking
self._preferred_arrow_chunk_size = preferred_arrow_chunk_size
self._session_hooks = session_hooks or []
# Configure logging for the SparkConnect client.

Expand Down Expand Up @@ -1235,6 +1254,15 @@ def _execute_plan_request_with_metadata(
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
# Add request option to allow result chunking.
req.request_options.append(
pb2.ExecutePlanRequest.RequestOption(
result_chunking_options=pb2.ResultChunkingOptions(
allow_arrow_batch_chunking=self._allow_arrow_batch_chunking,
preferred_arrow_chunk_size=self._preferred_arrow_chunk_size,
)
)
)
if operation_id is not None:
try:
uuid.UUID(operation_id, version=4)
Expand Down Expand Up @@ -1408,6 +1436,7 @@ def _execute_and_fetch_as_iterator(
req = hook.on_execute_plan(req)

num_records = 0
arrow_batch_chunks_to_assemble: List[bytes] = []

def handle_response(
b: pb2.ExecutePlanResponse,
Expand Down Expand Up @@ -1495,32 +1524,65 @@ def handle_response(
if b.HasField("arrow_batch"):
logger.debug(
f"Received arrow batch rows={b.arrow_batch.row_count} "
f"Number of chunks in batch={b.arrow_batch.num_chunks_in_batch} "
f"Chunk index={b.arrow_batch.chunk_index} "
f"size={len(b.arrow_batch.data)}"
)

if arrow_batch_chunks_to_assemble:
# Expect next chunk of the same batch
if b.arrow_batch.chunk_index != len(arrow_batch_chunks_to_assemble):
raise SparkConnectException(
f"Expected chunk index {len(arrow_batch_chunks_to_assemble)} of the "
f"arrow batch but got {b.arrow_batch.chunk_index}."
)
else:
# Expect next batch
if (
b.arrow_batch.HasField("start_offset")
and num_records != b.arrow_batch.start_offset
):
# Expect next batch
raise SparkConnectException(
f"Expected arrow batch to start at row offset {num_records} in "
+ "results, but received arrow batch starting at offset "
+ f"{b.arrow_batch.start_offset}."
)
if b.arrow_batch.chunk_index != 0:
raise SparkConnectException(
f"Expected chunk index 0 of the next arrow batch "
f"but got {b.arrow_batch.chunk_index}."
)

arrow_batch_chunks_to_assemble.append(b.arrow_batch.data)
# Assemble the chunks to an arrow batch to process if
# (a) chunking is not enabled (num_chunks_in_batch is not set or is 0,
# in this case, it is the single chunk in the batch)
# (b) or the client has received all chunks of the batch.
if (
b.arrow_batch.HasField("start_offset")
and num_records != b.arrow_batch.start_offset
not b.arrow_batch.HasField("num_chunks_in_batch")
or b.arrow_batch.num_chunks_in_batch == 0
or len(arrow_batch_chunks_to_assemble) == b.arrow_batch.num_chunks_in_batch
):
raise SparkConnectException(
f"Expected arrow batch to start at row offset {num_records} in results, "
+ "but received arrow batch starting at offset "
+ f"{b.arrow_batch.start_offset}."
arrow_batch_data = b"".join(arrow_batch_chunks_to_assemble)
arrow_batch_chunks_to_assemble.clear()
logger.debug(
f"Assembling arrow batch of size {len(arrow_batch_data)} from "
f"{b.arrow_batch.num_chunks_in_batch} chunks."
)

num_records_in_batch = 0
with pa.ipc.open_stream(b.arrow_batch.data) as reader:
for batch in reader:
assert isinstance(batch, pa.RecordBatch)
num_records_in_batch += batch.num_rows
yield batch

if num_records_in_batch != b.arrow_batch.row_count:
raise SparkConnectException(
f"Expected {b.arrow_batch.row_count} rows in arrow batch but got "
+ f"{num_records_in_batch}."
)
num_records += num_records_in_batch
num_records_in_batch = 0
with pa.ipc.open_stream(arrow_batch_data) as reader:
for batch in reader:
assert isinstance(batch, pa.RecordBatch)
num_records_in_batch += batch.num_rows
if num_records_in_batch != b.arrow_batch.row_count:
raise SparkConnectException(
f"Expected {b.arrow_batch.row_count} rows in arrow batch but "
+ f"got {num_records_in_batch}."
)
num_records += num_records_in_batch
yield batch
if b.HasField("create_resource_profile_command_result"):
profile_id = b.create_resource_profile_command_result.profile_id
yield {"create_resource_profile_command_result": profile_id}
Expand Down
224 changes: 113 additions & 111 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

111 changes: 109 additions & 2 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1093,16 +1093,20 @@ class ExecutePlanRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

REATTACH_OPTIONS_FIELD_NUMBER: builtins.int
RESULT_CHUNKING_OPTIONS_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def reattach_options(self) -> global___ReattachOptions: ...
@property
def result_chunking_options(self) -> global___ResultChunkingOptions: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""Extension type for request options"""
def __init__(
self,
*,
reattach_options: global___ReattachOptions | None = ...,
result_chunking_options: global___ResultChunkingOptions | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -1114,6 +1118,8 @@ class ExecutePlanRequest(google.protobuf.message.Message):
b"reattach_options",
"request_option",
b"request_option",
"result_chunking_options",
b"result_chunking_options",
],
) -> builtins.bool: ...
def ClearField(
Expand All @@ -1125,11 +1131,16 @@ class ExecutePlanRequest(google.protobuf.message.Message):
b"reattach_options",
"request_option",
b"request_option",
"result_chunking_options",
b"result_chunking_options",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["request_option", b"request_option"]
) -> typing_extensions.Literal["reattach_options", "extension"] | None: ...
) -> (
typing_extensions.Literal["reattach_options", "result_chunking_options", "extension"]
| None
): ...

SESSION_ID_FIELD_NUMBER: builtins.int
CLIENT_OBSERVED_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -1308,38 +1319,78 @@ class ExecutePlanResponse(google.protobuf.message.Message):
ROW_COUNT_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
START_OFFSET_FIELD_NUMBER: builtins.int
CHUNK_INDEX_FIELD_NUMBER: builtins.int
NUM_CHUNKS_IN_BATCH_FIELD_NUMBER: builtins.int
row_count: builtins.int
"""Count rows in `data`. Must match the number of rows inside `data`."""
data: builtins.bytes
"""Serialized Arrow data."""
start_offset: builtins.int
"""If set, row offset of the start of this ArrowBatch in execution results."""
chunk_index: builtins.int
"""Index of this chunk in the batch if chunking is enabled. The index starts from 0."""
num_chunks_in_batch: builtins.int
"""Total number of chunks in this batch if chunking is enabled.
It is missing when chunking is disabled - the batch is returned whole
and client will treat this response as the batch.
"""
def __init__(
self,
*,
row_count: builtins.int = ...,
data: builtins.bytes = ...,
start_offset: builtins.int | None = ...,
chunk_index: builtins.int | None = ...,
num_chunks_in_batch: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_start_offset", b"_start_offset", "start_offset", b"start_offset"
"_chunk_index",
b"_chunk_index",
"_num_chunks_in_batch",
b"_num_chunks_in_batch",
"_start_offset",
b"_start_offset",
"chunk_index",
b"chunk_index",
"num_chunks_in_batch",
b"num_chunks_in_batch",
"start_offset",
b"start_offset",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_chunk_index",
b"_chunk_index",
"_num_chunks_in_batch",
b"_num_chunks_in_batch",
"_start_offset",
b"_start_offset",
"chunk_index",
b"chunk_index",
"data",
b"data",
"num_chunks_in_batch",
b"num_chunks_in_batch",
"row_count",
b"row_count",
"start_offset",
b"start_offset",
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_chunk_index", b"_chunk_index"]
) -> typing_extensions.Literal["chunk_index"] | None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_num_chunks_in_batch", b"_num_chunks_in_batch"],
) -> typing_extensions.Literal["num_chunks_in_batch"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_start_offset", b"_start_offset"]
) -> typing_extensions.Literal["start_offset"] | None: ...
Expand Down Expand Up @@ -2942,6 +2993,62 @@ class ReattachOptions(google.protobuf.message.Message):

global___ReattachOptions = ReattachOptions

class ResultChunkingOptions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

ALLOW_ARROW_BATCH_CHUNKING_FIELD_NUMBER: builtins.int
PREFERRED_ARROW_CHUNK_SIZE_FIELD_NUMBER: builtins.int
allow_arrow_batch_chunking: builtins.bool
"""Although Arrow results are split into batches with a size limit according to estimation, the
size of the batches is not guaranteed to be less than the limit, especially when a single row
is larger than the limit, in which case the server will fail to split it further into smaller
batches. As a result, the client may encounter a gRPC error stating “Received message larger
than max” when a batch is too large.
If allow_arrow_batch_chunking=true, the server will split large Arrow batches into smaller chunks,
and the client is expected to handle the chunked Arrow batches.

If false, the server will not chunk large Arrow batches.
"""
preferred_arrow_chunk_size: builtins.int
"""Optional preferred Arrow batch size in bytes for the server to use when sending Arrow results.
The server will attempt to use this size if it is set and within the valid range
([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used.
"""
def __init__(
self,
*,
allow_arrow_batch_chunking: builtins.bool = ...,
preferred_arrow_chunk_size: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_preferred_arrow_chunk_size",
b"_preferred_arrow_chunk_size",
"preferred_arrow_chunk_size",
b"preferred_arrow_chunk_size",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_preferred_arrow_chunk_size",
b"_preferred_arrow_chunk_size",
"allow_arrow_batch_chunking",
b"allow_arrow_batch_chunking",
"preferred_arrow_chunk_size",
b"preferred_arrow_chunk_size",
],
) -> None: ...
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_preferred_arrow_chunk_size", b"_preferred_arrow_chunk_size"
],
) -> typing_extensions.Literal["preferred_arrow_chunk_size"] | None: ...

global___ResultChunkingOptions = ResultChunkingOptions

class ReattachExecuteRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
Loading