-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion #18664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 36 commits
5aa8b9e
20313f9
69e1e21
dbfbef3
436afff
78119ca
b709d78
399e527
e6d8590
719e77c
3585520
f977d0b
b826445
3b83d7a
a6009a5
2ec98cc
c29018c
7dbdb1f
c3f4e4d
ddbea24
c6b597d
874f104
d8bae0b
36f58b1
c4fd5ae
d1617fd
d7d9b47
efe3e27
9894519
a3ba4ac
7266304
e428cbe
cade921
f512deb
171d9e1
79bb93f
c555207
4d40893
addd35f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -214,6 +214,7 @@ def __repr__(self): | |
|
|
||
|
|
||
| def _create_batch(series): | ||
| from pyspark.sql.types import _series_convert_timestamps_internal | ||
| import pyarrow as pa | ||
| # Make input conform to [(series1, type1), (series2, type2), ...] | ||
| if not isinstance(series, (list, tuple)) or \ | ||
|
|
@@ -224,7 +225,13 @@ def _create_batch(series): | |
| # If a nullable integer series has been promoted to floating point with NaNs, need to cast | ||
| # NOTE: this is not necessary with Arrow >= 0.7 | ||
| def cast_series(s, t): | ||
| if t is None or s.dtype == t.to_pandas_dtype(): | ||
| if type(t) == pa.TimestampType: | ||
| # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 | ||
| return _series_convert_timestamps_internal(s).values.astype('datetime64[us]') | ||
| elif t == pa.date32(): | ||
| # TODO: ValueError: Cannot cast DatetimeIndex to dtype datetime64[D] | ||
|
||
| return s.dt.values.astype('datetime64[D]') | ||
| elif t is None or s.dtype == t.to_pandas_dtype(): | ||
| return s | ||
| else: | ||
| return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) | ||
|
|
@@ -260,11 +267,13 @@ def load_stream(self, stream): | |
| """ | ||
| Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. | ||
| """ | ||
| from pyspark.sql.types import _check_dataframe_localize_timestamps | ||
| import pyarrow as pa | ||
| reader = pa.open_stream(stream) | ||
| for batch in reader: | ||
| table = pa.Table.from_batches([batch]) | ||
| yield [c.to_pandas() for c in table.itercolumns()] | ||
| # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 | ||
| pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) | ||
| yield [c for _, c in pdf.iteritems()] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After running some tests, this change does not significantly degrade performance, but there seems to be a small difference. cc @ueshin I ran various columns of random data through a before change: 2.595558 Do you think the difference here is acceptable for now until arrow is upgraded and we can look into again?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran your script in my local, too.
I think it's okay to use this workaround. |
||
|
|
||
| def __repr__(self): | ||
| return "ArrowStreamPandasSerializer" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1880,11 +1880,13 @@ def toPandas(self): | |
| import pandas as pd | ||
| if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": | ||
| try: | ||
| from pyspark.sql.types import _check_dataframe_localize_timestamps | ||
| import pyarrow | ||
| tables = self._collectAsArrow() | ||
| if tables: | ||
| table = pyarrow.concat_tables(tables) | ||
| return table.to_pandas() | ||
| df = table.to_pandas() | ||
| return _check_dataframe_localize_timestamps(df) | ||
|
||
| else: | ||
| return pd.DataFrame.from_records([], columns=self.columns) | ||
| except ImportError as e: | ||
|
|
@@ -1952,6 +1954,7 @@ def _to_corrected_pandas_type(dt): | |
| """ | ||
| When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. | ||
| This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. | ||
| NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] | ||
| """ | ||
| import numpy as np | ||
| if type(dt) == ByteType: | ||
|
|
@@ -1962,6 +1965,8 @@ def _to_corrected_pandas_type(dt): | |
| return np.int32 | ||
| elif type(dt) == FloatType: | ||
| return np.float32 | ||
| elif type(dt) == DateType: | ||
| return 'datetime64[ns]' | ||
| else: | ||
| return None | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase): | |
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| from datetime import datetime | ||
| ReusedPySparkTestCase.setUpClass() | ||
|
|
||
| # Synchronize default timezone between Python and Java | ||
| cls.tz_prev = os.environ.get("TZ", None) # save current tz if set | ||
| tz = "America/Los_Angeles" | ||
| os.environ["TZ"] = tz | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am too much worried but shouldn't we keep the original
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I was assuming that |
||
| time.tzset() | ||
|
|
||
| cls.spark = SparkSession(cls.sc) | ||
| cls.spark.conf.set("spark.sql.session.timeZone", tz) | ||
| cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") | ||
| cls.schema = StructType([ | ||
| StructField("1_str_t", StringType(), True), | ||
| StructField("2_int_t", IntegerType(), True), | ||
| StructField("3_long_t", LongType(), True), | ||
| StructField("4_float_t", FloatType(), True), | ||
| StructField("5_double_t", DoubleType(), True)]) | ||
| cls.data = [("a", 1, 10, 0.2, 2.0), | ||
| ("b", 2, 20, 0.4, 4.0), | ||
| ("c", 3, 30, 0.8, 6.0)] | ||
| StructField("5_double_t", DoubleType(), True), | ||
| StructField("6_date_t", DateType(), True), | ||
| StructField("7_timestamp_t", TimestampType(), True)]) | ||
| cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), | ||
| ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), | ||
| ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| del os.environ["TZ"] | ||
| if cls.tz_prev is not None: | ||
| os.environ["TZ"] = cls.tz_prev | ||
| time.tzset() | ||
| ReusedPySparkTestCase.tearDownClass() | ||
| cls.spark.stop() | ||
|
|
||
| def assertFramesEqual(self, df_with_arrow, df_without): | ||
| msg = ("DataFrame from Arrow is not equal" + | ||
|
|
@@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): | |
| self.assertTrue(df_without.equals(df_with_arrow), msg=msg) | ||
|
|
||
| def test_unsupported_datatype(self): | ||
| schema = StructType([StructField("dt", DateType(), True)]) | ||
| df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) | ||
| schema = StructType([StructField("decimal", DecimalType(), True)]) | ||
| df = self.spark.createDataFrame([(None,)], schema=schema) | ||
| with QuietTest(self.sc): | ||
| self.assertRaises(Exception, lambda: df.toPandas()) | ||
|
|
||
|
|
@@ -3385,13 +3405,59 @@ def test_vectorized_udf_varargs(self): | |
|
|
||
| def test_vectorized_udf_unsupported_types(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| schema = StructType([StructField("dt", DateType(), True)]) | ||
| df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) | ||
| f = pandas_udf(lambda x: x, DateType()) | ||
| schema = StructType([StructField("dt", DecimalType(), True)]) | ||
| df = self.spark.createDataFrame([(None,)], schema=schema) | ||
| f = pandas_udf(lambda x: x, DecimalType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported data type'): | ||
| df.select(f(col('dt'))).collect() | ||
|
|
||
| def test_vectorized_udf_null_date(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| from datetime import date | ||
| schema = StructType().add("date", DateType()) | ||
| data = [(date(1969, 1, 1),), | ||
| (date(2012, 2, 2),), | ||
| (None,), | ||
| (date(2100, 4, 4),)] | ||
| df = self.spark.createDataFrame(data, schema=schema) | ||
| date_f = pandas_udf(lambda t: t, returnType=DateType()) | ||
| res = df.select(date_f(col("date"))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_vectorized_udf_timestamps(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| from datetime import datetime | ||
| schema = StructType([ | ||
| StructField("idx", LongType(), True), | ||
| StructField("timestamp", TimestampType(), True)]) | ||
| data = [(0, datetime(1969, 1, 1, 1, 1, 1)), | ||
| (1, datetime(2012, 2, 2, 2, 2, 2)), | ||
| (2, None), | ||
| (3, datetime(2100, 4, 4, 4, 4, 4))] | ||
| df = self.spark.createDataFrame(data, schema=schema) | ||
|
|
||
| # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc | ||
| f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) | ||
| df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) | ||
|
|
||
| @pandas_udf(returnType=BooleanType()) | ||
| def check_data(idx, timestamp, timestamp_copy): | ||
| is_equal = timestamp.isnull() # use this array to check values are equal | ||
| for i in range(len(idx)): | ||
| # Check that timestamps are as expected in the UDF | ||
| is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ | ||
| timestamp[i].to_pydatetime() == data[idx[i]][1] | ||
| return is_equal | ||
|
|
||
| result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), | ||
| col("timestamp_copy"))).collect() | ||
| # Check that collection values are correct | ||
| self.assertEquals(len(data), len(result)) | ||
| for i in range(len(result)): | ||
| self.assertEquals(data[i][1], result[i][1]) # "timestamp" col | ||
| self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected | ||
|
|
||
|
|
||
| @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") | ||
| class GroupbyApplyTests(ReusedPySparkTestCase): | ||
|
|
@@ -3550,8 +3616,8 @@ def test_wrong_args(self): | |
| def test_unsupported_types(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| schema = StructType( | ||
| [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) | ||
| df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) | ||
| [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) | ||
| df = self.spark.createDataFrame([(1, None,)], schema=schema) | ||
| f = pandas_udf(lambda x: x, df.schema) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported data type'): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1619,11 +1619,33 @@ def to_arrow_type(dt): | |
| arrow_type = pa.decimal(dt.precision, dt.scale) | ||
| elif type(dt) == StringType: | ||
| arrow_type = pa.string() | ||
| elif type(dt) == DateType: | ||
| arrow_type = pa.date32() | ||
| elif type(dt) == TimestampType: | ||
| # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read | ||
| arrow_type = pa.timestamp('us', tz='UTC') | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is ok to specify 'UTC' here. In Spark the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think It'd be nicer if we have some comments for this explanation. |
||
| else: | ||
| raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) | ||
| return arrow_type | ||
|
|
||
|
|
||
| def _check_dataframe_localize_timestamps(df): | ||
|
||
| """ Convert timezone aware timestamps to timezone-naive in local time | ||
|
||
| """ | ||
| from pandas.api.types import is_datetime64tz_dtype | ||
| for column, series in df.iteritems(): | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if is_datetime64tz_dtype(series.dtype): | ||
| df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) | ||
|
||
| return df | ||
|
|
||
|
|
||
| def _series_convert_timestamps_internal(s): | ||
| """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage | ||
| """ | ||
| return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') | ||
|
||
|
|
||
|
|
||
| def _test(): | ||
| import doctest | ||
| from pyspark.context import SparkContext | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we need
s.fillna(0)for null values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that? We did that for integers that were promoted to floats to get rid of NaN, but here we are converting datetime64[ns] to datetime64[us] and both support missing values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not exactly sure the reason but it seems
s.dt.tz_localize('tzlocal()')in_series_convert_timestamps_internaldoesn't work properly when including null values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, that's strange
s.dt.tz_localize('tzlocal()gets anOverflowError: Python int too large to convert to C longerror when printing buts.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')works but comes up with a bogus time where the NaT was. I agree thatfillna(0)is safer to avoid overflow.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the date/time-related casting bugs in pyarrow and added new cast implementations -- conversions from one timestamp unit to another in Arrow-land fail silently right now, this will all be in the 0.8.0 release landing hopefully the week of 11/6 or thereabouts apache/arrow#1245
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, thanks @wesm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BryanCutler It seems
s.apply(lambda ts: ts.tz_localize('tzlocal()'))works withouts.fillna(0). Do you know the difference between this ands.dt.tz_localize('tzlocal()')?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply()will invoke the given function on each individual value of the series. I think this iterates over the series, wheres.dt.tz_localize()would do a vectorized operation and should be faster.