diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index c93862f2f6..d937bf287c 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -12,6 +12,7 @@ - drop - drop_nulls - estimated_size + - explode - filter - gather_every - get_column diff --git a/docs/api-reference/lazyframe.md b/docs/api-reference/lazyframe.md index 515069d1cb..07667ab044 100644 --- a/docs/api-reference/lazyframe.md +++ b/docs/api-reference/lazyframe.md @@ -10,6 +10,7 @@ - columns - drop - drop_nulls + - explode - filter - gather_every - group_by diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a2b739d34c..cdbfd034ef 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -949,3 +949,55 @@ def unpivot( value_name=value_name if value_name is not None else "value", ) ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + from narwhals.exceptions import InvalidOperationError + + dtypes = import_dtypes_module(self._version) + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + if len(to_explode) == 1: + return self._from_native_frame(self._native_frame.explode(to_explode[0])) + else: + native_frame = self._native_frame + anchor_series = native_frame[to_explode[0]].list.len() + + if not all( + (native_frame[col_name].list.len() == anchor_series).all() + for col_name in to_explode[1:] + ): + from narwhals.exceptions import ShapeError + + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + + original_columns = self.columns + other_columns = [c for c in original_columns if c not in to_explode] + + exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode( + to_explode[0] + ) + exploded_series = [ + native_frame[col_name].explode().to_frame() for col_name in to_explode[1:] + ] + + plx = self.__native_namespace__() + + return self._from_native_frame( + plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns] + ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index cff5d1cc87..8f120532cb 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -335,6 +335,14 @@ def __eq__(self, other: object) -> NoReturn: ) raise NotImplementedError(msg) + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + return self._from_compliant_dataframe( + self._compliant_frame.explode( + columns, + *more_columns, + ) + ) + class DataFrame(BaseFrame[DataFrameT]): """Narwhals DataFrame, backed by a native eager dataframe. @@ -592,8 +600,6 @@ def to_pandas(self) -> pd.DataFrame: 0 1 6.0 a 1 2 7.0 b 2 3 8.0 c - - """ return self._compliant_frame.to_pandas() @@ -3125,6 +3131,68 @@ def unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + """Explode the dataframe to long format by exploding the given columns. + + Notes: + It is possible to explode multiple columns only if these columns must have + matching element counts. + + Arguments: + columns: Column names. The underlying columns being exploded must be of the `List` data type. + *more_columns: Additional names of columns to explode, specified as positional arguments. + + Returns: + New DataFrame + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoDataFrameT + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = { + ... "a": ["x", "y", "z", "w"], + ... "lst1": [[1, 2], None, [None], []], + ... "lst2": [[3, None], None, [42], []], + ... } + + We define a library agnostic function: + + >>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT: + ... return ( + ... nw.from_native(df_native) + ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) + ... .explode("lst1", "lst2") + ... .to_native() + ... ) + + We can then pass any supported library such as pandas, Polars (eager), + or PyArrow to `agnostic_explode`: + + >>> agnostic_explode(pd.DataFrame(data)) + a lst1 lst2 + 0 x 1 3 + 0 x 2 + 1 y + 2 z 42 + 3 w + >>> agnostic_explode(pl.DataFrame(data)) + shape: (5, 3) + ┌─────┬──────┬──────┐ + │ a ┆ lst1 ┆ lst2 │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i32 ┆ i32 │ + ╞═════╪══════╪══════╡ + │ x ┆ 1 ┆ 3 │ + │ x ┆ 2 ┆ null │ + │ y ┆ null ┆ null │ + │ z ┆ null ┆ 42 │ + │ w ┆ null ┆ null │ + └─────┴──────┴──────┘ + """ + return super().explode(columns, *more_columns) + class LazyFrame(BaseFrame[FrameT]): """Narwhals LazyFrame, backed by a native lazyframe. @@ -4910,3 +4978,56 @@ def unpivot( return super().unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + """Explode the dataframe to long format by exploding the given columns. + + Notes: + It is possible to explode multiple columns only if these columns must have + matching element counts. + + Arguments: + columns: Column names. The underlying columns being exploded must be of the `List` data type. + *more_columns: Additional names of columns to explode, specified as positional arguments. + + Returns: + New LazyFrame + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoFrameT + >>> import polars as pl + >>> data = { + ... "a": ["x", "y", "z", "w"], + ... "lst1": [[1, 2], None, [None], []], + ... "lst2": [[3, None], None, [42], []], + ... } + + We define a library agnostic function: + + >>> def agnostic_explode(df_native: IntoFrameT) -> IntoFrameT: + ... return ( + ... nw.from_native(df_native) + ... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32()))) + ... .explode("lst1", "lst2") + ... .to_native() + ... ) + + We can then pass any supported library such as pandas, Polars (eager), + or PyArrow to `agnostic_explode`: + + >>> agnostic_explode(pl.LazyFrame(data)).collect() + shape: (5, 3) + ┌─────┬──────┬──────┐ + │ a ┆ lst1 ┆ lst2 │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i32 ┆ i32 │ + ╞═════╪══════╪══════╡ + │ x ┆ 1 ┆ 3 │ + │ x ┆ 2 ┆ null │ + │ y ┆ null ┆ null │ + │ z ┆ null ┆ 42 │ + │ w ┆ null ┆ null │ + └─────┴──────┴──────┘ + """ + return super().explode(columns, *more_columns) diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 12f85d1ad5..ee4b79b6a0 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -35,6 +35,10 @@ def from_missing_and_available_column_names( return ColumnNotFoundError(message) +class ShapeError(Exception): + """Exception raised when trying to perform operations on data structures with incompatible shapes.""" + + class InvalidOperationError(Exception): """Exception raised during invalid operations.""" diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py new file mode 100644 index 0000000000..631da02556 --- /dev/null +++ b/tests/frame/explode_test.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest +from polars.exceptions import InvalidOperationError as PlInvalidOperationError +from polars.exceptions import ShapeError as PlShapeError + +import narwhals.stable.v1 as nw +from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import ShapeError +from tests.utils import PANDAS_VERSION +from tests.utils import POLARS_VERSION +from tests.utils import Constructor +from tests.utils import assert_equal_data + +# For context, polars allows to explode multiple columns only if the columns +# have matching element counts, therefore, l1 and l2 but not l1 and l3 together. +data = { + "a": ["x", "y", "z", "w"], + "l1": [[1, 2], None, [None], []], + "l2": [[3, None], None, [42], []], + "l3": [[1, 2], [3], [None], [1]], + "l4": [[1, 2], [3], [123], [456]], +} + + +@pytest.mark.parametrize( + ("column", "expected_values"), + [ + ("l2", [3, None, None, 42, None]), + ("l3", [1, 2, 3, None, 1]), # fast path for arrow + ], +) +def test_explode_single_col( + request: pytest.FixtureRequest, + constructor: Constructor, + column: str, + expected_values: list[int | None], +) -> None: + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): + request.applymarker(pytest.mark.xfail) + + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + + result = ( + nw.from_native(constructor(data)) + .with_columns(nw.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .select("a", column) + ) + expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("columns", "more_columns", "expected"), + [ + ( + "l1", + ["l2"], + { + "a": ["x", "x", "y", "z", "w"], + "l1": [1, 2, None, None, None], + "l2": [3, None, None, 42, None], + }, + ), + ( + "l3", + ["l4"], + { + "a": ["x", "x", "y", "z", "w"], + "l3": [1, 2, 3, None, 1], + "l4": [1, 2, 3, 123, 456], + }, + ), + ], +) +def test_explode_multiple_cols( + request: pytest.FixtureRequest, + constructor: Constructor, + columns: str | Sequence[str], + more_columns: Sequence[str], + expected: dict[str, list[str | int | None]], +) -> None: + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): + request.applymarker(pytest.mark.xfail) + + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + + result = ( + nw.from_native(constructor(data)) + .with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32()))) + .explode(columns, *more_columns) + .select("a", columns, *more_columns) + ) + assert_equal_data(result, expected) + + +def test_explode_shape_error( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if any( + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "pyarrow_table") + ): + request.applymarker(pytest.mark.xfail) + + if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): + request.applymarker(pytest.mark.xfail) + + with pytest.raises( + (ShapeError, PlShapeError), + match="exploded columns must have matching element counts", + ): + _ = ( + nw.from_native(constructor(data)) + .lazy() + .with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) + .explode("l1", "l3") + .collect() + ) + + +def test_explode_invalid_operation_error( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): + request.applymarker(pytest.mark.xfail) + + with pytest.raises( + (InvalidOperationError, PlInvalidOperationError), + match="`explode` operation not supported for dtype", + ): + _ = nw.from_native(constructor(data)).lazy().explode("a").collect()