Skip to content
Closed
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5aa8b9e
added date type and started test, still some issue with time difference
BryanCutler Jul 13, 2017
20313f9
DateTimeUtils forces defaultTimeZone
BryanCutler Jul 18, 2017
69e1e21
fix style checks
BryanCutler Jul 18, 2017
dbfbef3
date type java tests passing
BryanCutler Jul 18, 2017
436afff
timestamp type java tests passing
BryanCutler Jul 18, 2017
78119ca
adding date and timestamp data to python tests, not passing
BryanCutler Jul 19, 2017
b709d78
TimestampType is correctly inferred as datetime64[ns]
BryanCutler Jul 19, 2017
399e527
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Jul 24, 2017
e6d8590
Adding DateType and TimestampType to ArrowUtils conversions
BryanCutler Jul 24, 2017
719e77c
using default timezone, fixed tests
BryanCutler Jul 24, 2017
3585520
fixed scala tests for timestamp
BryanCutler Jul 25, 2017
f977d0b
Adding sync between Python and Java default timezones
BryanCutler Jul 26, 2017
b826445
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Jul 27, 2017
3b83d7a
added date timestamp writers, fixed tests
BryanCutler Jul 27, 2017
a6009a5
Modify ArrowUtils to have timeZoneId when convert schema to Arrow sch…
ueshin Jul 28, 2017
2ec98cc
fixed python test tearDownClass
BryanCutler Aug 1, 2017
c29018c
using Date.valueOf for tests instead
BryanCutler Aug 2, 2017
7dbdb1f
Made timezone id required for TimestampType
BryanCutler Aug 14, 2017
c3f4e4d
added test for TimestampType without specifying timezone id
BryanCutler Aug 14, 2017
ddbea24
added date and timestamp to ArrowWriter and tests
BryanCutler Aug 15, 2017
c6b597d
removed unused import
BryanCutler Aug 16, 2017
874f104
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 10, 2017
d8bae0b
added Python timezone converions for working with Pandas
BryanCutler Oct 10, 2017
36f58b1
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 11, 2017
c4fd5ae
fix compilation
BryanCutler Oct 11, 2017
d1617fd
fixed test comp
BryanCutler Oct 11, 2017
d7d9b47
add conversion to Python system local timezone before localize
BryanCutler Oct 11, 2017
efe3e27
timestamps with Arrow almost working for pandas_udfs
BryanCutler Oct 11, 2017
9894519
added workaround for Series to_pandas with timestamps, store os.envir…
BryanCutler Oct 17, 2017
a3ba4ac
change use of xrange for py3
BryanCutler Oct 17, 2017
7266304
remove check for valid timezone in vector for ArrowWriter
BryanCutler Oct 17, 2017
e428cbe
added note for 'us' conversion
BryanCutler Oct 17, 2017
cade921
changed python api for is_datetime64
BryanCutler Oct 19, 2017
f512deb
remove Option for timezoneId
BryanCutler Oct 19, 2017
171d9e1
Merge remote-tracking branch 'upstream/master' into arrow-date-timest…
BryanCutler Oct 20, 2017
79bb93f
added pandas_udf test for date
BryanCutler Oct 23, 2017
c555207
added workaround for date casting, put back check for timestamp conve…
BryanCutler Oct 24, 2017
4d40893
added fillna for null timestamp values
BryanCutler Oct 25, 2017
addd35f
added check for pandas_udf return is a timestamp with tz, added comme…
BryanCutler Oct 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __repr__(self):


def _create_batch(series):
from pyspark.sql.types import _series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
Expand All @@ -224,7 +225,13 @@ def _create_batch(series):
# If a nullable integer series has been promoted to floating point with NaNs, need to cast
# NOTE: this is not necessary with Arrow >= 0.7
def cast_series(s, t):
if t is None or s.dtype == t.to_pandas_dtype():
if type(t) == pa.TimestampType:
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _series_convert_timestamps_internal(s).values.astype('datetime64[us]')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need s.fillna(0) for null values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that? We did that for integers that were promoted to floats to get rid of NaN, but here we are converting datetime64[ns] to datetime64[us] and both support missing values

In [28]: s = pd.Series([pd.datetime.now(), None])

