Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,16 +493,15 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, \
_old_pandas_exception_message, TimestampType
from pyspark.sql.utils import _require_minimum_pyarrow_version
try:
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pyspark.sql.utils import _require_minimum_pandas_version, \
_require_minimum_pyarrow_version

_require_minimum_pandas_version()
_require_minimum_pyarrow_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
try:
import pandas
try:
import pandas.api
from pyspark.sql.utils import _require_minimum_pandas_version
_require_minimum_pandas_version()
_have_pandas = True
except:
_have_old_pandas = True
Expand Down Expand Up @@ -2600,7 +2601,7 @@ def test_to_pandas(self):
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_to_pandas_old(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
Expand Down Expand Up @@ -2643,7 +2644,7 @@ def test_create_dataframe_from_old_pandas(self):
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self.spark.createDataFrame(pdf)


Expand Down
33 changes: 13 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,13 +1678,6 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _old_pandas_exception_message(e):
""" Create an error message for importing old Pandas.
"""
msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process"
return "%s\n%s" % (_exception_message(e), msg)


def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
Expand All @@ -1693,10 +1686,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
try:
from pandas.api.types import is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import _require_minimum_pandas_version
_require_minimum_pandas_version()

from pandas.api.types import is_datetime64tz_dtype
tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
Expand All @@ -1714,10 +1707,10 @@ def _check_series_convert_timestamps_internal(s, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
try:
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import _require_minimum_pandas_version
_require_minimum_pandas_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
tz = timezone or 'tzlocal()'
Expand All @@ -1737,11 +1730,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
try:
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import _require_minimum_pandas_version
_require_minimum_pandas_version()

import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or 'tzlocal()'
to_tz = to_timezone or 'tzlocal()'
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def toJArray(gateway, jtype, arr):
return jarr


def _require_minimum_pandas_version():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh. BTW, I think we can just name it _require_minimum_pandas_version -> require_minimum_pandas_version and _require_minimum_pyarrow_version -> require_minimum_pyarrow_version. It's fine as is too to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll update them. Thanks!

""" Raise ImportError if minimum version of Pandas is not installed
"""
from distutils.version import LooseVersion
import pandas
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process")


def _require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
Expand Down