diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 134711fb41..47ff56e298 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -5,8 +5,8 @@ from itertools import chain from typing import TYPE_CHECKING from typing import Any -from typing import Iterable from typing import Literal +from typing import Sequence import pyarrow as pa import pyarrow.compute as pc @@ -17,9 +17,6 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import cast_to_comparable_string_types -from narwhals._arrow.utils import diagonal_concat -from narwhals._arrow.utils import horizontal_concat -from narwhals._arrow.utils import vertical_concat from narwhals._compliant import CompliantThen from narwhals._compliant import EagerNamespace from narwhals._compliant import EagerWhen @@ -34,7 +31,6 @@ from narwhals._arrow.typing import ArrowChunkedArray from narwhals._arrow.typing import Incomplete from narwhals.dtypes import DType - from narwhals.typing import ConcatMethod from narwhals.utils import Version @@ -211,30 +207,29 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) - def concat( - self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod - ) -> ArrowDataFrame: - dfs = [item.native for item in items] - - if not dfs: - msg = "No dataframes to concatenate" # pragma: no cover - raise AssertionError(msg) - - if how == "horizontal": - result_table = horizontal_concat(dfs) - elif how == "vertical": - result_table = vertical_concat(dfs) - elif how == "diagonal": - result_table = diagonal_concat(dfs, self._backend_version) - else: - raise NotImplementedError - - return ArrowDataFrame( - result_table, - backend_version=self._backend_version, - version=self._version, - validate_column_names=True, - ) + # NOTE: Stub issue fixed in https://github.com/zen-xu/pyarrow-stubs/pull/203 + def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + if self._backend_version >= (14,): + return pa.concat_tables(dfs, promote_options="default") # type: ignore[arg-type] + return pa.concat_tables(dfs, promote=True) # type: ignore[arg-type] # pragma: no cover + + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + names = list(chain.from_iterable(df.column_names for df in dfs)) + arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) + return pa.Table.from_arrays(arrays, names=names) + + def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: + cols_0 = dfs[0].column_names + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.column_names + if cols_current != cols_0: + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + return pa.concat_tables(dfs) # type: ignore[arg-type] @property def selectors(self: Self) -> ArrowSelectorNamespace: diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 380388d51d..661f50ff68 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations from functools import lru_cache -from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Iterable @@ -280,52 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: return reshaped -def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: - """Concatenate (native) DataFrames horizontally. - - Should be in namespace. - """ - names = [name for df in dfs for name in df.column_names] - - if len(set(names)) < len(names): # pragma: no cover - msg = "Expected unique column names" - raise ValueError(msg) - arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) - return pa.Table.from_arrays(arrays, names=names) - - -def vertical_concat(dfs: list[pa.Table]) -> pa.Table: - """Concatenate (native) DataFrames vertically. - - Should be in namespace. - """ - cols_0 = dfs[0].column_names - for i, df in enumerate(dfs[1:], start=1): - cols_current = df.column_names - if cols_current != cols_0: - msg = ( - "unable to vstack, column names don't match:\n" - f" - dataframe 0: {cols_0}\n" - f" - dataframe {i}: {cols_current}\n" - ) - raise TypeError(msg) - - return pa.concat_tables(dfs) - - -def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa.Table: - """Concatenate (native) DataFrames diagonally. - - Should be in namespace. - """ - kwargs: dict[str, Any] = ( - {"promote": True} - if backend_version < (14, 0, 0) - else {"promote_options": "default"} - ) - return pa.concat_tables(dfs, **kwargs) - - def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index fbfe626f2c..4c6c469d0d 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -18,8 +18,8 @@ from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import LazyExprT +from narwhals._compliant.typing import NativeFrameT from narwhals._compliant.typing import NativeFrameT_co -from narwhals._compliant.typing import NativeFrameT_contra from narwhals._compliant.typing import NativeSeriesT from narwhals.dependencies import is_numpy_array_2d from narwhals.utils import exclude_column_names @@ -130,9 +130,7 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: class EagerNamespace( DepthTrackingNamespace[EagerDataFrameT, EagerExprT], - Protocol[ - EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT_contra, NativeSeriesT - ], + Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], ): @property def _dataframe(self) -> type[EagerDataFrameT]: ... @@ -143,11 +141,11 @@ def when( ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ... @overload - def from_native(self, data: NativeFrameT_contra, /) -> EagerDataFrameT: ... + def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ... @overload def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ... def from_native( - self, data: NativeFrameT_contra | NativeSeriesT | Any, / + self, data: NativeFrameT | NativeSeriesT | Any, / ) -> EagerDataFrameT | EagerSeriesT: if self._dataframe._is_native(data): return self._dataframe.from_native(data, context=self) @@ -181,3 +179,22 @@ def from_numpy( if is_numpy_array_2d(data): return self._dataframe.from_numpy(data, schema=schema, context=self) return self._series.from_numpy(data, context=self) + + def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_horizontal( + self, dfs: Sequence[NativeFrameT | Any], / + ) -> NativeFrameT: ... + def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def concat( + self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod + ) -> EagerDataFrameT: + dfs = [item.native for item in items] + if how == "horizontal": + native = self._concat_horizontal(dfs) + elif how == "vertical": + native = self._concat_vertical(dfs) + elif how == "diagonal": + native = self._concat_diagonal(dfs) + else: + raise NotImplementedError + return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index a7cf2a8e81..ce28aee4cf 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -5,7 +5,6 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any -from typing import Iterable from typing import Literal from typing import Sequence @@ -27,7 +26,6 @@ from narwhals._pandas_like.typing import NDFrameT from narwhals.dtypes import DType - from narwhals.typing import ConcatMethod from narwhals.utils import Implementation from narwhals.utils import Version @@ -274,20 +272,6 @@ def _concat_vertical(self, dfs: Sequence[pd.DataFrame], /) -> pd.DataFrame: return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) - def concat( - self, items: Iterable[PandasLikeDataFrame], *, how: ConcatMethod - ) -> PandasLikeDataFrame: - dfs: list[pd.DataFrame] = [item.native for item in items] - if how == "horizontal": - native = self._concat_horizontal(dfs) - elif how == "vertical": - native = self._concat_vertical(dfs) - elif how == "diagonal": - native = self._concat_diagonal(dfs) - else: - raise NotImplementedError - return self._dataframe.from_native(native, context=self) - def when(self: Self, predicate: PandasLikeExpr) -> PandasWhen: return PandasWhen.from_expr(predicate, context=self)