Skip to content

Commit ff9ba5e

Browse files
committed
[SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test)
This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test. We declared the extra dependencies: https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204 In case of PyArrow: Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed: ``` ====================================================================== ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf return _create_udf(f=f, returnType=return_type, evalType=eval_type) File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf require_minimum_pyarrow_version() File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version "however, your version was %s." % pyarrow.__version__) ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0. ---------------------------------------------------------------------- Ran 33 tests in 8.098s FAILED (errors=33) ``` In case of Pandas: There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing. Manually tested by modifying the condition: ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ``` Author: hyukjinkwon <[email protected]> Closes apache#20487 from HyukjinKwon/pyarrow-pandas-skip. (cherry picked from commit 71cfba0) Signed-off-by: hyukjinkwon <[email protected]>
1 parent cb22e83 commit ff9ba5e

File tree

6 files changed

+86
-47
lines changed

6 files changed

+86
-47
lines changed

pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@
185185
<paranamer.version>2.8</paranamer.version>
186186
<maven-antrun.version>1.8</maven-antrun.version>
187187
<commons-crypto.version>1.0.0</commons-crypto.version>
188+
<!--
189+
If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py,
190+
./python/run-tests.py and ./python/setup.py too.
191+
-->
188192
<arrow.version>0.8.0</arrow.version>
189193

190194
<test.java.home>${java.home}</test.java.home>

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,9 @@ def toPandas(self):
19131913
0 2 Alice
19141914
1 5 Bob
19151915
"""
1916+
from pyspark.sql.utils import require_minimum_pandas_version
1917+
require_minimum_pandas_version()
1918+
19161919
import pandas as pd
19171920

19181921
if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \

python/pyspark/sql/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
640640
except Exception:
641641
has_pandas = False
642642
if has_pandas and isinstance(data, pandas.DataFrame):
643+
from pyspark.sql.utils import require_minimum_pandas_version
644+
require_minimum_pandas_version()
645+
643646
if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
644647
== "true":
645648
timezone = self.conf.get("spark.sql.session.timeZone")

python/pyspark/sql/tests.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,26 @@
4848
else:
4949
import unittest
5050

51-
_have_pandas = False
52-
_have_old_pandas = False
51+
_pandas_requirement_message = None
5352
try:
54-
import pandas
55-
try:
56-
from pyspark.sql.utils import require_minimum_pandas_version
57-
require_minimum_pandas_version()
58-
_have_pandas = True
59-
except:
60-
_have_old_pandas = True
61-
except:
62-
# No Pandas, but that's okay, we'll skip those tests
63-
pass
53+
from pyspark.sql.utils import require_minimum_pandas_version
54+
require_minimum_pandas_version()
55+
except ImportError as e:
56+
from pyspark.util import _exception_message
57+
# If Pandas version requirement is not satisfied, skip related tests.
58+
_pandas_requirement_message = _exception_message(e)
59+
60+
_pyarrow_requirement_message = None
61+
try:
62+
from pyspark.sql.utils import require_minimum_pyarrow_version
63+
require_minimum_pyarrow_version()
64+
except ImportError as e:
65+
from pyspark.util import _exception_message
66+
# If Arrow version requirement is not satisfied, skip related tests.
67+
_pyarrow_requirement_message = _exception_message(e)
68+
69+
_have_pandas = _pandas_requirement_message is None
70+
_have_pyarrow = _pyarrow_requirement_message is None
6471

6572
from pyspark import SparkContext
6673
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
@@ -75,15 +82,6 @@
7582
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
7683

7784

78-
_have_arrow = False
79-
try:
80-
import pyarrow
81-
_have_arrow = True
82-
except:
83-
# No Arrow, but that's okay, we'll skip those tests
84-
pass
85-
86-
8785
class UTCOffsetTimezone(datetime.tzinfo):
8886
"""
8987
Specifies timezone in UTC offset
@@ -2788,7 +2786,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
27882786

27892787
def _to_pandas(self):
27902788
from datetime import datetime, date
2791-
import numpy as np
27922789
schema = StructType().add("a", IntegerType()).add("b", StringType())\
27932790
.add("c", BooleanType()).add("d", FloatType())\
27942791
.add("dt", DateType()).add("ts", TimestampType())
@@ -2801,7 +2798,7 @@ def _to_pandas(self):
28012798
df = self.spark.createDataFrame(data, schema)
28022799
return df.toPandas()
28032800

2804-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2801+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28052802
def test_to_pandas(self):
28062803
import numpy as np
28072804
pdf = self._to_pandas()
@@ -2813,13 +2810,13 @@ def test_to_pandas(self):
28132810
self.assertEquals(types[4], np.object) # datetime.date
28142811
self.assertEquals(types[5], 'datetime64[ns]')
28152812

2816-
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
2817-
def test_to_pandas_old(self):
2813+
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
2814+
def test_to_pandas_required_pandas_not_found(self):
28182815
with QuietTest(self.sc):
28192816
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
28202817
self._to_pandas()
28212818

2822-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2819+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28232820
def test_to_pandas_avoid_astype(self):
28242821
import numpy as np
28252822
schema = StructType().add("a", IntegerType()).add("b", StringType())\
@@ -2837,7 +2834,7 @@ def test_create_dataframe_from_array_of_long(self):
28372834
df = self.spark.createDataFrame(data)
28382835
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
28392836

2840-
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2837+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
28412838
def test_create_dataframe_from_pandas_with_timestamp(self):
28422839
import pandas as pd
28432840
from datetime import datetime
@@ -2852,14 +2849,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
28522849
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
28532850
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
28542851

2855-
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
2856-
def test_create_dataframe_from_old_pandas(self):
2857-
import pandas as pd
2858-
from datetime import datetime
2859-
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
2860-
"d": [pd.Timestamp.now().date()]})
2852+
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
2853+
def test_create_dataframe_required_pandas_not_found(self):
28612854
with QuietTest(self.sc):
2862-
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
2855+
with self.assertRaisesRegexp(
2856+
ImportError,
2857+
'(Pandas >= .* must be installed|No module named pandas)'):
2858+
import pandas as pd
2859+
from datetime import datetime
2860+
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
2861+
"d": [pd.Timestamp.now().date()]})
28632862
self.spark.createDataFrame(pdf)
28642863

28652864

@@ -3351,7 +3350,9 @@ def __init__(self, **kwargs):
33513350
_make_type_verifier(data_type, nullable=False)(obj)
33523351

33533352

3354-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3353+
@unittest.skipIf(
3354+
not _have_pandas or not _have_pyarrow,
3355+
_pandas_requirement_message or _pyarrow_requirement_message)
33553356
class ArrowTests(ReusedSQLTestCase):
33563357

33573358
@classmethod
@@ -3615,7 +3616,9 @@ def test_createDataFrame_with_int_col_names(self):
36153616
self.assertEqual(pdf_col_names, df_arrow.columns)
36163617

36173618

3618-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3619+
@unittest.skipIf(
3620+
not _have_pandas or not _have_pyarrow,
3621+
_pandas_requirement_message or _pyarrow_requirement_message)
36193622
class PandasUDFTests(ReusedSQLTestCase):
36203623
def test_pandas_udf_basic(self):
36213624
from pyspark.rdd import PythonEvalType
@@ -3739,7 +3742,9 @@ def foo(k, v):
37393742
return k
37403743

37413744

3742-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3745+
@unittest.skipIf(
3746+
not _have_pandas or not _have_pyarrow,
3747+
_pandas_requirement_message or _pyarrow_requirement_message)
37433748
class ScalarPandasUDFTests(ReusedSQLTestCase):
37443749

37453750
@classmethod
@@ -4252,7 +4257,9 @@ def test_register_vectorized_udf_basic(self):
42524257
self.assertEquals(expected.collect(), res2.collect())
42534258

42544259

4255-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
4260+
@unittest.skipIf(
4261+
not _have_pandas or not _have_pyarrow,
4262+
_pandas_requirement_message or _pyarrow_requirement_message)
42564263
class GroupedMapPandasUDFTests(ReusedSQLTestCase):
42574264

42584265
def assertFramesEqual(self, expected, result):

python/pyspark/sql/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr):
115115
def require_minimum_pandas_version():
116116
""" Raise ImportError if minimum version of Pandas is not installed
117117
"""
118+
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
119+
minimum_pandas_version = "0.19.2"
120+
118121
from distutils.version import LooseVersion
119-
import pandas
120-
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
121-
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
122-
"however, your version was %s." % pandas.__version__)
122+
try:
123+
import pandas
124+
except ImportError:
125+
raise ImportError("Pandas >= %s must be installed; however, "
126+
"it was not found." % minimum_pandas_version)
127+
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
128+
raise ImportError("Pandas >= %s must be installed; however, "
129+
"your version was %s." % (minimum_pandas_version, pandas.__version__))
123130

124131

125132
def require_minimum_pyarrow_version():
126133
""" Raise ImportError if minimum version of pyarrow is not installed
127134
"""
135+
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
136+
minimum_pyarrow_version = "0.8.0"
137+
128138
from distutils.version import LooseVersion
129-
import pyarrow
130-
if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
131-
raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
132-
"however, your version was %s." % pyarrow.__version__)
139+
try:
140+
import pyarrow
141+
except ImportError:
142+
raise ImportError("PyArrow >= %s must be installed; however, "
143+
"it was not found." % minimum_pyarrow_version)
144+
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
145+
raise ImportError("PyArrow >= %s must be installed; however, "
146+
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))

python/setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def _supports_symlinks():
100100
file=sys.stderr)
101101
exit(-1)
102102

103+
# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
104+
# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
105+
_minimum_pandas_version = "0.19.2"
106+
_minimum_pyarrow_version = "0.8.0"
107+
103108
try:
104109
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
105110
# find it where expected. The rest of the files aren't copied because they are accessed
@@ -201,7 +206,10 @@ def _supports_symlinks():
201206
extras_require={
202207
'ml': ['numpy>=1.7'],
203208
'mllib': ['numpy>=1.7'],
204-
'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
209+
'sql': [
210+
'pandas>=%s' % _minimum_pandas_version,
211+
'pyarrow>=%s' % _minimum_pyarrow_version,
212+
]
205213
},
206214
classifiers=[
207215
'Development Status :: 5 - Production/Stable',

0 commit comments

Comments
 (0)