diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 96ee30e256..1176164d67 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -58,8 +58,7 @@ def native_to_narwhals_dtype( if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() if isinstance( - dtype, - (spark_types.StringType, spark_types.VarcharType, spark_types.CharType), + dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) ): return dtypes.String() if isinstance(dtype, spark_types.BooleanType): @@ -70,15 +69,26 @@ def native_to_narwhals_dtype( return dtypes.Datetime() if isinstance(dtype, spark_types.TimestampType): return dtypes.Datetime(time_zone="UTC") - if isinstance(dtype, spark_types.DecimalType): # pragma: no cover - # TODO(unassigned): cover this in dtypes_test.py + if isinstance(dtype, spark_types.DecimalType): return dtypes.Decimal() - if isinstance(dtype, spark_types.ArrayType): # pragma: no cover + if isinstance(dtype, spark_types.ArrayType): return dtypes.List( inner=native_to_narwhals_dtype( dtype.elementType, version=version, spark_types=spark_types ) ) + if isinstance(dtype, spark_types.StructType): + return dtypes.Struct( + fields=[ + dtypes.Field( + name=name, + dtype=native_to_narwhals_dtype( + dtype[name], version=version, spark_types=spark_types + ), + ) + for name in dtype.fieldNames() + ] + ) return dtypes.Unknown() @@ -113,28 +123,41 @@ def narwhals_to_native_dtype( msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}" raise ValueError(msg) return spark_types.TimestampType() - if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover - inner = narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - spark_types=spark_types, + if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)): + return spark_types.ArrayType( + elementType=narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + spark_types=spark_types, + ) ) - return spark_types.ArrayType(elementType=inner) if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover - msg = "Converting to Struct dtype is not supported yet" - raise NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - inner = narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - spark_types=spark_types, + return spark_types.StructType( + fields=[ + spark_types.StructField( + name=field.name, + dataType=narwhals_to_native_dtype( + field.dtype, + version=version, + spark_types=spark_types, + ), + ) + for field in dtype.fields # type: ignore[union-attr] + ] ) - return spark_types.ArrayType(elementType=inner) if isinstance_or_issubclass( - dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8) + dtype, + ( + dtypes.UInt64, + dtypes.UInt32, + dtypes.UInt16, + dtypes.UInt8, + dtypes.Enum, + dtypes.Categorical, + ), ): # pragma: no cover - msg = "Unsigned integer types are not supported by PySpark" + msg = "Unsigned integer, Enum and Categorical types are not supported by spark-like backend" raise UnsupportedDTypeError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 6021236390..7def70acc6 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -236,9 +236,7 @@ def test_cast_datetime_tz_aware( def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any( - backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): @@ -251,10 +249,24 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) - ] } + native_df = constructor(data) + + if "spark" in str(constructor): # pragma: no cover + # Special handling for pyspark as it natively maps the input to + # a column of type MAP + import pyspark.sql.functions as F # noqa: N812 + import pyspark.sql.types as T # noqa: N812 + + native_df = native_df.withColumn( # type: ignore[union-attr] + "a", + F.struct( + F.col("a.movie ").alias("movie ").cast(T.StringType()), + F.col("a.rating").alias("rating").cast(T.DoubleType()), + ), + ) + dtype = nw.Struct([nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float64())]) - result = ( - nw.from_native(constructor(data)).select(nw.col("a").cast(dtype)).lazy().collect() - ) + result = nw.from_native(native_df).select(nw.col("a").cast(dtype)).lazy().collect() assert result.schema == {"a": dtype}