In [29]: s
Out[29]: 
0   2017-10-24 10:44:51.483694
1                          NaT
dtype: datetime64[ns]

In [33]: s.values.astype('datetime64[us]')
Out[33]: array(['2017-10-24T10:44:51.483694', 'NaT'], dtype='datetime64[us]')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure the reason but it seems s.dt.tz_localize('tzlocal()') in _series_convert_timestamps_internal doesn't work properly when including null values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, that's strange s.dt.tz_localize('tzlocal() gets an OverflowError: Python int too large to convert to C long error when printing but s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') works but comes up with a bogus time where the NaT was. I agree that fillna(0) is safer to avoid overflow.

In [44]: s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
Out[44]: 
0      2017-10-24 17:44:51.483694+00:00
1   1677-09-21 08:12:43.145224192+00:00
dtype: datetime64[ns, UTC]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the date/time-related casting bugs in pyarrow and added new cast implementations -- conversions from one timestamp unit to another in Arrow-land fail silently right now, this will all be in the 0.8.0 release landing hopefully the week of 11/6 or thereabouts apache/arrow#1245

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks @wesm!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler It seems s.apply(lambda ts: ts.tz_localize('tzlocal()')) works without s.fillna(0). Do you know the difference between this and s.dt.tz_localize('tzlocal()')?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply() will invoke the given function on each individual value of the series. I think this iterates over the series, where s.dt.tz_localize() would do a vectorized operation and should be faster.

elif t == pa.date32():
# TODO: ValueError: Cannot cast DatetimeIndex to dtype datetime64[D]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came across an issue when using pandas_udf that returns a DateType. Dates are input to the udf as Series with dtype datetime64[ns] and trying to use this for pa.Array.from_pandas with type=pa.date32() fails with an error. I am also unable to call series.dt.values.astype('datetime64[D]') which results in an error. Without specifying the type, pyarrow will read the values as a timestamp. I filed https://issues.apache.org/jira/browse/ARROW-1718 to look into this, aside from this being fixed in a new version any ideas for a workaround @ueshin and @HyukjinKwon ?
cc @wesm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

@@ -229,14 +229,20 @@ def _create_batch(series):
             # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
             return _series_convert_timestamps_internal(s).values.astype('datetime64[us]')
         elif t == pa.date32():
-            # TODO: ValueError: Cannot cast DatetimeIndex to dtype datetime64[D]
-            return s.dt.values.astype('datetime64[D]')
+            return s.dt.date
         elif t is None or s.dtype == t.to_pandas_dtype():
             return s
         else:
             return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)

-    arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
+    def create_array(s, t):
+        casted = cast_series(s, t)
+        if casted.dtype == 'object':
+            return pa.Array.from_pandas(casted, type=t)
+        else:
+            return pa.Array.from_pandas(casted, mask=s.isnull(), type=t)
+
+    arrs = [create_array(s, t) for s, t in series]
     return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])

Copy link
Member

@ueshin ueshin Oct 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need to remove mask=s.isnull() when dtype is object because Arrow raises an error saying ArrowNotImplementedError: NotImplemented: mask not supported in object conversions yet.
Will this be fixed in the future (>0.4) version?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a small number of places where this is not yet supported, but we should be able to fix them in 0.8.0 (~next 2 weeks): https://issues.apache.org/jira/browse/ARROW-1721

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ueshin and @wesm, this seems to work. I will add a note with the related JIRA to look into it later

return s.dt.values.astype('datetime64[D]')
elif t is None or s.dtype == t.to_pandas_dtype():
return s
else:
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
Expand Down Expand Up @@ -260,11 +267,13 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
for batch in reader:
table = pa.Table.from_batches([batch])
yield [c.to_pandas() for c in table.itercolumns()]
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
pdf = _check_dataframe_localize_timestamps(batch.to_pandas())
yield [c for _, c in pdf.iteritems()]
Copy link
Member Author

@BryanCutler BryanCutler Oct 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running some tests, this change does not significantly degrade performance, but there seems to be a small difference. cc @ueshin

I ran various columns of random data through a pandas_udf repeatedly with and without this change. Test was in local mode with default Spark conf, looking at min wall clock time of 10 loops

