diff --git a/narwhals/functions.py b/narwhals/functions.py index 56457d3f31..b5f10927b6 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -9,8 +9,8 @@ from typing import Literal from typing import Mapping from typing import Sequence +from typing import TypeVar from typing import cast -from typing import overload from narwhals._expression_parsing import ExpansionKind from narwhals._expression_parsing import ExprKind @@ -57,46 +57,21 @@ from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.series import Series - from narwhals.typing import IntoDataFrameT from narwhals.typing import IntoExpr - from narwhals.typing import IntoFrameT from narwhals.typing import IntoSeriesT from narwhals.typing import NativeFrame from narwhals.typing import NativeLazyFrame from narwhals.typing import _2DArray _IntoSchema: TypeAlias = "Mapping[str, DType] | Schema | Sequence[str] | None" + FrameT = TypeVar("FrameT", "DataFrame[Any]", "LazyFrame[Any]") -@overload def concat( - items: Iterable[DataFrame[IntoDataFrameT]], + items: Iterable[FrameT], *, how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT]: ... - - -@overload -def concat( - items: Iterable[LazyFrame[IntoFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> LazyFrame[IntoFrameT]: ... - - -@overload -def concat( - items: Iterable[DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]: ... - - -def concat( - items: Iterable[DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]: +) -> FrameT: """Concatenate multiple DataFrames, LazyFrames into a single entity. Arguments: diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index deb33e6481..79712278f4 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -99,6 +99,7 @@ from narwhals.typing import _1DArray from narwhals.typing import _2DArray + FrameT = TypeVar("FrameT", "DataFrame[Any]", "LazyFrame[Any]") IntoSeriesT = TypeVar("IntoSeriesT", bound="IntoSeries", default=Any) T = TypeVar("T", default=Any) else: @@ -2049,35 +2050,11 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _stableify(nw.max_horizontal(*exprs)) -@overload -def concat( - items: Iterable[DataFrame[IntoDataFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT]: ... - - -@overload -def concat( - items: Iterable[LazyFrame[IntoFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> LazyFrame[IntoFrameT]: ... - - -@overload -def concat( - items: Iterable[DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]: ... - - def concat( - items: Iterable[DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]], + items: Iterable[FrameT], *, how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoFrameT]: +) -> FrameT: """Concatenate multiple DataFrames, LazyFrames into a single entity. Arguments: @@ -2096,7 +2073,7 @@ def concat( Raises: TypeError: The items to concatenate should either all be eager, or all lazy """ - return _stableify(nw.concat(items, how=how)) + return cast("FrameT", _stableify(nw.concat(items, how=how))) def concat_str( diff --git a/tests/frame/invalid_test.py b/tests/frame/invalid_test.py index 0891d55adc..888a4c0010 100644 --- a/tests/frame/invalid_test.py +++ b/tests/frame/invalid_test.py @@ -83,7 +83,7 @@ def test_validate_laziness() -> None: TypeError, match=("The items to concatenate should either all be eager, or all lazy"), ): - nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) + nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) # type: ignore[type-var] @pytest.mark.slow