Skip to content
53 changes: 24 additions & 29 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -17,9 +17,6 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import align_series_full_broadcast
from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._arrow.utils import diagonal_concat
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import vertical_concat
from narwhals._compliant import CompliantThen
from narwhals._compliant import EagerNamespace
from narwhals._compliant import EagerWhen
Expand All @@ -34,7 +31,6 @@
from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Incomplete
from narwhals.dtypes import DType
from narwhals.typing import ConcatMethod
from narwhals.utils import Version


Expand Down Expand Up @@ -211,30 +207,29 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
context=self,
)

def concat(
self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod
) -> ArrowDataFrame:
dfs = [item.native for item in items]

if not dfs:
msg = "No dataframes to concatenate" # pragma: no cover
raise AssertionError(msg)

if how == "horizontal":
result_table = horizontal_concat(dfs)
elif how == "vertical":
result_table = vertical_concat(dfs)
elif how == "diagonal":
result_table = diagonal_concat(dfs, self._backend_version)
else:
raise NotImplementedError

return ArrowDataFrame(
result_table,
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)
# NOTE: Stub issue fixed in https://github.com/zen-xu/pyarrow-stubs/pull/203
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
if self._backend_version >= (14,):
return pa.concat_tables(dfs, promote_options="default") # type: ignore[arg-type]
return pa.concat_tables(dfs, promote=True) # type: ignore[arg-type] # pragma: no cover

def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
names = list(chain.from_iterable(df.column_names for df in dfs))
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
return pa.Table.from_arrays(arrays, names=names)

def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
cols_0 = dfs[0].column_names
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.column_names
if cols_current != cols_0:
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)
return pa.concat_tables(dfs) # type: ignore[arg-type]

@property
def selectors(self: Self) -> ArrowSelectorNamespace:
Expand Down
47 changes: 0 additions & 47 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -280,52 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
return reshaped


def horizontal_concat(dfs: list[pa.Table]) -> pa.Table:
"""Concatenate (native) DataFrames horizontally.

Should be in namespace.
"""
names = [name for df in dfs for name in df.column_names]

if len(set(names)) < len(names): # pragma: no cover
msg = "Expected unique column names"
raise ValueError(msg)
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
return pa.Table.from_arrays(arrays, names=names)


def vertical_concat(dfs: list[pa.Table]) -> pa.Table:
"""Concatenate (native) DataFrames vertically.

Should be in namespace.
"""
cols_0 = dfs[0].column_names
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.column_names
if cols_current != cols_0:
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)

return pa.concat_tables(dfs)


def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa.Table:
"""Concatenate (native) DataFrames diagonally.

Should be in namespace.
"""
kwargs: dict[str, Any] = (
{"promote": True}
if backend_version < (14, 0, 0)
else {"promote_options": "default"}
)
return pa.concat_tables(dfs, **kwargs)


def floordiv_compat(left: Any, right: Any) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
Expand Down
29 changes: 23 additions & 6 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import LazyExprT
from narwhals._compliant.typing import NativeFrameT
from narwhals._compliant.typing import NativeFrameT_co
from narwhals._compliant.typing import NativeFrameT_contra
from narwhals._compliant.typing import NativeSeriesT
from narwhals.dependencies import is_numpy_array_2d
from narwhals.utils import exclude_column_names
Expand Down Expand Up @@ -130,9 +130,7 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:

class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[
EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT_contra, NativeSeriesT
],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
Expand All @@ -143,11 +141,11 @@ def when(
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...

@overload
def from_native(self, data: NativeFrameT_contra, /) -> EagerDataFrameT: ...
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT_contra | NativeSeriesT | Any, /
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
Expand Down Expand Up @@ -181,3 +179,22 @@ def from_numpy(
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)

def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
Copy link
Member Author

@dangotbanned dangotbanned Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note

We could pretty easily support how="vertical" for NativeSeriesT

That is the only ConcatMethod that polars allows for Series

  • pyarrow would use pa.concat_arrays
  • (pandas|polars) just use their concat function again

I'm not in a rush to add it, but this PR certainly makes that next step quite simple πŸ˜…

def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else:
raise NotImplementedError
return self._dataframe.from_native(native, context=self)
16 changes: 0 additions & 16 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence

Expand All @@ -27,7 +26,6 @@

from narwhals._pandas_like.typing import NDFrameT
from narwhals.dtypes import DType
from narwhals.typing import ConcatMethod
from narwhals.utils import Implementation
from narwhals.utils import Version

Expand Down Expand Up @@ -274,20 +272,6 @@ def _concat_vertical(self, dfs: Sequence[pd.DataFrame], /) -> pd.DataFrame:
return self._concat(dfs, axis=VERTICAL, copy=False)
return self._concat(dfs, axis=VERTICAL)

def concat(
self, items: Iterable[PandasLikeDataFrame], *, how: ConcatMethod
) -> PandasLikeDataFrame:
dfs: list[pd.DataFrame] = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else:
raise NotImplementedError
return self._dataframe.from_native(native, context=self)

def when(self: Self, predicate: PandasLikeExpr) -> PandasWhen:
return PandasWhen.from_expr(predicate, context=self)

Expand Down
Loading