diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index acbbdcda77..faf6c0e037 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -3,6 +3,7 @@ from functools import partial from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterator from typing import Literal from typing import Mapping @@ -16,17 +17,21 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import convert_str_slice_to_int_slice +from narwhals._arrow.utils import list_flatten +from narwhals._arrow.utils import lit from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import select_rows from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import check_column_exists from narwhals.utils import check_column_names_are_unique from narwhals.utils import generate_temporary_column_name +from narwhals.utils import import_dtypes_module from narwhals.utils import is_sequence_but_not_str from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop @@ -347,8 +352,6 @@ def estimated_size(self: Self, unit: SizeUnit) -> int | float: sz = self.native.nbytes return scale_bytes(sz, unit) - explode = not_implemented() - @property def columns(self: Self) -> list[str]: return self.native.schema.names @@ -834,3 +837,77 @@ def unpivot( ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + dtypes = import_dtypes_module(self._version) + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + + raise InvalidOperationError(msg) + + counts = pc.list_value_length(self.native[to_explode[0]]) + + if not all( + pc.all(pc.equal(pc.list_value_length(self.native[col_name]), counts)).as_py() + for col_name in to_explode[1:] + ): + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + + original_columns = self.columns + other_columns = [c for c in original_columns if c not in to_explode] + ONE = lit(1) # noqa: N806 + fast_path = pc.all(pc.greater_equal(counts, ONE)).as_py() + flatten: Callable[..., ArrowChunkedArray] + if fast_path: + indices = pc.list_parent_indices(self.native[to_explode[0]]) + flatten = list_flatten + else: + filled_counts = pc.max_element_wise(counts, ONE, skip_nulls=True) + indices = pa.array( + [ + i + for i, count in enumerate(filled_counts.to_pylist()) # pyright: ignore[reportAttributeAccessIssue] + for _ in range(count) + ] + ) + parent_indices = pc.list_parent_indices(self.native[to_explode[0]]) + is_valid_index = pc.is_in(indices, value_set=parent_indices) + exploded_size = len(is_valid_index) + + def _flatten( + array: pa.ChunkedArray[pa.ListScalar[Any]], / + ) -> ArrowChunkedArray: + dtype = array.type.value_type + return pc.replace_with_mask( + pa.nulls(exploded_size, dtype), + is_valid_index, + list_flatten(array).combine_chunks(), + ) + + flatten = _flatten + + arrays = [ + self.native[col_name].take(indices) + if col_name in other_columns + else flatten(self.native[col_name]) + for col_name in original_columns + ] + + return self._from_native_frame( + pa.Table.from_arrays(arrays, names=original_columns) + ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index ce10365f5c..4240872e30 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -53,8 +53,15 @@ def extract_regex( options: Any = None, memory_pool: Any = None, ) -> ChunkedArrayStructArray: ... + + @overload + def list_flatten(lists: ArrowChunkedArray, /, **kwds: Any) -> ArrowChunkedArray: ... + @overload + def list_flatten(lists: pa.ListArray[Any], /, **kwds: Any) -> pa.ListArray[Any]: ... + def list_flatten(lists: Any, /, recursive: bool = False, **kwds: Any) -> Any: ... # noqa: FBT001, FBT002 else: from pyarrow.compute import extract_regex + from pyarrow.compute import list_flatten # noqa: F401 from pyarrow.types import is_dictionary # noqa: F401 from pyarrow.types import is_duration from pyarrow.types import is_fixed_size_list diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index d3a9b0ff10..b3a4e3aaf2 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -36,10 +36,7 @@ def test_explode_single_col( column: str, expected_values: list[int | None], ) -> None: - if any( - backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): @@ -88,7 +85,7 @@ def test_explode_multiple_cols( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "duckdb", "pyspark") ): request.applymarker(pytest.mark.xfail) @@ -107,10 +104,7 @@ def test_explode_multiple_cols( def test_explode_shape_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any( - backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): @@ -132,7 +126,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask")): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):