diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 02fe7176b6fee..6630d96f21ded 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -24,6 +24,7 @@ from typing import ( Any, + Iterator, List, Optional, Type, @@ -437,7 +438,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: return plan def _serialize_table(self) -> bytes: - assert self._table is not None + assert self._table is not None, "table cannot be None" sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, self._table.schema) as writer: batches = self._table.to_batches() @@ -449,7 +450,7 @@ def _serialize_table_chunks( self, max_chunk_size_rows: int, max_chunk_size_bytes: int, - ) -> list[bytes]: + ) -> Iterator[bytes]: """ Serialize the table into multiple chunks, each up to max_chunk_size_bytes bytes and max_chunk_size_rows rows. @@ -457,49 +458,52 @@ def _serialize_table_chunks( This method processes the table in fixed-size batches (1024 rows) for efficiency, matching the Scala implementation's batchSizeCheckInterval. + + Yields chunks one at a time to avoid materializing all chunks in memory. """ - assert self._table is not None - chunks = [] + assert self._table is not None, "table cannot be None" + assert self._table.num_rows > 0, "table must have at least one row" schema = self._table.schema - # Calculate schema serialization size once - schema_buffer = pa.BufferOutputStream() - with pa.ipc.new_stream(schema_buffer, schema): - pass # Just write schema - schema_size = len(schema_buffer.getvalue()) + # Calculate schema serialization size once (empty table = just schema) + schema_size = len(self._serialize_batches_to_ipc([], schema)) current_batches: list[pa.RecordBatch] = [] current_size = schema_size for batch in self._table.to_batches(max_chunksize=min(1024, max_chunk_size_rows)): + # Approximate batch size using raw column data (fast, ignores IPC overhead). + # Calculating the real batch size of the IPC stream would require serializing each + # batch separately, which adds overhead. batch_size = sum(arr.nbytes for arr in batch.columns) # If this batch would exceed limit and we have data, flush current chunk - if current_size > schema_size and current_size + batch_size > max_chunk_size_bytes: - combined = pa.Table.from_batches(current_batches, schema=schema) - sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, schema) as writer: - writer.write_table(combined) - chunks.append(sink.getvalue().to_pybytes()) + if len(current_batches) > 0 and current_size + batch_size > max_chunk_size_bytes: + yield self._serialize_batches_to_ipc(current_batches, schema) current_batches = [] current_size = schema_size current_batches.append(batch) current_size += batch_size - # Flush remaining batches - if current_batches: - combined = pa.Table.from_batches(current_batches, schema=schema) - sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, schema) as writer: - writer.write_table(combined) - chunks.append(sink.getvalue().to_pybytes()) + # Flush remaining batches (guaranteed to have at least one due to assertion) + yield self._serialize_batches_to_ipc(current_batches, schema) - return chunks + def _serialize_batches_to_ipc( + self, + batches: list[pa.RecordBatch], + schema: pa.Schema, + ) -> bytes: + """Helper method to serialize Arrow batches to IPC stream format.""" + combined = pa.Table.from_batches(batches, schema=schema) + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, schema) as writer: + writer.write_table(combined) + return sink.getvalue().to_pybytes() def _serialize_schema(self) -> bytes: # the server uses UTF-8 for decoding the schema - assert self._schema is not None + assert self._schema is not None, "schema cannot be None" return self._schema.encode("utf-8") def serialize(self, session: "SparkConnectClient") -> bytes: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 21a7c8329a354..ac1d1f5681e3e 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -29,6 +29,7 @@ from typing import ( Optional, Any, + Iterator, Union, Dict, List, @@ -537,6 +538,7 @@ def createDataFrame( "spark.sql.session.localRelationCacheThreshold", "spark.sql.session.localRelationChunkSizeRows", "spark.sql.session.localRelationChunkSizeBytes", + "spark.sql.session.localRelationBatchOfChunksSizeBytes", "spark.sql.execution.pandas.convertToArrowArraySafely", "spark.sql.execution.pandas.inferPandasDictAsMap", "spark.sql.pyspark.inferNestedDictAsStruct.enabled", @@ -772,10 +774,16 @@ def createDataFrame( max_chunk_size_bytes = int( configs["spark.sql.session.localRelationChunkSizeBytes"] # type: ignore[arg-type] ) + max_batch_of_chunks_size_bytes = int( + configs["spark.sql.session.localRelationBatchOfChunksSizeBytes"] # type: ignore[arg-type] # noqa: E501 + ) plan: LogicalPlan = local_relation if cache_threshold <= _table.nbytes: plan = self._cache_local_relation( - local_relation, max_chunk_size_rows, max_chunk_size_bytes + local_relation, + max_chunk_size_rows, + max_chunk_size_bytes, + max_batch_of_chunks_size_bytes, ) df = DataFrame(plan, self) @@ -1054,30 +1062,62 @@ def _cache_local_relation( local_relation: LocalRelation, max_chunk_size_rows: int, max_chunk_size_bytes: int, + max_batch_of_chunks_size_bytes: int, ) -> ChunkedCachedLocalRelation: """ Cache the local relation at the server side if it has not been cached yet. - Should only be called on LocalRelations with _table set. + This method serializes the input local relation into multiple data chunks and + a schema chunk (if the schema is available) and uploads these chunks as artifacts + to the server. + + The method collects a batch of chunks of size up to max_batch_of_chunks_size_bytes and + uploads them together to the server. + Uploading each chunk separately would require an additional RPC call for each chunk. + Uploading all chunks together would require materializing all chunks in memory which + may cause high memory usage on the client. + Uploading batches of chunks is the middle-ground solution. + + Should only be called on a LocalRelation with a non-empty _table. """ - assert local_relation._table is not None + assert local_relation._table is not None, "table cannot be None" has_schema = local_relation._schema is not None - # Serialize table into chunks - data_chunks = local_relation._serialize_table_chunks( - max_chunk_size_rows, max_chunk_size_bytes + hashes = [] + current_batch = [] + current_batch_size = 0 + if has_schema: + schema_chunk = local_relation._serialize_schema() + current_batch.append(schema_chunk) + current_batch_size += len(schema_chunk) + + data_chunks: Iterator[bytes] = local_relation._serialize_table_chunks( + max_chunk_size_rows, min(max_chunk_size_bytes, max_batch_of_chunks_size_bytes) ) - blobs = data_chunks.copy() # Start with data chunks - if has_schema: - blobs.append(local_relation._serialize_schema()) + for chunk in data_chunks: + chunk_size = len(chunk) - hashes = self._client.cache_artifacts(blobs) + # Check if adding this chunk would exceed batch size + if ( + len(current_batch) > 0 + and current_batch_size + chunk_size > max_batch_of_chunks_size_bytes + ): + hashes += self._client.cache_artifacts(current_batch) + # start a new batch + current_batch = [] + current_batch_size = 0 - # Extract data hashes and schema hash - data_hashes = hashes[: len(data_chunks)] - schema_hash = hashes[len(data_chunks)] if has_schema else None + current_batch.append(chunk) + current_batch_size += chunk_size + hashes += self._client.cache_artifacts(current_batch) + if has_schema: + schema_hash = hashes[0] + data_hashes = hashes[1:] + else: + schema_hash = None + data_hashes = hashes return ChunkedCachedLocalRelation(data_hashes, schema_hash) def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py index 14f8fbe33c8e1..79d8bf77d9d5b 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow.py +++ b/python/pyspark/sql/tests/arrow/test_arrow.py @@ -438,10 +438,12 @@ def check_cached_local_relation_changing_values(self): assert not df.filter(df["col2"].endswith(suffix)).isEmpty() def check_large_cached_local_relation_same_values(self): - data = [("C000000032", "R20", 0.2555)] * 500_000 + row_count = 500_000 + data = [("C000000032", "R20", 0.2555)] * row_count pdf = pd.DataFrame(data=data, columns=["Contrat", "Recommandation", "Distance"]) - df = self.spark.createDataFrame(pdf) - df.collect() + for _ in range(2): + df = self.spark.createDataFrame(pdf) + assert df.count() == row_count def test_toArrow_keep_utc_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index f715f8f9ed8cd..0973750c65ce4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -67,6 +67,8 @@ private[sql] object SqlApiConf { SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String = SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY + val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String = + SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String = SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String = diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala index b839caba3f547..4fcc2f4e150d1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala @@ -35,6 +35,8 @@ private[sql] object SqlApiConfHelper { val LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY: String = "spark.sql.session.localRelationChunkSizeRows" val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String = "spark.sql.session.localRelationChunkSizeBytes" + val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String = + "spark.sql.session.localRelationBatchOfChunksSizeBytes" val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = "spark.sql.execution.arrow.useLargeVarTypes" val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String = "spark.sql.parser.parserDfaCacheFlushThreshold" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 303e2f0234f62..d92a8acf1af3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6117,7 +6117,9 @@ object SQLConf { .doc("The chunk size in bytes when splitting ChunkedCachedLocalRelation.data " + "into batches. A new chunk is created when either " + "spark.sql.session.localRelationChunkSizeBytes " + - "or spark.sql.session.localRelationChunkSizeRows is reached.") + "or spark.sql.session.localRelationChunkSizeRows is reached. " + + "Limited by the spark.sql.session.localRelationBatchOfChunksSizeBytes, " + + "a minimum of the two confs is used to determine the chunk size.") .version("4.1.0") .longConf .checkValue(_ > 0, "The chunk size in bytes must be positive") @@ -6141,6 +6143,21 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("3GB") + val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES = + buildConf(SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY) + .internal() + .doc("Limit on how much memory the client can use when uploading a local relation to the " + + "server. The client collects multiple local relation chunks into a single batch in " + + "memory until the limit is reached, then uploads the batch to the server. " + + "This helps reduce memory pressure on the client when dealing with very large local " + + "relations because the client does not have to materialize all chunks in memory. " + + "Limits the spark.sql.session.localRelationChunkSizeBytes, " + + "a minimum of the two confs is used to determine the chunk size.") + .version("4.1.0") + .longConf + .checkValue(_ > 0, "The batch size in bytes must be positive") + .createWithDefault(1 * 1024 * 1024 * 1024L) + val DECORRELATE_JOIN_PREDICATE_ENABLED = buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled") .internal() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index 0d9d4e5d60f0a..daa2cc2001e42 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -118,39 +118,73 @@ class SparkSession private[sql] ( newDataset(encoder) { builder => if (data.nonEmpty) { val threshold = conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt - val maxRecordsPerBatch = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt - val maxBatchSize = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt + val maxChunkSizeRows = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt + val maxChunkSizeBytes = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt + val maxBatchOfChunksSize = + conf.get(SqlApiConf.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY).toLong + // Serialize with chunking support val it = ArrowSerializer.serialize( data, encoder, allocator, - maxRecordsPerBatch = maxRecordsPerBatch, - maxBatchSize = maxBatchSize, + maxRecordsPerBatch = maxChunkSizeRows, + maxBatchSize = math.min(maxChunkSizeBytes, maxBatchOfChunksSize), timeZoneId = timeZoneId, largeVarTypes = largeVarTypes, - batchSizeCheckInterval = math.min(1024, maxRecordsPerBatch)) + batchSizeCheckInterval = math.min(1024, maxChunkSizeRows)) + + try { + val schemaBytes = encoder.schema.json.getBytes + // Schema is the first chunk, data chunks follow from the iterator + val currentBatch = scala.collection.mutable.ArrayBuffer[Array[Byte]](schemaBytes) + var totalChunks = 1 + var currentBatchSize = schemaBytes.length.toLong + var totalSize = currentBatchSize + + // store all hashes of uploaded chunks. The first hash is schema, rest are data hashes + val allHashes = scala.collection.mutable.ArrayBuffer[String]() + while (it.hasNext) { + val chunk = it.next() + val chunkSize = chunk.length + totalChunks += 1 + totalSize += chunkSize + + // Check if adding this chunk would exceed batch size + if (currentBatchSize + chunkSize > maxBatchOfChunksSize) { + // Upload current batch + allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray) + // Start new batch + currentBatch.clear() + currentBatchSize = 0 + } - val chunks = - try { - it.toArray - } finally { - it.close() + currentBatch += chunk + currentBatchSize += chunkSize } - // If we got multiple chunks or a single large chunk, use ChunkedCachedLocalRelation - val totalSize = chunks.map(_.length).sum - if (chunks.length > 1 || totalSize > threshold) { - val (dataHashes, schemaHash) = client.cacheLocalRelation(chunks, encoder.schema.json) - builder.getChunkedCachedLocalRelationBuilder - .setSchemaHash(schemaHash) - .addAllDataHashes(dataHashes.asJava) - } else { - // Small data, use LocalRelation directly - val arrowData = ByteString.copyFrom(chunks(0)) - builder.getLocalRelationBuilder - .setSchema(encoder.schema.json) - .setData(arrowData) + // Decide whether to use LocalRelation or ChunkedCachedLocalRelation + if (totalChunks == 2 && totalSize <= threshold) { + // Schema + single small data chunk: use LocalRelation with inline data + val arrowData = ByteString.copyFrom(currentBatch.last) + builder.getLocalRelationBuilder + .setSchema(encoder.schema.json) + .setData(arrowData) + } else { + // Multiple data chunks or large data: use ChunkedCachedLocalRelation + // Upload remaining batch + allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray) + + // First hash is schema, rest are data + val schemaHash = allHashes.head + val dataHashes = allHashes.tail + + builder.getChunkedCachedLocalRelationBuilder + .setSchemaHash(schemaHash) + .addAllDataHashes(dataHashes.asJava) + } + } finally { + it.close() } } else { builder.getLocalRelationBuilder diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index e5fd16a7c2612..ee42a873787f9 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -428,26 +428,6 @@ private[sql] class SparkConnectClient( channel.shutdownNow() } - /** - * Cache the given local relation Arrow stream from a local file and return its hashes. The file - * is streamed in chunks and does not need to fit in memory. - * - * This method batches artifact status checks and uploads to minimize RPC overhead. - */ - private[sql] def cacheLocalRelation( - data: Array[Array[Byte]], - schema: String): (Seq[String], String) = { - val schemaBytes = schema.getBytes - val allBlobs = data :+ schemaBytes - val allHashes = artifactManager.cacheArtifacts(allBlobs) - - // Last hash is the schema hash, rest are data hashes - val dataHashes = allHashes.dropRight(1) - val schemaHash = allHashes.last - - (dataHashes, schemaHash) - } - /** * Clone this client session, creating a new session with the same configuration and shared * state as the current session but with independent runtime state. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8f8e6261066f4..8bc33c41b3a30 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.connect.planner import java.util.{HashMap, Properties, UUID} -import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.Try @@ -1608,7 +1607,7 @@ class SparkConnectPlanner( schemaOpt match { case None => - logical.LocalRelation(attributes, ArraySeq.unsafeWrapArray(data.map(_.copy()).toArray)) + logical.LocalRelation(attributes, data.map(_.copy()).toArray.toImmutableArraySeq) case Some(schema) => def normalize(dt: DataType): DataType = dt match { case udt: UserDefinedType[_] => normalize(udt.sqlType)