diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index e1286f7d66e4..045b1366fc56 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -31,6 +31,8 @@ cast, ) +import numpy as np + from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( CaseWhen, @@ -42,7 +44,7 @@ LambdaFunction, ) from pyspark.sql import functions as pysparkfuncs -from pyspark.sql.types import DataType, StructType, ArrayType +from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -192,6 +194,17 @@ def lit(col: Any) -> Column: if isinstance(col, Column): return col elif isinstance(col, list): + return array(*[lit(c) for c in col]) + elif isinstance(col, np.ndarray) and col.ndim == 1: + if _from_numpy_type(col.dtype) is None: + raise TypeError("The type of array scalar '%s' is not supported" % (col.dtype)) + + # NumpyArrayConverter for Py4J can not support ndarray with int8 values. + # Actually this is not a problem for Connect, but here still convert it + # to int16 for compatibility. + if col.dtype == np.int8: + col = col.astype(np.int16) + return array(*[lit(c) for c in col]) else: return Column(LiteralExpression._from_value(col)) diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index e763352e9368..84f5c65017e9 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -69,11 +69,6 @@ def test_lit_np_scalar(self): def test_map_functions(self): super().test_map_functions() - # TODO(SPARK-41903): Support data type ndarray - @unittest.skip("Fails in Spark Connect, should enable.") - def test_ndarray_input(self): - super().test_ndarray_input() - # TODO(SPARK-41902): Parity in String representation of higher_order_function's output @unittest.skip("Fails in Spark Connect, should enable.") def test_nested_higher_order_function(self):