before change: 2.595558
after change: 2.681813

Do you think the difference here is acceptable for now until arrow is upgraded and we can look into again?
pandas_udf_perf.py.txt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran your script in my local, too.

  • before change:
    • mean: 2.605722
    • min: 2.502404
    • max: 3.045294
  • after change:
    • mean: 2.626306
    • min: 2.341781
    • max: 2.742432

I think it's okay to use this workaround.


def __repr__(self):
return "ArrowStreamPandasSerializer"
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,11 +1880,13 @@ def toPandas(self):
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
return table.to_pandas()
df = table.to_pandas()
return _check_dataframe_localize_timestamps(df)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto (df ->pdf).

else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
Expand Down Expand Up @@ -1952,6 +1954,7 @@ def _to_corrected_pandas_type(dt):
"""
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
"""
import numpy as np
if type(dt) == ByteType:
Expand All @@ -1962,6 +1965,8 @@ def _to_corrected_pandas_type(dt):
return np.int32
elif type(dt) == FloatType:
return np.float32
elif type(dt) == DateType:
return 'datetime64[ns]'
else:
return None

Expand Down
88 changes: 77 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
from datetime import datetime
ReusedPySparkTestCase.setUpClass()

# Synchronize default timezone between Python and Java
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am too much worried but shouldn't we keep the original os.environ["TZ"] and then restore it back in L3115?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was assuming that os.environ["TZ"] wouldn't be set, but I guess it could be. I'll fix that.

time.tzset()

cls.spark = SparkSession(cls.sc)
cls.spark.conf.set("spark.sql.session.timeZone", tz)
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0),
("b", 2, 20, 0.4, 4.0),
("c", 3, 30, 0.8, 6.0)]
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]

