Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from typing import (
Any,
Iterator,
List,
Optional,
Type,
Expand Down Expand Up @@ -437,7 +438,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan

def _serialize_table(self) -> bytes:
assert self._table is not None
assert self._table is not None, "table cannot be None"
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
batches = self._table.to_batches()
Expand All @@ -449,57 +450,60 @@ def _serialize_table_chunks(
self,
max_chunk_size_rows: int,
max_chunk_size_bytes: int,
) -> list[bytes]:
) -> Iterator[bytes]:
"""
Serialize the table into multiple chunks, each up to max_chunk_size_bytes bytes
and max_chunk_size_rows rows.
Each chunk is a valid Arrow IPC stream.

This method processes the table in fixed-size batches (1024 rows) for
efficiency, matching the Scala implementation's batchSizeCheckInterval.

Yields chunks one at a time to avoid materializing all chunks in memory.
"""
assert self._table is not None
chunks = []
assert self._table is not None, "table cannot be None"
assert self._table.num_rows > 0, "table must have at least one row"
schema = self._table.schema

# Calculate schema serialization size once
schema_buffer = pa.BufferOutputStream()
with pa.ipc.new_stream(schema_buffer, schema):
pass # Just write schema
schema_size = len(schema_buffer.getvalue())
# Calculate schema serialization size once (empty table = just schema)
schema_size = len(self._serialize_batches_to_ipc([], schema))

current_batches: list[pa.RecordBatch] = []
current_size = schema_size

for batch in self._table.to_batches(max_chunksize=min(1024, max_chunk_size_rows)):
# Approximate batch size using raw column data (fast, ignores IPC overhead).
# Calculating the real batch size of the IPC stream would require serializing each
# batch separately, which adds overhead.
batch_size = sum(arr.nbytes for arr in batch.columns)

# If this batch would exceed limit and we have data, flush current chunk
if current_size > schema_size and current_size + batch_size > max_chunk_size_bytes:
combined = pa.Table.from_batches(current_batches, schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, schema) as writer:
writer.write_table(combined)
chunks.append(sink.getvalue().to_pybytes())
if len(current_batches) > 0 and current_size + batch_size > max_chunk_size_bytes:
yield self._serialize_batches_to_ipc(current_batches, schema)
current_batches = []
current_size = schema_size

current_batches.append(batch)
current_size += batch_size

# Flush remaining batches
if current_batches:
combined = pa.Table.from_batches(current_batches, schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, schema) as writer:
writer.write_table(combined)
chunks.append(sink.getvalue().to_pybytes())
# Flush remaining batches (guaranteed to have at least one due to assertion)
yield self._serialize_batches_to_ipc(current_batches, schema)

return chunks
def _serialize_batches_to_ipc(
self,
batches: list[pa.RecordBatch],
schema: pa.Schema,
) -> bytes:
"""Helper method to serialize Arrow batches to IPC stream format."""
combined = pa.Table.from_batches(batches, schema=schema)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, schema) as writer:
writer.write_table(combined)
return sink.getvalue().to_pybytes()

def _serialize_schema(self) -> bytes:
# the server uses UTF-8 for decoding the schema
assert self._schema is not None
assert self._schema is not None, "schema cannot be None"
return self._schema.encode("utf-8")

