Skip to content

Commit 71cfba0

Browse files
committed
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
## What changes were proposed in this pull request? This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test. We declared the extra dependencies: https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204 In case of PyArrow: Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed: ``` ====================================================================== ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf return _create_udf(f=f, returnType=return_type, evalType=eval_type) File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf require_minimum_pyarrow_version() File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version "however, your version was %s." % pyarrow.__version__) ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0. ---------------------------------------------------------------------- Ran 33 tests in 8.098s FAILED (errors=33) ``` In case of Pandas: There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing. ## How was this patch tested? Manually tested by modifying the condition: ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ``` Author: hyukjinkwon <[email protected]> Closes #20487 from HyukjinKwon/pyarrow-pandas-skip.
1 parent 9775df6 commit 71cfba0

File tree

6 files changed

+89
-48
lines changed

6 files changed

+89
-48
lines changed

pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@
185185
<paranamer.version>2.8</paranamer.version>
186186
<maven-antrun.version>1.8</maven-antrun.version>
187187
<commons-crypto.version>1.0.0</commons-crypto.version>
188+
<!--
189+
If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py,
190+
./python/run-tests.py and ./python/setup.py too.
191+
-->
188192
<arrow.version>0.8.0</arrow.version>
189193

190194
<test.java.home>${java.home}</test.java.home>

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,9 @@ def toPandas(self):
19131913
0 2 Alice
19141914
1 5 Bob
19151915
"""
1916+
from pyspark.sql.utils import require_minimum_pandas_version
1917+
require_minimum_pandas_version()
1918+
19161919
import pandas as pd
19171920

19181921
if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \

python/pyspark/sql/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
646646
except Exception:
647647
has_pandas = False
648648
if has_pandas and isinstance(data, pandas.DataFrame):
649+
from pyspark.sql.utils import require_minimum_pandas_version
650+
require_minimum_pandas_version()
651+
649652
if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
650653
== "true":
651654
timezone = self.conf.get("spark.sql.session.timeZone")

python/pyspark/sql/tests.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,26 @@
4848
else:
4949
import unittest
5050

51-
_have_pandas = False
52-
_have_old_pandas = False
51+
_pandas_requirement_message = None
5352
try:
54-
import pandas
55-
try:
56-
from pyspark.sql.utils import require_minimum_pandas_version
57-
require_minimum_pandas_version()
58-
_have_pandas = True
59-
except:
60-
_have_old_pandas = True
61-
except:
62-
# No Pandas, but that's okay, we'll skip those tests
63-
pass
53+
from pyspark.sql.utils import require_minimum_pandas_version
54+
require_minimum_pandas_version()
55+
except ImportError as e:
56+
from pyspark.util import _exception_message
57+
# If Pandas version requirement is not satisfied, skip related tests.
58+
_pandas_requirement_message = _exception_message(e)
59+
60+
_pyarrow_requirement_message = None
61+
try:
62+
from pyspark.sql.utils import require_minimum_pyarrow_version
63+
require_minimum_pyarrow_version()
64+
except ImportError as e:
65+
from pyspark.util import _exception_message
66+
# If Arrow version requirement is not satisfied, skip related tests.
67+
_pyarrow_requirement_message = _exception_message(e)
68+
69+
_have_pandas = _pandas_requirement_message is None
70+
_have_pyarrow = _pyarrow_requirement_message is None
6471

6572
from pyspark import SparkContext
6673
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
@@ -75,15 +82,6 @@
7582
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
7683

7784

78-
_have_arrow = False
79-
try:
80-
import pyarrow
81-
_have_arrow = True
82-
except:
83-
# No Arrow, but that's okay, we'll skip those tests
84-
pass
85-
86-
8785
class UTCOffsetTimezone(datetime.tzinfo):
8886
"""
8987
Specifies timezone in UTC offset
@@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
27942792

27952793
def _to_pandas(self):
27962794
from datetime import datetime, date
2797-
import numpy as np
27982795
schema = StructType().add("a", IntegerType()).add("b", StringType())\
27992796
.add("c", BooleanType()).add("d", FloatType())\
28002797
.add("dt", DateType()).add("ts", TimestampType())
@@ -2807,7 +2804,7 @@ def _to_pandas(self):
28072804
df = self.spark.createDataFrame(data, schema)
28082805
return df.toPandas()
28092806

2810-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2807+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28112808
def test_to_pandas(self):
28122809
import numpy as np
28132810
pdf = self._to_pandas()
@@ -2819,13 +2816,13 @@ def test_to_pandas(self):
28192816
self.assertEquals(types[4], np.object) # datetime.date
28202817
self.assertEquals(types[5], 'datetime64[ns]')
28212818

2822-
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
2823-
def test_to_pandas_old(self):
2819+
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
2820+
def test_to_pandas_required_pandas_not_found(self):
28242821
with QuietTest(self.sc):
28252822
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
28262823
self._to_pandas()
28272824

2828-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2825+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28292826
def test_to_pandas_avoid_astype(self):
28302827
import numpy as np
28312828
schema = StructType().add("a", IntegerType()).add("b", StringType())\
@@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self):
28432840
df = self.spark.createDataFrame(data)
28442841
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
28452842

