diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 59b5fa7f3a43..11d75f4f99a7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -679,10 +679,10 @@ def _serialize_to_jvm( data: Iterable[T], serializer: Serializer, reader_func: Callable, - createRDDServer: Callable, + server_func: Callable, ) -> JavaObject: """ - Using py4j to send a large dataset to the jvm is really slow, so we use either a file + Using Py4J to send a large dataset to the jvm is slow, so we use either a file or a socket if we have encryption enabled. Examples @@ -693,13 +693,13 @@ def _serialize_to_jvm( reader_func : function A function which takes a filename and reads in the data in the jvm and returns a JavaRDD. Only used when encryption is disabled. - createRDDServer : function - A function which creates a PythonRDDServer in the jvm to + server_func : function + A function which creates a SocketAuthServer in the JVM to accept the serialized data, for use when encryption is enabled. """ if self._encryption_enabled: # with encryption, we open a server in java and send the data directly - server = createRDDServer() + server = server_func() (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) chunked_out = ChunkedStream(sock_file, 8192) serializer.dump_stream(data, chunked_out) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index fff0bac5480a..119a9bf315ca 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -596,7 +596,7 @@ def _create_from_pandas_with_arrow( ] # Slice the DataFrame to be batched - step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + step = self._jconf.arrowMaxRecordsPerBatch() pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) # Create list of Arrow (columns, type) for serializer dump_stream @@ -613,16 +613,16 @@ def _create_from_pandas_with_arrow( @no_type_check def reader_func(temp_filename): - return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename) + return self._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename) @no_type_check - def create_RDD_server(): - return self._jvm.ArrowRDDServer(jsparkSession) + def create_iter_server(): + return self._jvm.ArrowIteratorServer() # Create Spark DataFrame from Arrow stream file, using one batch per partition - jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) + jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server) assert self._jvm is not None - jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsparkSession) + jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession) df = DataFrame(jdf, self) df._schema = schema return df diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index b737848b11ac..9b1b204542b0 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -803,6 +803,18 @@ def conf(cls): return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") +class RDDBasedArrowTests(ArrowTests): + @classmethod + def conf(cls): + return ( + super(RDDBasedArrowTests, cls) + .conf() + .set("spark.sql.execution.arrow.localRelationThreshold", "0") + # to test multiple partitions + .set("spark.sql.execution.arrow.maxRecordsPerBatch", "2") + ) + + if __name__ == "__main__": from pyspark.sql.tests.test_arrow import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b8a752e90eca..4b64d91e56a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2576,6 +2576,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_LOCAL_RELATION_THRESHOLD = + buildConf("spark.sql.execution.arrow.localRelationThreshold") + .doc( + "When converting Arrow batches to Spark DataFrame, local collections are used in the " + + "driver side if the byte size of Arrow batches is smaller than this threshold. " + + "Otherwise, the Arrow batches are sent and deserialized to Spark internal rows " + + "in the executors.") + .version("3.4.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ >= 0, "This value must be equal to or greater than 0.") + .createWithDefaultString("48MB") + val PYSPARK_JVM_STACKTRACE_ENABLED = buildConf("spark.sql.pyspark.jvmStacktrace.enabled") .doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " + @@ -4418,6 +4430,8 @@ class SQLConf extends Serializable with Logging { def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) + def arrowLocalRelationThreshold: Long = getConf(ARROW_LOCAL_RELATION_THRESHOLD) + def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index fd689bf502a3..a3ba86362339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -18,15 +18,15 @@ package org.apache.spark.sql.api.python import java.io.InputStream +import java.net.Socket import java.nio.channels.Channels import java.util.Locale import net.razorvine.pickle.Pickler -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.PythonRDDServer +import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.FunctionRegistry @@ -70,22 +70,22 @@ private[sql] object PythonSQLUtils extends Logging { SQLConf.get.timestampType == org.apache.spark.sql.types.TimestampNTZType /** - * Python callable function to read a file in Arrow stream format and create a [[RDD]] - * using each serialized ArrowRecordBatch as a partition. + * Python callable function to read a file in Arrow stream format and create an iterator + * of serialized ArrowRecordBatches. */ - def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = { - ArrowConverters.readArrowStreamFromFile(session, filename) + def readArrowStreamFromFile(filename: String): Iterator[Array[Byte]] = { + ArrowConverters.readArrowStreamFromFile(filename).iterator } /** * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] - * from an RDD. + * from the Arrow batch iterator. */ def toDataFrame( - arrowBatchRDD: JavaRDD[Array[Byte]], + arrowBatches: Iterator[Array[Byte]], schemaString: String, session: SparkSession): DataFrame = { - ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session) + ArrowConverters.toDataFrame(arrowBatches, schemaString, session) } def explainString(queryExecution: QueryExecution, mode: String): String = { @@ -137,16 +137,16 @@ private[sql] object PythonSQLUtils extends Logging { } /** - * Helper for making a dataframe from arrow data from data sent from python over a socket. This is + * Helper for making a dataframe from Arrow data from data sent from python over a socket. This is * used when encryption is enabled, and we don't want to write data to a file. */ -private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer { - - override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { - // Create array to consume iterator so that we can safely close the inputStream - val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray - // Parallelize the record batches to create an RDD - JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length)) +private[spark] class ArrowIteratorServer + extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") { + + def handleConnection(sock: Socket): Iterator[Array[Byte]] = { + val in = sock.getInputStream() + val dechunkedInput: InputStream = new DechunkedInputStream(in) + // Create array to consume iterator so that we can safely close the file + ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 7831ddee4f9b..f58afcfa05d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,6 +23,7 @@ import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex +import org.apache.spark.TaskContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast @@ -230,7 +231,9 @@ private[sql] object SQLUtils extends Logging { def readArrowStreamFromFile( sparkSession: SparkSession, filename: String): JavaRDD[Array[Byte]] = { - ArrowConverters.readArrowStreamFromFile(sparkSession, filename) + // Parallelize the record batches to create an RDD + val batches = ArrowConverters.readArrowStreamFromFile(filename) + JavaRDD.fromRDD(sparkSession.sparkContext.parallelize(batches, batches.length)) } /** @@ -241,6 +244,11 @@ private[sql] object SQLUtils extends Logging { arrowBatchRDD: JavaRDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = { - ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession) + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) + } + sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 93ff276529da..bded158645cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -29,10 +29,12 @@ import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} import org.apache.spark.TaskContext -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -68,7 +70,7 @@ private[sql] class ArrowBatchStreamWriter( } } -private[sql] object ArrowConverters { +private[sql] object ArrowConverters extends Logging { /** * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size @@ -143,7 +145,7 @@ private[sql] object ArrowConverters { new Iterator[InternalRow] { private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - context.addTaskCompletionListener[Unit] { _ => + if (context != null) context.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() } @@ -190,32 +192,54 @@ private[sql] object ArrowConverters { } /** - * Create a DataFrame from an RDD of serialized ArrowRecordBatches. + * Create a DataFrame from an iterator of serialized ArrowRecordBatches. */ - private[sql] def toDataFrame( - arrowBatchRDD: JavaRDD[Array[Byte]], + /** + * Create a DataFrame from an iterator of serialized ArrowRecordBatches. + */ + def toDataFrame( + arrowBatches: Iterator[Array[Byte]], schemaString: String, session: SparkSession): DataFrame = { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - val timeZoneId = session.sessionState.conf.sessionLocalTimeZone - val rdd = arrowBatchRDD.rdd.mapPartitions { iter => - val context = TaskContext.get() - ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) + val attrs = schema.toAttributes + val batchesInDriver = arrowBatches.toArray + val shouldUseRDD = session.sessionState.conf + .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum + + if (shouldUseRDD) { + logDebug("Using RDD-based createDataFrame with Arrow optimization.") + val timezone = session.sessionState.conf.sessionLocalTimeZone + val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length) + .mapPartitions { batchesInExecutors => + ArrowConverters.fromBatchIterator( + batchesInExecutors, + schema, + timezone, + TaskContext.get()) + } + session.internalCreateDataFrame(rdd.setName("arrow"), schema) + } else { + logDebug("Using LocalRelation in createDataFrame with Arrow optimization.") + val data = ArrowConverters.fromBatchIterator( + batchesInDriver.toIterator, + schema, + session.sessionState.conf.sessionLocalTimeZone, + TaskContext.get()) + + // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out. + val proj = UnsafeProjection.create(attrs, attrs) + Dataset.ofRows(session, LocalRelation(attrs, data.map(r => proj(r).copy()).toArray)) } - session.internalCreateDataFrame(rdd.setName("arrow"), schema) } /** - * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + * Read a file as an Arrow stream and return an array of serialized ArrowRecordBatches. */ - private[sql] def readArrowStreamFromFile( - session: SparkSession, - filename: String): JavaRDD[Array[Byte]] = { + private[sql] def readArrowStreamFromFile(filename: String): Array[Array[Byte]] = { Utils.tryWithResource(new FileInputStream(filename)) { fileStream => // Create array to consume iterator so that we can safely close the file - val batches = getBatchesFromStream(fileStream.getChannel).toArray - // Parallelize the record batches to create an RDD - JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length)) + getBatchesFromStream(fileStream.getChannel).toArray } }