def serialize(self, session: "SparkConnectClient") -> bytes:
Expand Down
66 changes: 53 additions & 13 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import (
Optional,
Any,
Iterator,
Union,
Dict,
List,
Expand Down Expand Up @@ -537,6 +538,7 @@ def createDataFrame(
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.session.localRelationChunkSizeRows",
"spark.sql.session.localRelationChunkSizeBytes",
"spark.sql.session.localRelationBatchOfChunksSizeBytes",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
Expand Down Expand Up @@ -772,10 +774,16 @@ def createDataFrame(
max_chunk_size_bytes = int(
configs["spark.sql.session.localRelationChunkSizeBytes"] # type: ignore[arg-type]
)
max_batch_of_chunks_size_bytes = int(
configs["spark.sql.session.localRelationBatchOfChunksSizeBytes"] # type: ignore[arg-type] # noqa: E501
)
plan: LogicalPlan = local_relation
if cache_threshold <= _table.nbytes:
plan = self._cache_local_relation(
local_relation, max_chunk_size_rows, max_chunk_size_bytes
local_relation,
max_chunk_size_rows,
max_chunk_size_bytes,
max_batch_of_chunks_size_bytes,
)

df = DataFrame(plan, self)
Expand Down Expand Up @@ -1054,30 +1062,62 @@ def _cache_local_relation(
local_relation: LocalRelation,
max_chunk_size_rows: int,
max_chunk_size_bytes: int,
max_batch_of_chunks_size_bytes: int,
) -> ChunkedCachedLocalRelation:
"""
Cache the local relation at the server side if it has not been cached yet.

Should only be called on LocalRelations with _table set.
This method serializes the input local relation into multiple data chunks and
a schema chunk (if the schema is available) and uploads these chunks as artifacts
to the server.

The method collects a batch of chunks of size up to max_batch_of_chunks_size_bytes and
uploads them together to the server.
Uploading each chunk separately would require an additional RPC call for each chunk.
Uploading all chunks together would require materializing all chunks in memory which
may cause high memory usage on the client.
Uploading batches of chunks is the middle-ground solution.

Should only be called on a LocalRelation with a non-empty _table.
"""
assert local_relation._table is not None
assert local_relation._table is not None, "table cannot be None"
has_schema = local_relation._schema is not None

# Serialize table into chunks
data_chunks = local_relation._serialize_table_chunks(
max_chunk_size_rows, max_chunk_size_bytes
hashes = []
current_batch = []
current_batch_size = 0
if has_schema:
schema_chunk = local_relation._serialize_schema()
current_batch.append(schema_chunk)
current_batch_size += len(schema_chunk)

data_chunks: Iterator[bytes] = local_relation._serialize_table_chunks(
max_chunk_size_rows, min(max_chunk_size_bytes, max_batch_of_chunks_size_bytes)
)
blobs = data_chunks.copy() # Start with data chunks

if has_schema:
blobs.append(local_relation._serialize_schema())
for chunk in data_chunks:
chunk_size = len(chunk)

hashes = self._client.cache_artifacts(blobs)
# Check if adding this chunk would exceed batch size
if (
len(current_batch) > 0
and current_batch_size + chunk_size > max_batch_of_chunks_size_bytes
):
hashes += self._client.cache_artifacts(current_batch)
# start a new batch
current_batch = []
current_batch_size = 0

# Extract data hashes and schema hash
data_hashes = hashes[: len(data_chunks)]
schema_hash = hashes[len(data_chunks)] if has_schema else None
current_batch.append(chunk)
current_batch_size += chunk_size
hashes += self._client.cache_artifacts(current_batch)

if has_schema:
schema_hash = hashes[0]
data_hashes = hashes[1:]
else:
schema_hash = None
data_hashes = hashes
return ChunkedCachedLocalRelation(data_hashes, schema_hash)

def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql/tests/arrow/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,12 @@ def check_cached_local_relation_changing_values(self):
assert not df.filter(df["col2"].endswith(suffix)).isEmpty()

def check_large_cached_local_relation_same_values(self):
data = [("C000000032", "R20", 0.2555)] * 500_000
row_count = 500_000
data = [("C000000032", "R20", 0.2555)] * row_count
pdf = pd.DataFrame(data=data, columns=["Contrat", "Recommandation", "Distance"])
df = self.spark.createDataFrame(pdf)
df.collect()
for _ in range(2):
df = self.spark.createDataFrame(pdf)
assert df.count() == row_count

def test_toArrow_keep_utc_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ private[sql] object SqlApiConf {
SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY
val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY
val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String =
SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY
val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ private[sql] object SqlApiConfHelper {
val LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY: String = "spark.sql.session.localRelationChunkSizeRows"
val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
"spark.sql.session.localRelationChunkSizeBytes"
val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String =
"spark.sql.session.localRelationBatchOfChunksSizeBytes"
val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = "spark.sql.execution.arrow.useLargeVarTypes"
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
"spark.sql.parser.parserDfaCacheFlushThreshold"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6117,7 +6117,9 @@ object SQLConf {
.doc("The chunk size in bytes when splitting ChunkedCachedLocalRelation.data " +
"into batches. A new chunk is created when either " +
"spark.sql.session.localRelationChunkSizeBytes " +
"or spark.sql.session.localRelationChunkSizeRows is reached.")
"or spark.sql.session.localRelationChunkSizeRows is reached. " +
"Limited by the spark.sql.session.localRelationBatchOfChunksSizeBytes, " +
"a minimum of the two confs is used to determine the chunk size.")
.version("4.1.0")
.longConf
.checkValue(_ > 0, "The chunk size in bytes must be positive")
Expand All @@ -6141,6 +6143,21 @@ object SQLConf {
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("3GB")

val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES =
buildConf(SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update the name here to be a bit more explicit in the sense that this pertains to the maximum number of bytes that we will materialise in memory for the specific local relation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(specific because multi-threading can result in multiple artifacts being materialised at once)

.internal()
.doc("Limit on how much memory the client can use when uploading a local relation to the " +
"server. The client collects multiple local relation chunks into a single batch in " +
"memory until the limit is reached, then uploads the batch to the server. " +
"This helps reduce memory pressure on the client when dealing with very large local " +
"relations because the client does not have to materialize all chunks in memory. " +
"Limits the spark.sql.session.localRelationChunkSizeBytes, " +
"a minimum of the two confs is used to determine the chunk size.")
.version("4.1.0")
.longConf
.checkValue(_ > 0, "The batch size in bytes must be positive")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if this value is greater than the chunk size value value as an initial step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of conflicts, this conf should be respected and the operation should error out as we wouldn't want to bypass an explicitly set max materialisation size (to avoid system failures)

.createWithDefault(1 * 1024 * 1024 * 1024L)

val DECORRELATE_JOIN_PREDICATE_ENABLED =
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,39 +118,73 @@ class SparkSession private[sql] (
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val threshold = conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt
val maxRecordsPerBatch = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt
val maxBatchSize = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt
val maxChunkSizeRows = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt
val maxChunkSizeBytes = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt
val maxBatchOfChunksSize =
conf.get(SqlApiConf.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY).toLong

// Serialize with chunking support
val it = ArrowSerializer.serialize(
data,
encoder,
allocator,
maxRecordsPerBatch = maxRecordsPerBatch,
maxBatchSize = maxBatchSize,
maxRecordsPerBatch = maxChunkSizeRows,
maxBatchSize = math.min(maxChunkSizeBytes, maxBatchOfChunksSize),
timeZoneId = timeZoneId,
largeVarTypes = largeVarTypes,
batchSizeCheckInterval = math.min(1024, maxRecordsPerBatch))
batchSizeCheckInterval = math.min(1024, maxChunkSizeRows))

try {
val schemaBytes = encoder.schema.json.getBytes
// Schema is the first chunk, data chunks follow from the iterator
val currentBatch = scala.collection.mutable.ArrayBuffer[Array[Byte]](schemaBytes)
var totalChunks = 1
var currentBatchSize = schemaBytes.length.toLong
var totalSize = currentBatchSize

// store all hashes of uploaded chunks. The first hash is schema, rest are data hashes
val allHashes = scala.collection.mutable.ArrayBuffer[String]()
while (it.hasNext) {
val chunk = it.next()
val chunkSize = chunk.length
totalChunks += 1
totalSize += chunkSize

// Check if adding this chunk would exceed batch size
if (currentBatchSize + chunkSize > maxBatchOfChunksSize) {
// Upload current batch
allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray)
// Start new batch
currentBatch.clear()
currentBatchSize = 0
}

val chunks =
try {
it.toArray
} finally {
it.close()
currentBatch += chunk
currentBatchSize += chunkSize
}

// If we got multiple chunks or a single large chunk, use ChunkedCachedLocalRelation
val totalSize = chunks.map(_.length).sum
if (chunks.length > 1 || totalSize > threshold) {
val (dataHashes, schemaHash) = client.cacheLocalRelation(chunks, encoder.schema.json)
builder.getChunkedCachedLocalRelationBuilder
.setSchemaHash(schemaHash)
.addAllDataHashes(dataHashes.asJava)
} else {
// Small data, use LocalRelation directly
val arrowData = ByteString.copyFrom(chunks(0))
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
// Decide whether to use LocalRelation or ChunkedCachedLocalRelation
if (totalChunks == 2 && totalSize <= threshold) {
// Schema + single small data chunk: use LocalRelation with inline data
val arrowData = ByteString.copyFrom(currentBatch.last)
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
} else {
// Multiple data chunks or large data: use ChunkedCachedLocalRelation
// Upload remaining batch
allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray)

// First hash is schema, rest are data
val schemaHash = allHashes.head
val dataHashes = allHashes.tail

builder.getChunkedCachedLocalRelationBuilder
.setSchemaHash(schemaHash)
.addAllDataHashes(dataHashes.asJava)
}
} finally {
it.close()
}
} else {
builder.getLocalRelationBuilder
Expand Down
Loading