2846-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2843+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28472844
def test_create_dataframe_from_pandas_with_timestamp(self):
28482845
import pandas as pd
28492846
from datetime import datetime
@@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
28582855
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
28592856
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
28602857

2861-
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
2862-
def test_create_dataframe_from_old_pandas(self):
2863-
import pandas as pd
2864-
from datetime import datetime
2865-
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
2866-
"d": [pd.Timestamp.now().date()]})
2858+
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
2859+
def test_create_dataframe_required_pandas_not_found(self):
28672860
with QuietTest(self.sc):
2868-
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
2861+
with self.assertRaisesRegexp(
2862+
ImportError,
2863+
'(Pandas >= .* must be installed|No module named pandas)'):
2864+
import pandas as pd
2865+
from datetime import datetime
2866+
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
2867+
"d": [pd.Timestamp.now().date()]})
28692868
self.spark.createDataFrame(pdf)
28702869

28712870

@@ -3383,7 +3382,9 @@ def __init__(self, **kwargs):
33833382
_make_type_verifier(data_type, nullable=False)(obj)
33843383

33853384

3386-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3385+
@unittest.skipIf(
3386+
not _have_pandas or not _have_pyarrow,
3387+
_pandas_requirement_message or _pyarrow_requirement_message)
33873388
class ArrowTests(ReusedSQLTestCase):
33883389

33893390
@classmethod
@@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self):
36413642
self.assertEqual(pdf_col_names, df_arrow.columns)
36423643

36433644

3644-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3645+
@unittest.skipIf(
3646+
not _have_pandas or not _have_pyarrow,
3647+
_pandas_requirement_message or _pyarrow_requirement_message)
36453648
class PandasUDFTests(ReusedSQLTestCase):
36463649
def test_pandas_udf_basic(self):
36473650
from pyspark.rdd import PythonEvalType
@@ -3765,7 +3768,9 @@ def foo(k, v):
37653768
return k
37663769

37673770

3768-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3771+
@unittest.skipIf(
3772+
not _have_pandas or not _have_pyarrow,
3773+
_pandas_requirement_message or _pyarrow_requirement_message)
37693774
class ScalarPandasUDFTests(ReusedSQLTestCase):
37703775

37713776
@classmethod
@@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self):
42784283
self.assertEquals(expected.collect(), res2.collect())
42794284

42804285

4281-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
4286+
@unittest.skipIf(
4287+
not _have_pandas or not _have_pyarrow,
4288+
_pandas_requirement_message or _pyarrow_requirement_message)
42824289
class GroupedMapPandasUDFTests(ReusedSQLTestCase):
42834290

42844291
@property
@@ -4447,7 +4454,9 @@ def test_unsupported_types(self):
44474454
df.groupby('id').apply(f).collect()
44484455

44494456

4450-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
4457+
@unittest.skipIf(
4458+
not _have_pandas or not _have_pyarrow,
4459+
_pandas_requirement_message or _pyarrow_requirement_message)
44514460
class GroupedAggPandasUDFTests(ReusedSQLTestCase):
44524461

44534462
@property

python/pyspark/sql/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr):
115115
def require_minimum_pandas_version():
116116
""" Raise ImportError if minimum version of Pandas is not installed
117117
"""
118+
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
119+
minimum_pandas_version = "0.19.2"
120+
118121
from distutils.version import LooseVersion
119-
import pandas
120-
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
121-
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
122-
"however, your version was %s." % pandas.__version__)
122+
try:
123+
import pandas
124+
except ImportError:
125+
raise ImportError("Pandas >= %s must be installed; however, "
126+
"it was not found." % minimum_pandas_version)
127+
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
128+
raise ImportError("Pandas >= %s must be installed; however, "
129+
"your version was %s." % (minimum_pandas_version, pandas.__version__))
123130

124131

125132
def require_minimum_pyarrow_version():
126133
""" Raise ImportError if minimum version of pyarrow is not installed
127134
"""
135+
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
136+
minimum_pyarrow_version = "0.8.0"
137+
128138
from distutils.version import LooseVersion
129-
import pyarrow
130-
if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
131-
raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
132-
"however, your version was %s." % pyarrow.__version__)
139+
try:
140+
import pyarrow
141+
except ImportError:
142+
raise ImportError("PyArrow >= %s must be installed; however, "
143+
"it was not found." % minimum_pyarrow_version)
144+
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
145+
raise ImportError("PyArrow >= %s must be installed; however, "
146+
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))

python/setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def _supports_symlinks():
100100
file=sys.stderr)
101101
exit(-1)
102102

103+
# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
104+
# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
105+
_minimum_pandas_version = "0.19.2"
106+
_minimum_pyarrow_version = "0.8.0"
107+
103108
try:
104109
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
105110
# find it where expected. The rest of the files aren't copied because they are accessed
@@ -201,7 +206,10 @@ def _supports_symlinks():
201206
extras_require={
202207
'ml': ['numpy>=1.7'],
203208
'mllib': ['numpy>=1.7'],
204-
'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
209+
'sql': [
210+
'pandas>=%s' % _minimum_pandas_version,
211+
'pyarrow>=%s' % _minimum_pyarrow_version,
212+
]
205213
},
206214
classifiers=[
207215
'Development Status :: 5 - Production/Stable',

0 commit comments

Comments
 (0)