@classmethod
def tearDownClass(cls):
del os.environ["TZ"]
if cls.tz_prev is not None:
os.environ["TZ"] = cls.tz_prev
time.tzset()
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
Expand All @@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_unsupported_datatype(self):
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
schema = StructType([StructField("decimal", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())

Expand Down Expand Up @@ -3385,13 +3405,59 @@ def test_vectorized_udf_varargs(self):

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
f = pandas_udf(lambda x: x, DateType())
schema = StructType([StructField("dt", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, DecimalType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('dt'))).collect()

def test_vectorized_udf_null_date(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import date
schema = StructType().add("date", DateType())
data = [(date(1969, 1, 1),),
(date(2012, 2, 2),),
(None,),
(date(2100, 4, 4),)]
df = self.spark.createDataFrame(data, schema=schema)
date_f = pandas_udf(lambda t: t, returnType=DateType())
res = df.select(date_f(col("date")))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_timestamps(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import datetime
schema = StructType([
StructField("idx", LongType(), True),
StructField("timestamp", TimestampType(), True)])
data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
(1, datetime(2012, 2, 2, 2, 2, 2)),
(2, None),
(3, datetime(2100, 4, 4, 4, 4, 4))]
df = self.spark.createDataFrame(data, schema=schema)

# Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))

@pandas_udf(returnType=BooleanType())
def check_data(idx, timestamp, timestamp_copy):
is_equal = timestamp.isnull() # use this array to check values are equal
for i in range(len(idx)):
# Check that timestamps are as expected in the UDF
is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]
return is_equal

result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedPySparkTestCase):
Expand Down Expand Up @@ -3550,8 +3616,8 @@ def test_wrong_args(self):
def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType(
[StructField("id", LongType(), True), StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema)
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,11 +1619,33 @@ def to_arrow_type(dt):
arrow_type = pa.decimal(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok to specify 'UTC' here. In Spark the ArrowColumnVector requires a timezone to be set because it expects the Arrow NullableTimeStampMicroTZVector but it doesn't do anything with the tz, so as long as the data is UTC time internally, it should be fine for Spark to use.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think It'd be nicer if we have some comments for this explanation.

else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type


def _check_dataframe_localize_timestamps(df):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: df -> pdf

""" Convert timezone aware timestamps to timezone-naive in local time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment that says the expected input is pd.DataFrame.

"""
from pandas.api.types import is_datetime64tz_dtype
for column, series in df.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jreback is this the best route to obtain tz-naive datetimes in localtime?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this is idiomatic

In [16]: s = Series(pd.date_range('20130101', periods=3, tz='UTC'))

In [17]: s
Out[17]: 
0   2013-01-01 00:00:00+00:00
1   2013-01-02 00:00:00+00:00
2   2013-01-03 00:00:00+00:00
dtype: datetime64[ns, UTC]

In [18]: s.dt.tz_convert('tzlocal()')
Out[18]: 
0   2012-12-31 19:00:00-05:00
1   2013-01-01 19:00:00-05:00
2   2013-01-02 19:00:00-05:00
dtype: datetime64[ns, tzlocal()]

In [19]: s.dt.tz_convert('tzlocal()').dt.tz_localize(None)
Out[19]: 
0   2012-12-31 19:00:00
1   2013-01-01 19:00:00
2   2013-01-02 19:00:00
dtype: datetime64[ns]

return df


def _series_convert_timestamps_internal(s):
""" Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
"""
return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
Copy link
Member

@ueshin ueshin Oct 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer the previous implementation which checks type in case the series is already tz-aware.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. I figured we are already checking that it is a timestamp type, but it's true the user could have created tz-aware timestamps so we need to check.



def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) {
accessor = new StringAccessor((NullableVarCharVector) vector);
} else if (vector instanceof NullableVarBinaryVector) {
accessor = new BinaryAccessor((NullableVarBinaryVector) vector);
} else if (vector instanceof NullableDateDayVector) {
accessor = new DateAccessor((NullableDateDayVector) vector);
} else if (vector instanceof NullableTimeStampMicroTZVector) {
accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
Expand Down Expand Up @@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) {
}
}

private static class DateAccessor extends ArrowVectorAccessor {

private final NullableDateDayVector.Accessor accessor;

DateAccessor(NullableDateDayVector vector) {
super(vector);
this.accessor = vector.getAccessor();
}

@Override
final int getInt(int rowId) {
return accessor.get(rowId);
}
}

private static class TimestampAccessor extends ArrowVectorAccessor {

private final NullableTimeStampMicroTZVector.Accessor accessor;

TimestampAccessor(NullableTimeStampMicroTZVector vector) {
super(vector);
this.accessor = vector.getAccessor();
}

@Override
final long getLong(int rowId) {
return accessor.get(rowId);
}
}

private static class ArrayAccessor extends ArrowVectorAccessor {

private final UInt4Vector.Accessor accessor;
Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3143,9 +3143,11 @@ class Dataset[T] private[sql](
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
queryExecution.toRdd.mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
ArrowConverters.toPayloadIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ private[sql] object ArrowConverters {
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int,
timeZoneId: String,
context: TaskContext): Iterator[ArrowPayload] = {

val arrowSchema = ArrowUtils.toArrowSchema(schema)
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.types._
Expand All @@ -31,7 +31,8 @@ object ArrowUtils {

// todo: support more types.

def toArrowType(dt: DataType): ArrowType = dt match {
/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
case ShortType => new ArrowType.Int(8 * 2, true)
Expand All @@ -42,6 +43,13 @@ object ArrowUtils {
case StringType => ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
case DateType => new ArrowType.Date(DateUnit.DAY)
case TimestampType =>
if (timeZoneId == null) {
throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter")
} else {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
}
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}")
}

Expand All @@ -58,22 +66,27 @@ object ArrowUtils {
case ArrowType.Utf8.INSTANCE => StringType
case ArrowType.Binary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = {
/** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
def toArrowField(
name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
dt match {
case ArrayType(elementType, containsNull) =>
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava)
new Field(name, fieldType,
Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava)
case StructType(fields) =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
new Field(name, fieldType,
fields.map { field =>
toArrowField(field.name, field.dataType, field.nullable)
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}.toSeq.asJava)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType), null)
val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
}
}
Expand All @@ -94,9 +107,10 @@ object ArrowUtils {
}
}

def toArrowSchema(schema: StructType): Schema = {
/** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */
def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
new Schema(schema.map { field =>
toArrowField(field.name, field.dataType, field.nullable)
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}.asJava)
}

Expand Down
Loading