Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,10 @@ def _serialize_to_jvm(
data: Iterable[T],
serializer: Serializer,
reader_func: Callable,
createRDDServer: Callable,
server_func: Callable,
) -> JavaObject:
"""
Using py4j to send a large dataset to the jvm is really slow, so we use either a file
Using Py4J to send a large dataset to the jvm is slow, so we use either a file
or a socket if we have encryption enabled.

Examples
Expand All @@ -693,13 +693,13 @@ def _serialize_to_jvm(
reader_func : function
A function which takes a filename and reads in the data in the jvm and
returns a JavaRDD. Only used when encryption is disabled.
createRDDServer : function
A function which creates a PythonRDDServer in the jvm to
server_func : function
A function which creates a SocketAuthServer in the JVM to
accept the serialized data, for use when encryption is enabled.
"""
if self._encryption_enabled:
# with encryption, we open a server in java and send the data directly
server = createRDDServer()
server = server_func()
(sock_file, _) = local_connect_and_auth(server.port(), server.secret())
chunked_out = ChunkedStream(sock_file, 8192)
serializer.dump_stream(data, chunked_out)
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _create_from_pandas_with_arrow(
]

# Slice the DataFrame to be batched
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
step = self._jconf.arrowMaxRecordsPerBatch()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason of doing this is to avoid reconfiguring spark.rpc.message.maxSize. When the batch is too large, it throws an exception with complaining spark.rpc.message.maxSize is too small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was to control how many partitions were in the rdd? Each partition could have multiple batches, and probably should be capped at arrowMaxRecordsPerBatch, but since it was coming from a local Pandas DataFrame already in memory, that didn't seem to be a big deal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's true... perf diff seems trivial in any event and seems it works around the spark.rpc.message.maxSize issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was like this to create the same number of partitions as when arrow is disabled, although that might have changed since. If the DataFrame is split with arrowMaxRecordsPerBatch and a user wanted to create a certain number of partitions, then would they have to look at the size of the input and then adjust arrowMaxRecordsPerBatch accordingly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's true .. but I wonder if the default number of partitions is something we should consider given that it wasn't already configurable before, and SparkSession.createDataFrame does not expose the number of partitions too.

If they really need, users might want to create an RDD with an explicit parallelism .. we don't support this now though (see also #29719).

Copy link
Member Author

@HyukjinKwon HyukjinKwon Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, just to extra clarify, when the pandas DataFrame is small (lower than the threshold), the number of partitions remains same (configured by spark.sql.leafNodeDefaultParallelism that falls back to sparkContext.defaultParallelism if not set).

The number of partitions is only different when the input DataFrame is large, which I think makes more sense in general ..

pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))

# Create list of Arrow (columns, type) for serializer dump_stream
Expand All @@ -613,16 +613,16 @@ def _create_from_pandas_with_arrow(

@no_type_check
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename)
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)

@no_type_check
def create_RDD_server():
return self._jvm.ArrowRDDServer(jsparkSession)
def create_iter_server():
return self._jvm.ArrowIteratorServer()

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server)
assert self._jvm is not None
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsparkSession)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)
df = DataFrame(jdf, self)
df._schema = schema
return df
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,18 @@ def conf(cls):
return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")


class RDDBasedArrowTests(ArrowTests):
@classmethod
def conf(cls):
return (
super(RDDBasedArrowTests, cls)
.conf()
.set("spark.sql.execution.arrow.localRelationThreshold", "0")
# to test multiple partitions
.set("spark.sql.execution.arrow.maxRecordsPerBatch", "2")
)


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow import * # noqa: F401

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARROW_LOCAL_RELATION_THRESHOLD =
buildConf("spark.sql.execution.arrow.localRelationThreshold")
.doc(
"When converting Arrow batches to Spark DataFrame, local collections are used in the " +
"driver side if the byte size of Arrow batches is smaller than this threshold. " +
"Otherwise, the Arrow batches are sent and deserialized to Spark internal rows " +
"in the executors.")
.version("3.4.0")
.bytesConf(ByteUnit.BYTE)
.checkValue(_ >= 0, "This value must be equal to or greater than 0.")
.createWithDefaultString("48MB")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 32 MB? Don't have a strong preference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the max size of each batch or all batches together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's all together ..so pretty small


val PYSPARK_JVM_STACKTRACE_ENABLED =
buildConf("spark.sql.pyspark.jvmStacktrace.enabled")
.doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " +
Expand Down Expand Up @@ -4418,6 +4430,8 @@ class SQLConf extends Serializable with Logging {

def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)

def arrowLocalRelationThreshold: Long = getConf(ARROW_LOCAL_RELATION_THRESHOLD)

def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED)

def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
package org.apache.spark.sql.api.python

import java.io.InputStream
import java.net.Socket
import java.nio.channels.Channels
import java.util.Locale

