diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 4b91c6a0f8730..63fb8562799e3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -153,13 +153,16 @@ def create_array(s, t): s = s.astype(s.dtypes.categories.dtype) try: array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) - except pa.ArrowException as e: - error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ - "Array (%s). It can be caused by overflows or other unsafe " + \ - "conversions warned by Arrow. Arrow safe type check can be " + \ - "disabled by using SQL config " + \ - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - raise RuntimeError(error_msg % (s.dtype, t), e) + except ValueError as e: + if self._safecheck: + error_msg = "Exception thrown when converting pandas.Series (%s) to " + \ + "Arrow Array (%s). It can be caused by overflows or other " + \ + "unsafe conversions warned by Arrow. Arrow safe type check " + \ + "can be disabled by using SQL config " + \ + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + raise ValueError(error_msg % (s.dtype, t)) from e + else: + raise e return array arrs = [] diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index fb4f619c8bf63..1d8cd76aa819c 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -264,11 +264,12 @@ def test_createDataFrame_with_schema(self): def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() fields = list(self.schema) - fields[0], fields[1] = fields[1], fields[0] # swap str with int + fields[5], fields[6] = fields[6], fields[5] # swap decimal with date wrong_schema = StructType(fields) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "integer.*required"): - self.spark.createDataFrame(pdf, schema=wrong_schema) + with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"): + self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 6eb5355044bb0..eca4f8b12f9c4 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -446,15 +446,16 @@ def int_index(pdf): def column_name_typo(pdf): return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) - @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + @pandas_udf('id long, v decimal', PandasUDFType.GROUPED_MAP) def invalid_positional_types(pdf): - return pd.DataFrame([(u'a', 1.2)]) + return pd.DataFrame([(1, datetime.date(2020, 10, 5))]) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): - grouped_df.apply(column_name_typo).collect() - with self.assertRaisesRegexp(Exception, "an integer is required"): - grouped_df.apply(invalid_positional_types).collect() + with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"): + grouped_df.apply(invalid_positional_types).collect() def test_positional_assignment_conf(self): with self.sql_conf({