diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 30ec676b96c46..f8b8834fd2df7 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1033,6 +1033,29 @@ def test_np_scalar_input(self): res = df.select(array_position(df.data, dtype(1)).alias("c")).collect() self.assertEqual([Row(c=1), Row(c=0)], res) + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_ndarray_input(self): + import numpy as np + + arr_dtype_to_spark_dtypes = [ + ("int8", [("b", "array")]), + ("int16", [("b", "array")]), + ("int32", [("b", "array")]), + ("int64", [("b", "array")]), + ("float32", [("b", "array")]), + ("float64", [("b", "array")]), + ] + for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes: + arr = np.array([1, 2]).astype(t) + self.assertEqual( + expected_spark_dtypes, self.spark.range(1).select(lit(arr).alias("b")).dtypes + ) + arr = np.array([1, 2]).astype(np.uint) + with self.assertRaisesRegex( + TypeError, "The type of array scalar '%s' is not supported" % arr.dtype + ): + self.spark.range(1).select(lit(arr).alias("b")) + def test_binary_math_function(self): funcs, expected = zip(*[(atan2, 0.13664), (hypot, 8.07527), (pow, 2.14359), (pmod, 1.1)]) df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c1e6a738bc6aa..365c903487ce9 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -46,7 +46,7 @@ ) from py4j.protocol import register_input_converter -from py4j.java_gateway import GatewayClient, JavaClass, JavaObject +from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject from pyspark.serializers import CloudPickleSerializer from pyspark.sql.utils import has_numpy @@ -2268,12 +2268,59 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any: return obj.item() +class NumpyArrayConverter: + def _from_numpy_type_to_java_type( + self, nt: "np.dtype", gateway: JavaGateway + ) -> Optional[JavaClass]: + """Convert NumPy type to Py4J Java type.""" + if nt in [np.dtype("int8"), np.dtype("int16")]: + # Mapping int8 to gateway.jvm.byte causes + # TypeError: 'bytes' object does not support item assignment + return gateway.jvm.short + elif nt == np.dtype("int32"): + return gateway.jvm.int + elif nt == np.dtype("int64"): + return gateway.jvm.long + elif nt == np.dtype("float32"): + return gateway.jvm.float + elif nt == np.dtype("float64"): + return gateway.jvm.double + elif nt == np.dtype("bool"): + return gateway.jvm.boolean + + return None + + def can_convert(self, obj: Any) -> bool: + return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1 + + def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject: + from pyspark import SparkContext + + gateway = SparkContext._gateway + assert gateway is not None + plist = obj.tolist() + + if len(obj) > 0 and isinstance(plist[0], str): + jtpe = gateway.jvm.String + else: + jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway) + if jtpe is None: + raise TypeError("The type of array scalar '%s' is not supported" % (obj.dtype)) + jarr = gateway.new_array(jtpe, len(obj)) + for i in range(len(plist)): + jarr[i] = plist[i] + return jarr + + # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeNTZConverter()) register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) register_input_converter(DayTimeIntervalTypeConverter()) register_input_converter(NumpyScalarConverter()) +# NumPy array satisfies py4j.java_collections.ListConverter, +# so prepend NumpyArrayConverter +register_input_converter(NumpyArrayConverter(), prepend=True) def _test() -> None: