Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,30 @@ def is_cached_artifact(self, hash: str) -> bool:
status = resp.statuses.get(artifactName)
return status.exists if status is not None else False

def get_cached_artifacts(self, hashes: list[str]) -> set[str]:
"""
Batch check which artifacts are already cached on the server.
Returns a set of hashes that are already cached.
"""
if not hashes:
return set()

artifact_names = [f"{CACHE_PREFIX}/{hash}" for hash in hashes]
request = proto.ArtifactStatusesRequest(
user_context=self._user_context, session_id=self._session_id, names=artifact_names
)
resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
request, metadata=self._metadata
)

cached = set()
for hash in hashes:
artifact_name = f"{CACHE_PREFIX}/{hash}"
status = resp.statuses.get(artifact_name)
if status is not None and status.exists:
cached.add(hash)
return cached

def cache_artifact(self, blob: bytes) -> str:
"""
Cache the give blob at the session.
Expand All @@ -442,3 +466,34 @@ def cache_artifact(self, blob: bytes) -> str:
# TODO(SPARK-42658): Handle responses containing CRC failures.

return hash

def cache_artifacts(self, blobs: list[bytes]) -> list[str]:
"""
Cache the given blobs at the session.

This method batches artifact status checks and uploads to minimize RPC overhead.
"""
# Compute hashes for all blobs upfront
hashes = [hashlib.sha256(blob).hexdigest() for blob in blobs]
unique_hashes = list(set(hashes))

# Batch check which artifacts are already cached
cached_hashes = self.get_cached_artifacts(unique_hashes)

# Collect unique artifacts that need to be uploaded
seen_hashes = set()
artifacts_to_add = []
for blob, hash in zip(blobs, hashes):
if hash not in cached_hashes and hash not in seen_hashes:
artifacts_to_add.append(new_cache_artifact(hash, InMemory(blob)))
seen_hashes.add(hash)

# Batch upload all missing artifacts in a single RPC call
if artifacts_to_add:
requests = self._add_artifacts(artifacts_to_add)
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
for summary in response.artifacts:
summaries.append(summary)
# TODO(SPARK-42658): Handle responses containing CRC failures.
return hashes
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,12 @@ def cache_artifact(self, blob: bytes) -> str:
return self._artifact_manager.cache_artifact(blob)
raise SparkConnectException("Invalid state during retry exception handling.")

def cache_artifacts(self, blobs: list[bytes]) -> list[str]:
for attempt in self._retrying():
with attempt:
return self._artifact_manager.cache_artifacts(blobs)
raise SparkConnectException("Invalid state during retry exception handling.")

def _verify_response_integrity(
self,
response: Union[
Expand Down
91 changes: 79 additions & 12 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,78 @@ def __init__(
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._table is not None:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
for b in self._table.to_batches():
writer.write_batch(b)
plan.local_relation.data = sink.getvalue().to_pybytes()
plan.local_relation.data = self._serialize_table()

if self._schema is not None:
plan.local_relation.schema = self._schema
return plan

def _serialize_table(self) -> bytes:
assert self._table is not None
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
batches = self._table.to_batches()
for b in batches:
writer.write_batch(b)
return sink.getvalue().to_pybytes()

def _serialize_table_chunks(
self,
max_chunk_size_rows: int,
max_chunk_size_bytes: int,
) -> list[bytes]:
"""
Serialize the table into multiple chunks, each up to max_chunk_size_bytes bytes
and max_chunk_size_rows rows.
Each chunk is a valid Arrow IPC stream.

This method processes the table in fixed-size batches (1024 rows) for
efficiency, matching the Scala implementation's batchSizeCheckInterval.
"""
assert self._table is not None
chunks = []
schema = self._table.schema

# Calculate schema serialization size once
schema_buffer = pa.BufferOutputStream()
with pa.ipc.new_stream(schema_buffer, schema):
pass # Just write schema
schema_size = len(schema_buffer.getvalue())

current_batches: list[pa.RecordBatch] = []
current_size = schema_size

for batch in self._table.to_batches(max_chunksize=min(1024, max_chunk_size_rows)):
batch_size = sum(arr.nbytes for arr in batch.columns)

# If this batch would exceed limit and we have data, flush current chunk
if current_size > schema_size and current_size + batch_size > max_chunk_size_bytes:
combined = pa.Table.from_batches(current_batches, schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, schema) as writer:
writer.write_table(combined)
chunks.append(sink.getvalue().to_pybytes())
current_batches = []
current_size = schema_size

current_batches.append(batch)
current_size += batch_size

# Flush remaining batches
if current_batches:
combined = pa.Table.from_batches(current_batches, schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, schema) as writer:
writer.write_table(combined)
chunks.append(sink.getvalue().to_pybytes())

return chunks

def _serialize_schema(self) -> bytes:
# the server uses UTF-8 for decoding the schema
assert self._schema is not None
return self._schema.encode("utf-8")

def serialize(self, session: "SparkConnectClient") -> bytes:
p = self.plan(session)
return bytes(p.local_relation.SerializeToString())
Expand All @@ -454,29 +516,34 @@ def _repr_html_(self) -> str:
"""


