-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13534][PySpark] Using Apache Arrow to increase performance of DataFrame.toPandas #15821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
f681d52
afd5739
a4b958e
be508a5
5dbad22
5837b38
bdba357
d20437f
2e81a93
1ce4f2d
ed1f0fa
202650e
fbe3b7c
f44e6d7
e0bf11b
3090a3e
54884ed
42af1d5
9c8ea63
b7c28ad
2851cd6
f8f24ab
b6c752b
cbab294
44ca3ff
33b75b9
97742b8
b821077
3d786a2
cb4c510
074f66c
7260217
a0483b8
470f33d
c144667
250b581
f667a7a
76f7ddb
89dd0f4
d7cb4ab
b6fe733
36f8127
088f79e
e0449eb
b6bfcd7
2c1af59
1d471ac
a4d6057
934c147
2e4747b
b4eebc2
d49a14d
a630bf0
748e6fb
b361bdc
8bff966
f96f555
b53e09f
44d7a2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,6 +184,7 @@ | |
| <paranamer.version>2.6</paranamer.version> | ||
| <maven-antrun.version>1.8</maven-antrun.version> | ||
| <commons-crypto.version>1.0.0</commons-crypto.version> | ||
| <arrow.version>0.2.0</arrow.version> | ||
|
|
||
| <test.java.home>${java.home}</test.java.home> | ||
| <test.exclude.tags></test.exclude.tags> | ||
|
|
@@ -1871,6 +1872,29 @@ | |
| <artifactId>paranamer</artifactId> | ||
| <version>${paranamer.version}</version> | ||
| </dependency> | ||
| <dependency> | ||
| <groupId>org.apache.arrow</groupId> | ||
| <artifactId>arrow-vector</artifactId> | ||
| <version>${arrow.version}</version> | ||
| <exclusions> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider explore excluding netty here, since we exclude it in most of the other related projects (like parquet), since it seems to have added some unnecessary jars to the deps list.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added netty to exclusions and does not seem to cause any issues |
||
| <exclusion> | ||
| <groupId>com.fasterxml.jackson.core</groupId> | ||
| <artifactId>jackson-annotations</artifactId> | ||
| </exclusion> | ||
| <exclusion> | ||
| <groupId>com.fasterxml.jackson.core</groupId> | ||
| <artifactId>jackson-databind</artifactId> | ||
| </exclusion> | ||
| <exclusion> | ||
| <groupId>org.slf4j</groupId> | ||
| <artifactId>log4j-over-slf4j</artifactId> | ||
| </exclusion> | ||
| <exclusion> | ||
| <groupId>io.netty</groupId> | ||
| <artifactId>netty-handler</artifactId> | ||
| </exclusion> | ||
| </exclusions> | ||
| </dependency> | ||
| </dependencies> | ||
| </dependencyManagement> | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,6 +182,23 @@ def loads(self, obj): | |
| raise NotImplementedError | ||
|
|
||
|
|
||
| class ArrowSerializer(FramedSerializer): | ||
| """ | ||
| Serializes an Arrow stream. | ||
| """ | ||
|
|
||
| def dumps(self, obj): | ||
| raise NotImplementedError | ||
|
|
||
| def loads(self, obj): | ||
| from pyarrow import FileReader, BufferReader | ||
|
||
| reader = FileReader(BufferReader(obj)) | ||
| return reader.read_all() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we are sending multiple batches from JVM, does
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will read all batches in a framed byte array from a stream and return. The stream can have multiple framed byte arrays, so it repeats until end of stream. How many batches this reads depends on how it serialized. When calling |
||
|
|
||
| def __repr__(self): | ||
| return "ArrowSerializer" | ||
|
|
||
|
|
||
| class BatchedSerializer(Serializer): | ||
|
|
||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,8 @@ | |
|
|
||
| from pyspark import copy_func, since | ||
| from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix | ||
| from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer | ||
| from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ | ||
| UTF8Deserializer | ||
| from pyspark.storagelevel import StorageLevel | ||
| from pyspark.traceback_utils import SCCallSiteSync | ||
| from pyspark.sql.types import _parse_datatype_json_string | ||
|
|
@@ -1597,21 +1598,46 @@ def toDF(self, *cols): | |
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @since(1.3) | ||
| def toPandas(self): | ||
| """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. | ||
| def toPandas(self, useArrow=False): | ||
|
||
| """ | ||
| Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. | ||
|
|
||
| This is only available if Pandas is installed and available. | ||
|
|
||
| :param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed | ||
|
||
| and available on the calling Python process (Experimental). | ||
|
|
||
| .. note:: This method should only be used if the resulting Pandas's DataFrame is expected | ||
| to be small, as all the data is loaded into the driver's memory. | ||
|
|
||
| .. note:: Using pyarrow is experimental and currently supports the following data types: | ||
|
||
| StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, | ||
| LongType, ShortType | ||
|
|
||
| >>> df.toPandas() # doctest: +SKIP | ||
| age name | ||
| 0 2 Alice | ||
| 1 5 Bob | ||
| """ | ||
| import pandas as pd | ||
| return pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
| if useArrow: | ||
| from pyarrow.table import concat_tables | ||
|
||
| tables = self._collectAsArrow() | ||
| table = concat_tables(tables) | ||
| return table.to_pandas() | ||
| else: | ||
| import pandas as pd | ||
| return pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
|
|
||
| def _collectAsArrow(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So right now it seems we check the types in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is done on the Scala tests, the error that is thrown is: |
||
| """ | ||
| Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed | ||
| and available. | ||
|
|
||
| .. note:: Experimental. | ||
| """ | ||
| with SCCallSiteSync(self._sc) as css: | ||
| port = self._jdf.collectAsArrowToPython() | ||
| return list(_load_from_socket(port, ArrowSerializer())) | ||
|
|
||
| ########################################################################################## | ||
| # Pandas compatibility | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,15 @@ | |
| from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException | ||
|
|
||
|
|
||
| _have_arrow = False | ||
| try: | ||
| import pyarrow | ||
| _have_arrow = True | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should do similar thing above when using Arrow required feature, e.g., ArrowSerializer.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean to automatically enable the Arrow functionality if pyarrow installed? Right now it is enabled manually with a flag
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean we should throw an exception when
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe give the param doc string as exception message? I.e., |
||
| except: | ||
| # No Arrow, but that's okay, we'll skip those tests | ||
| pass | ||
|
|
||
|
|
||
| class UTCOffsetTimezone(datetime.tzinfo): | ||
| """ | ||
| Specifies timezone in UTC offset | ||
|
|
@@ -2338,6 +2347,55 @@ def range_frame_match(): | |
|
|
||
| importlib.reload(window) | ||
|
|
||
|
|
||
| @unittest.skipIf(not _have_arrow, "Arrow not installed") | ||
| class ArrowTests(ReusedPySparkTestCase): | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| ReusedPySparkTestCase.setUpClass() | ||
| cls.spark = SparkSession(cls.sc) | ||
| cls.schema = StructType([ | ||
| StructField("str_t", StringType(), True), | ||
| StructField("int_t", IntegerType(), True), | ||
| StructField("long_t", LongType(), True), | ||
| StructField("float_t", FloatType(), True), | ||
| StructField("double_t", DoubleType(), True)]) | ||
| cls.data = [("a", 1, 10, 0.2, 2.0), | ||
| ("b", 2, 20, 0.4, 4.0), | ||
| ("c", 3, 30, 0.8, 6.0)] | ||
|
|
||
| def assertFramesEqual(self, df_with_arrow, df_without): | ||
| msg = ("DataFrame from Arrow is not equal" + | ||
| ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + | ||
| ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) | ||
| self.assertTrue(df_without.equals(df_with_arrow), msg=msg) | ||
|
|
||
| def test_null_conversion(self): | ||
| df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + | ||
| self.data) | ||
| pdf = df_null.toPandas(useArrow=True) | ||
| null_counts = pdf.isnull().sum().tolist() | ||
| self.assertTrue(all([c == 1 for c in null_counts])) | ||
|
|
||
| def test_toPandas_arrow_toggle(self): | ||
| df = self.spark.createDataFrame(self.data, schema=self.schema) | ||
| # NOTE - toPandas(useArrow=False) will infer standard data types | ||
| df_sel = df.select("str_t", "long_t", "double_t") | ||
| pdf = df_sel.toPandas(useArrow=False) | ||
| pdf_arrow = df_sel.toPandas(useArrow=True) | ||
| self.assertFramesEqual(pdf_arrow, pdf) | ||
|
|
||
| def test_pandas_round_trip(self): | ||
| import pandas as pd | ||
| data_dict = {} | ||
| for j, name in enumerate(self.schema.names): | ||
| data_dict[name] = [self.data[i][j] for i in range(len(self.data))] | ||
| pdf = pd.DataFrame(data=data_dict) | ||
| pdf_arrow = self.spark.createDataFrame(pdf).toPandas(useArrow=True) | ||
| self.assertFramesEqual(pdf_arrow, pdf) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we add arrow dependency at root instead of only spark sql?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is just standard pom ..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just the dependency management section in the main pom. The only actual dependency is in spark-sql.