Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3061fe9
feat: DataFrame and LazyFrame explode
FBruzzesi Dec 8, 2024
2326b08
arrow refactor
FBruzzesi Dec 9, 2024
32af22e
raise for invalid type and docstrings
FBruzzesi Dec 9, 2024
3b52ab5
Update narwhals/dataframe.py
FBruzzesi Dec 9, 2024
c3bf009
old versions
FBruzzesi Dec 9, 2024
b427e79
merge main
FBruzzesi Dec 13, 2024
c77dc62
Merge branch 'main' into feat/explode-method
FBruzzesi Dec 17, 2024
72314a2
almost all native
FBruzzesi Dec 17, 2024
7f04579
doctest
FBruzzesi Dec 17, 2024
7be326e
Merge branch 'main' into feat/explode-method
FBruzzesi Dec 17, 2024
5da1ad6
Merge branch 'main' into feat/explode-method
FBruzzesi Dec 18, 2024
4a098b8
Merge branch 'main' into feat/explode-method
FBruzzesi Dec 19, 2024
380a6cb
Merge branch 'feat/explode-method' of https://github.com/narwhals-dev…
FBruzzesi Dec 21, 2024
c7a47c9
Merge branch 'main' into feat/explode-method
FBruzzesi Dec 21, 2024
864e932
better error message, fail for arrow with nulls
FBruzzesi Dec 21, 2024
cc72f6b
doctest-modules
FBruzzesi Dec 21, 2024
1156beb
completely remove pyarrow implementation
FBruzzesi Dec 21, 2024
03081cb
feat: ArrowDataFrame explode method
FBruzzesi Dec 21, 2024
8fc8c0a
merge main
FBruzzesi Dec 22, 2024
7369925
Merge remote-tracking branch 'upstream/main' into feat/pyarrow-explode
dangotbanned Mar 25, 2025
d04fc7d
fix: remove `not_implemented`
dangotbanned Mar 25, 2025
fc79540
refactor: move imports
dangotbanned Mar 25, 2025
1f1ac63
chore: use `ArrowDataFrame.native`
dangotbanned Mar 25, 2025
79b8fd4
fix(typing): Resolve most issues
dangotbanned Mar 25, 2025
80fcc02
pyright ignore
dangotbanned Mar 25, 2025
22ea311
fix(typing): Avoid `mypy` redef
dangotbanned Mar 25, 2025
8e1e025
Merge branch 'main' into feat/pyarrow-explode
FBruzzesi Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterator
from typing import Literal
from typing import Mapping
Expand All @@ -16,17 +17,21 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import align_series_full_broadcast
from narwhals._arrow.utils import convert_str_slice_to_int_slice
from narwhals._arrow.utils import list_flatten
from narwhals._arrow.utils import lit
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import select_rows
from narwhals._compliant import EagerDataFrame
from narwhals._expression_parsing import ExprKind
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import ShapeError
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import check_column_exists
from narwhals.utils import check_column_names_are_unique
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import not_implemented
from narwhals.utils import parse_columns_to_drop
Expand Down Expand Up @@ -347,8 +352,6 @@ def estimated_size(self: Self, unit: SizeUnit) -> int | float:
sz = self.native.nbytes
return scale_bytes(sz, unit)

explode = not_implemented()

@property
def columns(self: Self) -> list[str]:
return self.native.schema.names
Expand Down Expand Up @@ -834,3 +837,77 @@ def unpivot(
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)

schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)

raise InvalidOperationError(msg)

counts = pc.list_value_length(self.native[to_explode[0]])

if not all(
pc.all(pc.equal(pc.list_value_length(self.native[col_name]), counts)).as_py()
for col_name in to_explode[1:]
):
msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]
ONE = lit(1) # noqa: N806
fast_path = pc.all(pc.greater_equal(counts, ONE)).as_py()
flatten: Callable[..., ArrowChunkedArray]
if fast_path:
indices = pc.list_parent_indices(self.native[to_explode[0]])
flatten = list_flatten
else:
filled_counts = pc.max_element_wise(counts, ONE, skip_nulls=True)
indices = pa.array(
[
i
for i, count in enumerate(filled_counts.to_pylist()) # pyright: ignore[reportAttributeAccessIssue]
for _ in range(count)
]
)
parent_indices = pc.list_parent_indices(self.native[to_explode[0]])
is_valid_index = pc.is_in(indices, value_set=parent_indices)
exploded_size = len(is_valid_index)

def _flatten(
array: pa.ChunkedArray[pa.ListScalar[Any]], /
) -> ArrowChunkedArray:
dtype = array.type.value_type
return pc.replace_with_mask(
pa.nulls(exploded_size, dtype),
is_valid_index,
list_flatten(array).combine_chunks(),
)

flatten = _flatten

arrays = [
self.native[col_name].take(indices)
if col_name in other_columns
else flatten(self.native[col_name])
for col_name in original_columns
]

return self._from_native_frame(
pa.Table.from_arrays(arrays, names=original_columns)
)
7 changes: 7 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ def extract_regex(
options: Any = None,
memory_pool: Any = None,
) -> ChunkedArrayStructArray: ...

@overload
def list_flatten(lists: ArrowChunkedArray, /, **kwds: Any) -> ArrowChunkedArray: ...
@overload
def list_flatten(lists: pa.ListArray[Any], /, **kwds: Any) -> pa.ListArray[Any]: ...
def list_flatten(lists: Any, /, recursive: bool = False, **kwds: Any) -> Any: ... # noqa: FBT001, FBT002
else:
from pyarrow.compute import extract_regex
from pyarrow.compute import list_flatten # noqa: F401
from pyarrow.types import is_dictionary # noqa: F401
from pyarrow.types import is_duration
from pyarrow.types import is_fixed_size_list
Expand Down
14 changes: 4 additions & 10 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ def test_explode_single_col(
column: str,
expected_values: list[int | None],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand Down Expand Up @@ -88,7 +85,7 @@ def test_explode_multiple_cols(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark")
for backend in ("dask", "modin", "cudf", "duckdb", "pyspark")
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -107,10 +104,7 @@ def test_explode_multiple_cols(
def test_explode_shape_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand All @@ -132,7 +126,7 @@ def test_explode_shape_error(
def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table", "dask")):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
Expand Down
Loading