class CachedLocalRelation(LogicalPlan):
class ChunkedCachedLocalRelation(LogicalPlan):
"""Creates a CachedLocalRelation plan object based on a hash of a LocalRelation."""

def __init__(self, hash: str) -> None:
def __init__(self, data_hashes: list[str], schema_hash: Optional[str]) -> None:
super().__init__(None)

self._hash = hash
self._data_hashes = data_hashes
self._schema_hash = schema_hash

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
clr = plan.cached_local_relation
clr = plan.chunked_cached_local_relation

clr.hash = self._hash
# Add hex string hashes directly to protobuf
for data_hash in self._data_hashes:
clr.dataHashes.append(data_hash)
if self._schema_hash is not None:
clr.schemaHash = self._schema_hash

return plan

def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<CachedLocalRelation>\n"
return f"{' ' * indent}<ChunkedCachedLocalRelation>\n"

def _repr_html_(self) -> str:
return """
<ul>
<li><b>CachedLocalRelation</b></li>
<li><b>ChunkedCachedLocalRelation</b></li>
</ul>
"""

Expand Down
338 changes: 170 additions & 168 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

58 changes: 57 additions & 1 deletion python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Relation(google.protobuf.message.Message):
TRANSPOSE_FIELD_NUMBER: builtins.int
UNRESOLVED_TABLE_VALUED_FUNCTION_FIELD_NUMBER: builtins.int
LATERAL_JOIN_FIELD_NUMBER: builtins.int
CHUNKED_CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -216,6 +217,8 @@ class Relation(google.protobuf.message.Message):
@property
def lateral_join(self) -> global___LateralJoin: ...
@property
def chunked_cached_local_relation(self) -> global___ChunkedCachedLocalRelation: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -301,6 +304,7 @@ class Relation(google.protobuf.message.Message):
transpose: global___Transpose | None = ...,
unresolved_table_valued_function: global___UnresolvedTableValuedFunction | None = ...,
lateral_join: global___LateralJoin | None = ...,
chunked_cached_local_relation: global___ChunkedCachedLocalRelation | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -334,6 +338,8 @@ class Relation(google.protobuf.message.Message):
b"cached_remote_relation",
"catalog",
b"catalog",
"chunked_cached_local_relation",
b"chunked_cached_local_relation",
"co_group_map",
b"co_group_map",
"collect_metrics",
Expand Down Expand Up @@ -459,6 +465,8 @@ class Relation(google.protobuf.message.Message):
b"cached_remote_relation",
"catalog",
b"catalog",
"chunked_cached_local_relation",
b"chunked_cached_local_relation",
"co_group_map",
b"co_group_map",
"collect_metrics",
Expand Down Expand Up @@ -614,6 +622,7 @@ class Relation(google.protobuf.message.Message):
"transpose",
"unresolved_table_valued_function",
"lateral_join",
"chunked_cached_local_relation",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -2084,7 +2093,9 @@ class LocalRelation(google.protobuf.message.Message):
global___LocalRelation = LocalRelation

class CachedLocalRelation(google.protobuf.message.Message):
"""A local relation that has been cached already."""
"""A local relation that has been cached already.
CachedLocalRelation doesn't support LocalRelations of size over 2GB.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand All @@ -2100,6 +2111,51 @@ class CachedLocalRelation(google.protobuf.message.Message):

global___CachedLocalRelation = CachedLocalRelation

class ChunkedCachedLocalRelation(google.protobuf.message.Message):
"""A local relation that has been cached already."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

DATAHASHES_FIELD_NUMBER: builtins.int
SCHEMAHASH_FIELD_NUMBER: builtins.int
@property
def dataHashes(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Required) A list of sha-256 hashes for representing LocalRelation.data.
Data is serialized in Arrow IPC streaming format, each batch is cached on the server as
a separate artifact. Each hash represents one batch stored on the server.
Hashes are hex-encoded strings (e.g., "a3b2c1d4...").
"""
schemaHash: builtins.str
"""(Optional) A sha-256 hash of the serialized LocalRelation.schema.
Scala clients always provide the schema, Python clients can omit it.
Hash is a hex-encoded string (e.g., "a3b2c1d4...").
"""
def __init__(
self,
*,
dataHashes: collections.abc.Iterable[builtins.str] | None = ...,
schemaHash: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_schemaHash", b"_schemaHash", "schemaHash", b"schemaHash"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_schemaHash", b"_schemaHash", "dataHashes", b"dataHashes", "schemaHash", b"schemaHash"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_schemaHash", b"_schemaHash"]
) -> typing_extensions.Literal["schemaHash"] | None: ...

global___ChunkedCachedLocalRelation = ChunkedCachedLocalRelation

class CachedRemoteRelation(google.protobuf.message.Message):
"""Represents a remote relation that has been cached on server."""

Expand Down
Loading