Skip to content
65 changes: 44 additions & 21 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def native_to_narwhals_dtype(
if isinstance(dtype, spark_types.ByteType):
return dtypes.Int8()
if isinstance(
dtype,
(spark_types.StringType, spark_types.VarcharType, spark_types.CharType),
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
):
return dtypes.String()
if isinstance(dtype, spark_types.BooleanType):
Expand All @@ -70,15 +69,26 @@ def native_to_narwhals_dtype(
return dtypes.Datetime()
if isinstance(dtype, spark_types.TimestampType):
return dtypes.Datetime(time_zone="UTC")
if isinstance(dtype, spark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
if isinstance(dtype, spark_types.DecimalType):
return dtypes.Decimal()
if isinstance(dtype, spark_types.ArrayType): # pragma: no cover
if isinstance(dtype, spark_types.ArrayType):
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version=version, spark_types=spark_types
)
)
if isinstance(dtype, spark_types.StructType):
return dtypes.Struct(
fields=[
dtypes.Field(
name=name,
dtype=native_to_narwhals_dtype(
dtype[name], version=version, spark_types=spark_types
),
)
for name in dtype.fieldNames()
]
)
return dtypes.Unknown()


Expand Down Expand Up @@ -113,28 +123,41 @@ def narwhals_to_native_dtype(
msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
raise ValueError(msg)
return spark_types.TimestampType()
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
return spark_types.ArrayType(
elementType=narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
)
Comment on lines +126 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # type: ignore here is an example of this issue (#1807 (comment))

Off-topic-ish, but should I spin that out into a new issue?

I think it might get lost in that PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dangotbanned - I'd say let's keep track in a dedicated issue, as that's not even introduced in this specific PR

)
return spark_types.ArrayType(elementType=inner)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
return spark_types.StructType(
fields=[
spark_types.StructField(
name=field.name,
dataType=narwhals_to_native_dtype(
field.dtype,
version=version,
spark_types=spark_types,
),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
return spark_types.ArrayType(elementType=inner)

if isinstance_or_issubclass(
dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8)
dtype,
(
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
dtypes.UInt8,
dtypes.Enum,
dtypes.Categorical,
),
): # pragma: no cover
msg = "Unsigned integer types are not supported by PySpark"
msg = "Unsigned integer, Enum and Categorical types are not supported by spark-like backend"
raise UnsupportedDTypeError(msg)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
Expand Down
24 changes: 18 additions & 6 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def test_cast_datetime_tz_aware(


def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand All @@ -251,10 +249,24 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -
]
}

native_df = constructor(data)

if "spark" in str(constructor): # pragma: no cover
# Special handling for pyspark as it natively maps the input to
# a column of type MAP<STRING, STRING>
import pyspark.sql.functions as F # noqa: N812
import pyspark.sql.types as T # noqa: N812

native_df = native_df.withColumn( # type: ignore[union-attr]
"a",
F.struct(
F.col("a.movie ").alias("movie ").cast(T.StringType()),
F.col("a.rating").alias("rating").cast(T.DoubleType()),
),
)

dtype = nw.Struct([nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float64())])
result = (
nw.from_native(constructor(data)).select(nw.col("a").cast(dtype)).lazy().collect()
)
result = nw.from_native(native_df).select(nw.col("a").cast(dtype)).lazy().collect()

assert result.schema == {"a": dtype}

Expand Down
Loading