-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization #36683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2576,6 +2576,18 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val ARROW_LOCAL_RELATION_THRESHOLD = | ||
| buildConf("spark.sql.execution.arrow.localRelationThreshold") | ||
| .doc( | ||
| "When converting Arrow batches to Spark DataFrame, local collections are used in the " + | ||
| "driver side if the byte size of Arrow batches is smaller than this threshold. " + | ||
| "Otherwise, the Arrow batches are sent and deserialized to Spark internal rows " + | ||
| "in the executors.") | ||
| .version("3.4.0") | ||
| .bytesConf(ByteUnit.BYTE) | ||
| .checkValue(_ >= 0, "This value must be equal to or greater than 0.") | ||
| .createWithDefaultString("48MB") | ||
|
||
|
|
||
| val PYSPARK_JVM_STACKTRACE_ENABLED = | ||
| buildConf("spark.sql.pyspark.jvmStacktrace.enabled") | ||
| .doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " + | ||
|
|
@@ -4418,6 +4430,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) | ||
|
|
||
| def arrowLocalRelationThreshold: Long = getConf(ARROW_LOCAL_RELATION_THRESHOLD) | ||
|
|
||
| def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) | ||
|
|
||
| def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,15 +18,15 @@ | |
| package org.apache.spark.sql.api.python | ||
|
|
||
| import java.io.InputStream | ||
| import java.net.Socket | ||
| import java.nio.channels.Channels | ||
| import java.util.Locale | ||
|
|
||
| import net.razorvine.pickle.Pickler | ||
|
|
||
| import org.apache.spark.api.java.JavaRDD | ||
| import org.apache.spark.api.python.PythonRDDServer | ||
| import org.apache.spark.api.python.DechunkedInputStream | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.security.SocketAuthServer | ||
| import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} | ||
| import org.apache.spark.sql.catalyst.CatalystTypeConverters | ||
| import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | ||
|
|
@@ -70,22 +70,22 @@ private[sql] object PythonSQLUtils extends Logging { | |
| SQLConf.get.timestampType == org.apache.spark.sql.types.TimestampNTZType | ||
|
|
||
| /** | ||
| * Python callable function to read a file in Arrow stream format and create a [[RDD]] | ||
| * using each serialized ArrowRecordBatch as a partition. | ||
| * Python callable function to read a file in Arrow stream format and create an iterator | ||
| * of serialized ArrowRecordBatches. | ||
| */ | ||
| def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = { | ||
| ArrowConverters.readArrowStreamFromFile(session, filename) | ||
| def readArrowStreamFromFile(filename: String): Iterator[Array[Byte]] = { | ||
|
||
| ArrowConverters.readArrowStreamFromFile(filename).iterator | ||
| } | ||
|
|
||
| /** | ||
| * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] | ||
| * from an RDD. | ||
| * from the Arrow batch iterator. | ||
| */ | ||
| def toDataFrame( | ||
| arrowBatchRDD: JavaRDD[Array[Byte]], | ||
| arrowBatches: Iterator[Array[Byte]], | ||
| schemaString: String, | ||
| session: SparkSession): DataFrame = { | ||
| ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session) | ||
| ArrowConverters.toDataFrame(arrowBatches, schemaString, session) | ||
| } | ||
|
|
||
| def explainString(queryExecution: QueryExecution, mode: String): String = { | ||
|
|
@@ -137,16 +137,16 @@ private[sql] object PythonSQLUtils extends Logging { | |
| } | ||
|
|
||
| /** | ||
| * Helper for making a dataframe from arrow data from data sent from python over a socket. This is | ||
| * Helper for making a dataframe from Arrow data from data sent from python over a socket. This is | ||
| * used when encryption is enabled, and we don't want to write data to a file. | ||
| */ | ||
| private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer { | ||
|
|
||
| override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { | ||
| // Create array to consume iterator so that we can safely close the inputStream | ||
| val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray | ||
| // Parallelize the record batches to create an RDD | ||
| JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length)) | ||
| private[spark] class ArrowIteratorServer | ||
| extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") { | ||
|
|
||
| def handleConnection(sock: Socket): Iterator[Array[Byte]] = { | ||
| val in = sock.getInputStream() | ||
| val dechunkedInput: InputStream = new DechunkedInputStream(in) | ||
| // Create array to consume iterator so that we can safely close the file | ||
| ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason of doing this is to avoid reconfiguring
spark.rpc.message.maxSize. When the batch is too large, it throws an exception with complainingspark.rpc.message.maxSizeis too small.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this was to control how many partitions were in the rdd? Each partition could have multiple batches, and probably should be capped at
arrowMaxRecordsPerBatch, but since it was coming from a local Pandas DataFrame already in memory, that didn't seem to be a big deal.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's true... perf diff seems trivial in any event and seems it works around the
spark.rpc.message.maxSizeissue.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it was like this to create the same number of partitions as when arrow is disabled, although that might have changed since. If the DataFrame is split with
arrowMaxRecordsPerBatchand a user wanted to create a certain number of partitions, then would they have to look at the size of the input and then adjustarrowMaxRecordsPerBatchaccordingly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's true .. but I wonder if the default number of partitions is something we should consider given that it wasn't already configurable before, and
SparkSession.createDataFramedoes not expose the number of partitions too.If they really need, users might want to create an RDD with an explicit parallelism .. we don't support this now though (see also #29719).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, just to extra clarify, when the pandas DataFrame is small (lower than the threshold), the number of partitions remains same (configured by
spark.sql.leafNodeDefaultParallelismthat falls back tosparkContext.defaultParallelismif not set).The number of partitions is only different when the input DataFrame is large, which I think makes more sense in general ..