import net.razorvine.pickle.Pickler

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDDServer
import org.apache.spark.api.python.DechunkedInputStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthServer
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
Expand Down Expand Up @@ -70,22 +70,22 @@ private[sql] object PythonSQLUtils extends Logging {
SQLConf.get.timestampType == org.apache.spark.sql.types.TimestampNTZType

/**
* Python callable function to read a file in Arrow stream format and create a [[RDD]]
* using each serialized ArrowRecordBatch as a partition.
* Python callable function to read a file in Arrow stream format and create an iterator
* of serialized ArrowRecordBatches.
*/
def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = {
ArrowConverters.readArrowStreamFromFile(session, filename)
def readArrowStreamFromFile(filename: String): Iterator[Array[Byte]] = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentionally I used Iterator to avoid Py4J copies Array into Python driver side.

ArrowConverters.readArrowStreamFromFile(filename).iterator
}

/**
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
* from an RDD.
* from the Arrow batch iterator.
*/
def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session)
ArrowConverters.toDataFrame(arrowBatches, schemaString, session)
}

def explainString(queryExecution: QueryExecution, mode: String): String = {
Expand Down Expand Up @@ -137,16 +137,16 @@ private[sql] object PythonSQLUtils extends Logging {
}

/**
* Helper for making a dataframe from arrow data from data sent from python over a socket. This is
* Helper for making a dataframe from Arrow data from data sent from python over a socket. This is
* used when encryption is enabled, and we don't want to write data to a file.
*/
private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer {

override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
// Create array to consume iterator so that we can safely close the inputStream
val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
// Parallelize the record batches to create an RDD
JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
private[spark] class ArrowIteratorServer
extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {

def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
val in = sock.getInputStream()
val dechunkedInput: InputStream = new DechunkedInputStream(in)
// Create array to consume iterator so that we can safely close the file
ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
}

}
12 changes: 10 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{Locale, Map => JMap}
import scala.collection.JavaConverters._
import scala.util.matching.Regex

import org.apache.spark.TaskContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -230,7 +231,9 @@ private[sql] object SQLUtils extends Logging {
def readArrowStreamFromFile(
sparkSession: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
ArrowConverters.readArrowStreamFromFile(sparkSession, filename)
// Parallelize the record batches to create an RDD
val batches = ArrowConverters.readArrowStreamFromFile(filename)
JavaRDD.fromRDD(sparkSession.sparkContext.parallelize(batches, batches.length))
}

/**
Expand All @@ -241,6 +244,11 @@ private[sql] object SQLUtils extends Logging {
arrowBatchRDD: JavaRDD[Array[Byte]],
schema: StructType,
sparkSession: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession)
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}

import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
Expand Down Expand Up @@ -68,7 +70,7 @@ private[sql] class ArrowBatchStreamWriter(
}
}

private[sql] object ArrowConverters {
private[sql] object ArrowConverters extends Logging {

/**
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
Expand Down Expand Up @@ -143,7 +145,7 @@ private[sql] object ArrowConverters {
new Iterator[InternalRow] {
private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty

context.addTaskCompletionListener[Unit] { _ =>
if (context != null) context.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}
Expand Down Expand Up @@ -190,32 +192,54 @@ private[sql] object ArrowConverters {
}

/**
* Create a DataFrame from an RDD of serialized ArrowRecordBatches.
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
def toDataFrame(
arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
val attrs = schema.toAttributes
val batchesInDriver = arrowBatches.toArray
val shouldUseRDD = session.sessionState.conf
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum

if (shouldUseRDD) {
logDebug("Using RDD-based createDataFrame with Arrow optimization.")
val timezone = session.sessionState.conf.sessionLocalTimeZone
val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length)
.mapPartitions { batchesInExecutors =>
ArrowConverters.fromBatchIterator(
batchesInExecutors,
schema,
timezone,
TaskContext.get())
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
} else {
logDebug("Using LocalRelation in createDataFrame with Arrow optimization.")
val data = ArrowConverters.fromBatchIterator(
batchesInDriver.toIterator,
schema,
session.sessionState.conf.sessionLocalTimeZone,
TaskContext.get())

// Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
val proj = UnsafeProjection.create(attrs, attrs)
Dataset.ofRows(session, LocalRelation(attrs, data.map(r => proj(r).copy()).toArray))
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

/**
* Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
* Read a file as an Arrow stream and return an array of serialized ArrowRecordBatches.
*/
private[sql] def readArrowStreamFromFile(
session: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
private[sql] def readArrowStreamFromFile(filename: String): Array[Array[Byte]] = {
Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
val batches = getBatchesFromStream(fileStream.getChannel).toArray
// Parallelize the record batches to create an RDD
JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
getBatchesFromStream(fileStream.getChannel).toArray
}
}

Expand Down