Skip to content

Commit 5b187a8

Browse files
HyukjinKwonBryanCutler
authored andcommitted
[SPARK-24976][PYTHON] Allow None for Decimal type conversion (specific to PyArrow 0.9.0)
## What changes were proposed in this pull request? See [ARROW-2432](https://jira.apache.org/jira/browse/ARROW-2432). Seems using `from_pandas` to convert decimals fails if encounters a value of `None`: ```python import pyarrow as pa import pandas as pd from decimal import Decimal pa.Array.from_pandas(pd.Series([Decimal('3.14'), None]), type=pa.decimal128(3, 2)) ``` **Arrow 0.8.0** ``` <pyarrow.lib.Decimal128Array object at 0x10a572c58> [ Decimal('3.14'), NA ] ``` **Arrow 0.9.0** ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "array.pxi", line 383, in pyarrow.lib.Array.from_pandas File "array.pxi", line 177, in pyarrow.lib.array File "error.pxi", line 77, in pyarrow.lib.check_status File "error.pxi", line 77, in pyarrow.lib.check_status pyarrow.lib.ArrowInvalid: Error converting from Python objects to Decimal: Got Python object of type NoneType but can only handle these types: decimal.Decimal ``` This PR propose to work around this via Decimal NaN: ```python pa.Array.from_pandas(pd.Series([Decimal('3.14'), Decimal('NaN')]), type=pa.decimal128(3, 2)) ``` ``` <pyarrow.lib.Decimal128Array object at 0x10ffd2e68> [ Decimal('3.14'), NA ] ``` ## How was this patch tested? Manually tested: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.sql.tests ScalarPandasUDFTests ``` **Before** ``` Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4672, in test_vectorized_udf_null_decimal self.assertEquals(df.collect(), res.collect()) File "/.../spark/python/pyspark/sql/dataframe.py", line 533, in collect sock_info = self._jdf.collectToPython() File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ answer, self.gateway_client, self.target_id, self.name) File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value format(target_id, ".", name), value) Py4JJavaError: An error occurred while calling o51.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 1.0 failed 1 times, most recent failure: Lost task 3.0 in stage 1.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 320, in main process() File "/.../spark/python/pyspark/worker.py", line 315, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/.../spark/python/pyspark/serializers.py", line 274, in dump_stream batch = _create_batch(series, self._timezone) File "/.../spark/python/pyspark/serializers.py", line 243, in _create_batch arrs = [create_array(s, t) for s, t in series] File "/.../spark/python/pyspark/serializers.py", line 241, in create_array return pa.Array.from_pandas(s, mask=mask, type=t) File "array.pxi", line 383, in pyarrow.lib.Array.from_pandas File "array.pxi", line 177, in pyarrow.lib.array File "error.pxi", line 77, in pyarrow.lib.check_status File "error.pxi", line 77, in pyarrow.lib.check_status ArrowInvalid: Error converting from Python objects to Decimal: Got Python object of type NoneType but can only handle these types: decimal.Decimal ``` **After** ``` Running tests... ---------------------------------------------------------------------- Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). .......S............................. ---------------------------------------------------------------------- Ran 37 tests in 21.980s ``` Author: hyukjinkwon <[email protected]> Closes #21928 from HyukjinKwon/SPARK-24976. (cherry picked from commit f4772fd) Signed-off-by: Bryan Cutler <[email protected]>
1 parent fc3df45 commit 5b187a8

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

python/pyspark/serializers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,10 @@ def _create_batch(series, timezone):
215215
:param timezone: A timezone to respect when handling timestamp values
216216
:return: Arrow RecordBatch
217217
"""
218-
219-
from pyspark.sql.types import _check_series_convert_timestamps_internal
218+
import decimal
219+
from distutils.version import LooseVersion
220220
import pyarrow as pa
221+
from pyspark.sql.types import _check_series_convert_timestamps_internal
221222
# Make input conform to [(series1, type1), (series2, type2), ...]
222223
if not isinstance(series, (list, tuple)) or \
223224
(len(series) == 2 and isinstance(series[1], pa.DataType)):
@@ -235,6 +236,11 @@ def create_array(s, t):
235236
# TODO: need decode before converting to Arrow in Python 2
236237
return pa.Array.from_pandas(s.apply(
237238
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
239+
elif t is not None and pa.types.is_decimal(t) and \
240+
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
241+
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
242+
return pa.Array.from_pandas(s.apply(
243+
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
238244
return pa.Array.from_pandas(s, mask=mask, type=t)
239245

240246
arrs = [create_array(s, t) for s, t in series]

0 commit comments

Comments
 (0)