Skip to content

Commit 47068db

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-41903][CONNECT][PYTHON] Literal` should support 1-dim ndarray
### What changes were proposed in this pull request? `Literal` should support 1-dim ndarray ### Why are the changes needed? parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? enabled UT Closes #39570 from zhengruifeng/connect_lit_ndaray. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent bf0d3c5 commit 47068db

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

python/pyspark/sql/connect/functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
cast,
3232
)
3333

34+
import numpy as np
35+
3436
from pyspark.sql.connect.column import Column
3537
from pyspark.sql.connect.expressions import (
3638
CaseWhen,
@@ -42,7 +44,7 @@
4244
LambdaFunction,
4345
)
4446
from pyspark.sql import functions as pysparkfuncs
45-
from pyspark.sql.types import DataType, StructType, ArrayType
47+
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType
4648

4749
if TYPE_CHECKING:
4850
from pyspark.sql.connect._typing import ColumnOrName
@@ -192,6 +194,17 @@ def lit(col: Any) -> Column:
192194
if isinstance(col, Column):
193195
return col
194196
elif isinstance(col, list):
197+
return array(*[lit(c) for c in col])
198+
elif isinstance(col, np.ndarray) and col.ndim == 1:
199+
if _from_numpy_type(col.dtype) is None:
200+
raise TypeError("The type of array scalar '%s' is not supported" % (col.dtype))
201+
202+
# NumpyArrayConverter for Py4J can not support ndarray with int8 values.
203+
# Actually this is not a problem for Connect, but here still convert it
204+
# to int16 for compatibility.
205+
if col.dtype == np.int8:
206+
col = col.astype(np.int16)
207+
195208
return array(*[lit(c) for c in col])
196209
else:
197210
return Column(LiteralExpression._from_value(col))

python/pyspark/sql/tests/connect/test_parity_functions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,6 @@ def test_lit_np_scalar(self):
6464
def test_map_functions(self):
6565
super().test_map_functions()
6666

67-
# TODO(SPARK-41903): Support data type ndarray
68-
@unittest.skip("Fails in Spark Connect, should enable.")
69-
def test_ndarray_input(self):
70-
super().test_ndarray_input()
71-
7267
# TODO(SPARK-41902): Parity in String representation of higher_order_function's output
7368
@unittest.skip("Fails in Spark Connect, should enable.")
7469
def test_nested_higher_order_function(self):

0 commit comments

Comments
 (0)