Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
cast,
)

import numpy as np

from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
CaseWhen,
Expand All @@ -42,7 +44,7 @@
LambdaFunction,
)
from pyspark.sql import functions as pysparkfuncs
from pyspark.sql.types import DataType, StructType, ArrayType
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -192,6 +194,17 @@ def lit(col: Any) -> Column:
if isinstance(col, Column):
return col
elif isinstance(col, list):
return array(*[lit(c) for c in col])
elif isinstance(col, np.ndarray) and col.ndim == 1:
if _from_numpy_type(col.dtype) is None:
raise TypeError("The type of array scalar '%s' is not supported" % (col.dtype))

# NumpyArrayConverter for Py4J can not support ndarray with int8 values.
Copy link
Contributor Author

@zhengruifeng zhengruifeng Jan 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Actually this is not a problem for Connect, but here still convert it
# to int16 for compatibility.
if col.dtype == np.int8:
col = col.astype(np.int16)

return array(*[lit(c) for c in col])
else:
return Column(LiteralExpression._from_value(col))
Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/sql/tests/connect/test_parity_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ def test_lit_np_scalar(self):
def test_map_functions(self):
super().test_map_functions()

# TODO(SPARK-41903): Support data type ndarray
@unittest.skip("Fails in Spark Connect, should enable.")
def test_ndarray_input(self):
super().test_ndarray_input()

# TODO(SPARK-41902): Parity in String representation of higher_order_function's output
@unittest.skip("Fails in Spark Connect, should enable.")
def test_nested_higher_order_function(self):
Expand Down