Skip to content

Commit d6632d1

Browse files
committed
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame
## What changes were proposed in this pull request? This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame. ## How was this patch tested? Manually tested and unit tests added. You can test this by: **`createDataFrame`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame(pdf, "a: map<string, int>") ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame(pdf, "a: map<string, int>") ``` **`toPandas`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` Author: hyukjinkwon <[email protected]> Closes #20678 from HyukjinKwon/SPARK-23380-conf.
1 parent 9bb239c commit d6632d1

File tree

5 files changed

+186
-58
lines changed

5 files changed

+186
-58
lines changed

docs/sql-programming-guide.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da
16891689
`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set
16901690
the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default.
16911691

1692+
In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically
1693+
to non-Arrow optimization implementation if an error occurs before the actual computation within Spark.
1694+
This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'.
1695+
16921696
<div class="codetabs">
16931697
<div data-lang="python" markdown="1">
16941698
{% include_example dataframe_with_arrow python/sql/arrow.py %}
@@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
18001804
## Upgrading From Spark SQL 2.3 to 2.4
18011805

18021806
- Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively.
1807+
- In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
18031808

18041809
## Upgrading From Spark SQL 2.2 to 2.3
18051810

python/pyspark/sql/dataframe.py

Lines changed: 78 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,55 +1992,91 @@ def toPandas(self):
19921992
timezone = None
19931993

19941994
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
1995+
use_arrow = True
19951996
try:
1996-
from pyspark.sql.types import _check_dataframe_convert_date, \
1997-
_check_dataframe_localize_timestamps, to_arrow_schema
1997+
from pyspark.sql.types import to_arrow_schema
19981998
from pyspark.sql.utils import require_minimum_pyarrow_version
1999+
19992000
require_minimum_pyarrow_version()
2000-
import pyarrow
20012001
to_arrow_schema(self.schema)
2002-
tables = self._collectAsArrow()
2003-
if tables:
2004-
table = pyarrow.concat_tables(tables)
2005-
pdf = table.to_pandas()
2006-
pdf = _check_dataframe_convert_date(pdf, self.schema)
2007-
return _check_dataframe_localize_timestamps(pdf, timezone)
2008-
else:
2009-
return pd.DataFrame.from_records([], columns=self.columns)
20102002
except Exception as e:
2011-
msg = (
2012-
"Note: toPandas attempted Arrow optimization because "
2013-
"'spark.sql.execution.arrow.enabled' is set to true. Please set it to false "
2014-
"to disable this.")
2015-
raise RuntimeError("%s\n%s" % (_exception_message(e), msg))
2016-
else:
2017-
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
20182003

2019-
dtype = {}
2004+
if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \
2005+
.lower() == "true":
2006+
msg = (
2007+
"toPandas attempted Arrow optimization because "
2008+
"'spark.sql.execution.arrow.enabled' is set to true; however, "
2009+
"failed by the reason below:\n %s\n"
2010+
"Attempts non-optimization as "
2011+
"'spark.sql.execution.arrow.fallback.enabled' is set to "
2012+
"true." % _exception_message(e))
2013+
warnings.warn(msg)
2014+
use_arrow = False
2015+
else:
2016+
msg = (
2017+
"toPandas attempted Arrow optimization because "
2018+
"'spark.sql.execution.arrow.enabled' is set to true; however, "
2019+
"failed by the reason below:\n %s\n"
2020+
"For fallback to non-optimization automatically, please set true to "
2021+
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
2022+
raise RuntimeError(msg)
2023+
2024+
# Try to use Arrow optimization when the schema is supported and the required version
2025+
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
2026+
if use_arrow:
2027+
try:
2028+
from pyspark.sql.types import _check_dataframe_convert_date, \
2029+
_check_dataframe_localize_timestamps
2030+
import pyarrow
2031+
2032+
tables = self._collectAsArrow()
2033+
if tables:
2034+
table = pyarrow.concat_tables(tables)
2035+
pdf = table.to_pandas()
2036+
pdf = _check_dataframe_convert_date(pdf, self.schema)
2037+
return _check_dataframe_localize_timestamps(pdf, timezone)
2038+
else:
2039+
return pd.DataFrame.from_records([], columns=self.columns)
2040+
except Exception as e:
2041+
# We might have to allow fallback here as well but multiple Spark jobs can
2042+
# be executed. So, simply fail in this case for now.
2043+
msg = (
2044+
"toPandas attempted Arrow optimization because "
2045+
"'spark.sql.execution.arrow.enabled' is set to true; however, "
2046+
"failed unexpectedly:\n %s\n"
2047+
"Note that 'spark.sql.execution.arrow.fallback.enabled' does "
2048+
"not have an effect in such failure in the middle of "
2049+
"computation." % _exception_message(e))
2050+
raise RuntimeError(msg)
2051+
2052+
# Below is toPandas without Arrow optimization.
2053+
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
2054+
2055+
dtype = {}
2056+
for field in self.schema:
2057+
pandas_type = _to_corrected_pandas_type(field.dataType)
2058+
# SPARK-21766: if an integer field is nullable and has null values, it can be
2059+
# inferred by pandas as float column. Once we convert the column with NaN back
2060+
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
2061+
# float type, not the corrected type from the schema in this case.
2062+
if pandas_type is not None and \
2063+
not(isinstance(field.dataType, IntegralType) and field.nullable and
2064+
pdf[field.name].isnull().any()):
2065+
dtype[field.name] = pandas_type
2066+
2067+
for f, t in dtype.items():
2068+
pdf[f] = pdf[f].astype(t, copy=False)
2069+
2070+
if timezone is None:
2071+
return pdf
2072+
else:
2073+
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
20202074
for field in self.schema:
2021-
pandas_type = _to_corrected_pandas_type(field.dataType)
2022-
# SPARK-21766: if an integer field is nullable and has null values, it can be
2023-
# inferred by pandas as float column. Once we convert the column with NaN back
2024-
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
2025-
# float type, not the corrected type from the schema in this case.
2026-
if pandas_type is not None and \
2027-
not(isinstance(field.dataType, IntegralType) and field.nullable and
2028-
pdf[field.name].isnull().any()):
2029-
dtype[field.name] = pandas_type
2030-
2031-
for f, t in dtype.items():
2032-
pdf[f] = pdf[f].astype(t, copy=False)
2033-
2034-
if timezone is None:
2035-
return pdf
2036-
else:
2037-
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
2038-
for field in self.schema:
2039-
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
2040-
if isinstance(field.dataType, TimestampType):
2041-
pdf[field.name] = \
2042-
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
2043-
return pdf
2075+
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
2076+
if isinstance(field.dataType, TimestampType):
2077+
pdf[field.name] = \
2078+
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
2079+
return pdf
20442080

20452081
def _collectAsArrow(self):
20462082
"""

python/pyspark/sql/session.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
666666
try:
667667
return self._create_from_pandas_with_arrow(data, schema, timezone)
668668
except Exception as e:
669-
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
670-
# Fallback to create DataFrame without arrow if raise some exception
669+
from pyspark.util import _exception_message
670+
671+
if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \
672+
.lower() == "true":
673+
msg = (
674+
"createDataFrame attempted Arrow optimization because "
675+
"'spark.sql.execution.arrow.enabled' is set to true; however, "
676+
"failed by the reason below:\n %s\n"
677+
"Attempts non-optimization as "
678+
"'spark.sql.execution.arrow.fallback.enabled' is set to "
679+
"true." % _exception_message(e))
680+
warnings.warn(msg)
681+
else:
682+
msg = (
683+
"createDataFrame attempted Arrow optimization because "
684+
"'spark.sql.execution.arrow.enabled' is set to true; however, "
685+
"failed by the reason below:\n %s\n"
686+
"For fallback to non-optimization automatically, please set true to "
687+
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
688+
raise RuntimeError(msg)
671689
data = self._convert_from_pandas(data, schema, timezone)
672690

673691
if isinstance(schema, StructType):

python/pyspark/sql/tests.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
import datetime
3333
import array
3434
import ctypes
35+
import warnings
3536
import py4j
37+
from contextlib import contextmanager
3638

3739
try:
3840
import xmlrunner
@@ -48,12 +50,13 @@
4850
else:
4951
import unittest
5052

53+
from pyspark.util import _exception_message
54+
5155
_pandas_requirement_message = None
5256
try:
5357
from pyspark.sql.utils import require_minimum_pandas_version
5458
require_minimum_pandas_version()
5559
except ImportError as e:
56-
from pyspark.util import _exception_message
5760
# If Pandas version requirement is not satisfied, skip related tests.
5861
_pandas_requirement_message = _exception_message(e)
5962

@@ -62,7 +65,6 @@
6265
from pyspark.sql.utils import require_minimum_pyarrow_version
6366
require_minimum_pyarrow_version()
6467
except ImportError as e:
65-
from pyspark.util import _exception_message
6668
# If Arrow version requirement is not satisfied, skip related tests.
6769
_pyarrow_requirement_message = _exception_message(e)
6870

@@ -195,6 +197,28 @@ def tearDownClass(cls):
195197
ReusedPySparkTestCase.tearDownClass()
196198
cls.spark.stop()
197199

200+
@contextmanager
201+
def sql_conf(self, pairs):
202+
"""
203+
A convenient context manager to test some configuration specific logic. This sets
204+
`value` to the configuration `key` and then restores it back when it exits.
205+
"""
206+
assert isinstance(pairs, dict), "pairs should be a dictionary."
207+
208+
keys = pairs.keys()
209+
new_values = pairs.values()
210+
old_values = [self.spark.conf.get(key, None) for key in keys]
211+
for key, new_value in zip(keys, new_values):
212+
self.spark.conf.set(key, new_value)
213+
try:
214+
yield
215+
finally:
216+
for key, old_value in zip(keys, old_values):
217+
if old_value is None:
218+
self.spark.conf.unset(key)
219+
else:
220+
self.spark.conf.set(key, old_value)
221+
198222
def assertPandasEqual(self, expected, result):
199223
msg = ("DataFrames are not equal: " +
200224
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
@@ -3458,6 +3482,8 @@ def setUpClass(cls):
34583482

34593483
cls.spark.conf.set("spark.sql.session.timeZone", tz)
34603484
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
3485+
# Disable fallback by default to easily detect the failures.
3486+
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
34613487
cls.schema = StructType([
34623488
StructField("1_str_t", StringType(), True),
34633489
StructField("2_int_t", IntegerType(), True),
@@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self):
34933519
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
34943520
return pd.DataFrame(data=data_dict)
34953521

3496-
def test_unsupported_datatype(self):
3522+
def test_toPandas_fallback_enabled(self):
3523+
import pandas as pd
3524+
3525+
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
3526+
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
3527+
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
3528+
with QuietTest(self.sc):
3529+
with warnings.catch_warnings(record=True) as warns:
3530+
pdf = df.toPandas()
3531+
# Catch and check the last UserWarning.
3532+
user_warns = [
3533+
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
3534+
self.assertTrue(len(user_warns) > 0)
3535+
self.assertTrue(
3536+
"Attempts non-optimization" in _exception_message(user_warns[-1]))
3537+
self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
3538+
3539+
def test_toPandas_fallback_disabled(self):
34973540
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
34983541
df = self.spark.createDataFrame([(None,)], schema=schema)
34993542
with QuietTest(self.sc):
35003543
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
35013544
df.toPandas()
35023545

3503-
df = self.spark.createDataFrame([(None,)], schema="a binary")
3504-
with QuietTest(self.sc):
3505-
with self.assertRaisesRegexp(
3506-
Exception,
3507-
'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'):
3508-
df.toPandas()
3509-
35103546
def test_null_conversion(self):
35113547
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
35123548
self.data)
@@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self):
36253661
pdf = self.create_pandas_data_frame()
36263662
wrong_schema = StructType(list(reversed(self.schema)))
36273663
with QuietTest(self.sc):
3628-
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
3664+
with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"):
36293665
self.spark.createDataFrame(pdf, schema=wrong_schema)
36303666

36313667
def test_createDataFrame_with_names(self):
@@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self):
36503686
def test_createDataFrame_with_single_data_type(self):
36513687
import pandas as pd
36523688
with QuietTest(self.sc):
3653-
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
3689+
with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"):
36543690
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
36553691

36563692
def test_createDataFrame_does_not_modify_input(self):
@@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self):
37053741
self.assertEqual(pdf_col_names, df.columns)
37063742
self.assertEqual(pdf_col_names, df_arrow.columns)
37073743

3744+
def test_createDataFrame_fallback_enabled(self):
3745+
import pandas as pd
3746+
3747+
with QuietTest(self.sc):
3748+
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
3749+
with warnings.catch_warnings(record=True) as warns:
3750+
df = self.spark.createDataFrame(
3751+
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
3752+
# Catch and check the last UserWarning.
3753+
user_warns = [
3754+
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
3755+
self.assertTrue(len(user_warns) > 0)
3756+
self.assertTrue(
3757+
"Attempts non-optimization" in _exception_message(user_warns[-1]))
3758+
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
3759+
3760+
def test_createDataFrame_fallback_disabled(self):
3761+
import pandas as pd
3762+
3763+
with QuietTest(self.sc):
3764+
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
3765+
self.spark.createDataFrame(
3766+
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
3767+
37083768
# Regression test for SPARK-23314
37093769
def test_timestamp_dst(self):
37103770
import pandas as pd

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ object SQLConf {
10581058
.intConf
10591059
.createWithDefault(100)
10601060

1061-
val ARROW_EXECUTION_ENABLE =
1061+
val ARROW_EXECUTION_ENABLED =
10621062
buildConf("spark.sql.execution.arrow.enabled")
10631063
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
10641064
"for use with pyspark.sql.DataFrame.toPandas, and " +
@@ -1068,6 +1068,13 @@ object SQLConf {
10681068
.booleanConf
10691069
.createWithDefault(false)
10701070

1071+
val ARROW_FALLBACK_ENABLED =
1072+
buildConf("spark.sql.execution.arrow.fallback.enabled")
1073+
.doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " +
1074+
"fallback automatically to non-optimized implementations if an error occurs.")
1075+
.booleanConf
1076+
.createWithDefault(true)
1077+
10711078
val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
10721079
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
10731080
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
@@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging {
15181525

15191526
def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
15201527

1521-
def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
1528+
def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED)
1529+
1530+
def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED)
15221531

15231532
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
15241533

0 commit comments

Comments
 (0)