diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 83bb1c2bce..85e11ee2c0 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from functools import reduce +from functools import partial, reduce from itertools import chain from typing import TYPE_CHECKING, Literal @@ -18,7 +18,8 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -176,6 +177,32 @@ def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: return pa.concat_tables(dfs, promote_options="default") return pa.concat_tables(dfs, promote=True) # pragma: no cover + def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: + native_schemas = tuple(table.schema for table in dfs) + out_schema = reduce( + merge_schemas, (Schema.from_arrow(pa_schema) for pa_schema in native_schemas) + ) + to_schemas = ( + { + name: dtype + for name, dtype in out_schema.items() + if name in native_schema.names + } + for native_schema in native_schemas + ) + version = self._version + to_compliant = partial( + self._dataframe, + version=version, + validate_backend_version=False, + validate_column_names=False, + ) + tables = tuple( + to_compliant(tbl).select(*safe_cast(self, to_schema)).native + for tbl, to_schema in zip(dfs, to_schemas) + ) + return self._concat_diagonal(tables) + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: names = list(chain.from_iterable(df.column_names for df in dfs)) arrays = tuple(chain.from_iterable(df.itercolumns() for df in dfs)) @@ -194,6 +221,21 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: raise TypeError(msg) return pa.concat_tables(dfs) + def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table: + out_schema = reduce(to_supertype, (Schema.from_arrow(tbl.schema) for tbl in dfs)) + version = self._version + to_compliant = partial( + self._dataframe, + version=version, + validate_backend_version=False, + validate_column_names=False, + ) + tables = ( + to_compliant(tbl).select(*safe_cast(self, out_schema)).native for tbl in dfs + ) + + return pa.concat_tables(tables) + @property def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace.from_namespace(self) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 98b1e16b50..eaf276793d 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -273,10 +273,16 @@ def from_numpy( return self._series.from_numpy(data, context=self) def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_diagonal_relaxed( + self, dfs: Sequence[NativeFrameT], / + ) -> NativeFrameT: ... def _concat_horizontal( self, dfs: Sequence[NativeFrameT | Any], / ) -> NativeFrameT: ... def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_vertical_relaxed( + self, dfs: Sequence[NativeFrameT], / + ) -> NativeFrameT: ... def concat( self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod ) -> EagerDataFrameT: @@ -285,8 +291,12 @@ def concat( native = self._concat_horizontal(dfs) elif how == "vertical": native = self._concat_vertical(dfs) + elif how == "vertical_relaxed": + native = self._concat_vertical_relaxed(dfs) elif how == "diagonal": native = self._concat_diagonal(dfs) + elif how == "diagonal_relaxed": + native = self._concat_diagonal_relaxed(dfs) else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index af3a9bc2f9..935e73d5d6 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -22,7 +22,9 @@ combine_alias_output_names, combine_evaluate_output_names, ) +from narwhals._pandas_like.utils import promote_dtype_backends from narwhals._utils import Implementation, is_nested_literal, zip_strict +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -143,10 +145,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def concat( self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod ) -> DaskLazyFrame: - if not items: - msg = "No items to concatenate" # pragma: no cover - raise AssertionError(msg) - dfs = [i._native_frame for i in items] + dfs = [item.native for item in items] cols_0 = dfs[0].columns if how == "vertical": for i, df in enumerate(dfs[1:], start=1): @@ -167,6 +166,24 @@ def concat( return DaskLazyFrame( dd.concat(dfs, axis=0, join="outer"), version=self._version ) + if how == "vertical_relaxed": + schemas = tuple(df.dtypes.to_dict() for df in dfs) + out_schema = reduce( + to_supertype, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) + + to_concat = [df.astype(out_schema) for df in dfs] + return DaskLazyFrame( + dd.concat(to_concat, axis=0, join="inner"), version=self._version + ) + if how == "diagonal_relaxed": + schemas = tuple(df.dtypes.to_dict() for df in dfs) + out_schema = reduce( + merge_schemas, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) + + native_res = dd.concat(dfs, axis=0, join="outer").astype(out_schema) + return DaskLazyFrame(native_res, version=self._version) raise NotImplementedError diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 71a94e4a95..86ae1b1048 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any import duckdb -from duckdb import CoalesceOperator, Expression +from duckdb import CoalesceOperator, DuckDBPyRelation, Expression from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr @@ -26,13 +26,12 @@ combine_evaluate_output_names, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable - from duckdb import DuckDBPyRelation # noqa: F401 - from narwhals._compliant.window import WindowInputs from narwhals._utils import Version from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral @@ -82,23 +81,55 @@ def _coalesce(self, *exprs: Expression) -> Expression: def concat( self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod ) -> DuckDBLazyFrame: - native_items = [item._native_frame for item in items] - items = list(items) + items = tuple(items) first = items[0] - schema = first.schema - if how == "vertical" and not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) + + if how == "vertical": + schema = first.schema + if not all(x.schema == schema for x in items[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) + + res = reduce(DuckDBPyRelation.union, (item.native for item in items)) + return first._with_native(res) + + if how == "vertical_relaxed": + schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) + out_schema = reduce(to_supertype, schemas) + native_items = ( + item.select(*safe_cast(self, out_schema)).native for item in items + ) + res = reduce(DuckDBPyRelation.union, native_items) + return first._with_native(res) + if how == "diagonal": - res = first.native - for _item in native_items[1:]: + res, *others = (item.native for item in items) + for _item in others: # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 res = duckdb.sql(""" from res select * union all by name from _item select * """) return first._with_native(res) - res = reduce(lambda x, y: x.union(y), native_items) - return first._with_native(res) + + if how == "diagonal_relaxed": + schemas = [Schema(df.collect_schema()) for df in items] + out_schema = reduce(merge_schemas, schemas) + native_items = ( + item.select( + *( + self.col(name) + if name in schema + else self.lit(None, dtype=dtype).alias(name) + for name, dtype in out_schema.items() + ) + ) + .select(*safe_cast(self, out_schema)) + .native + for item, schema in zip(items, schemas) + ) + res = reduce(DuckDBPyRelation.union, native_items) + return first._with_native(res) + raise NotImplementedError def concat_str( self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 801a04dbd4..eaf5305e76 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -18,7 +18,8 @@ from narwhals._ibis.selectors import IbisSelectorNamespace from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace -from narwhals._utils import Implementation +from narwhals._utils import Implementation, safe_cast +from narwhals.schema import Schema, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -68,8 +69,13 @@ def concat( self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod ) -> IbisLazyFrame: frames: Sequence[IbisLazyFrame] = tuple(items) - if how == "diagonal": + if how.startswith("diagonal"): frames = self.align_diagonal(frames) + + if how.endswith("relaxed"): + schemas = (Schema(frame.collect_schema()) for frame in frames) + out_schema = reduce(to_supertype, schemas) + frames = [frame.select(*safe_cast(self, out_schema)) for frame in frames] try: result = ibis.union(*(lf.native for lf in frames)) except ibis.IbisError: @@ -78,7 +84,8 @@ def concat( msg = "inputs should all have the same schema" raise TypeError(msg) from None raise - return frames[0]._with_native(result) + else: + return self._lazyframe.from_native(result, context=self) def concat_str( self, *exprs: IbisExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 604d06bd25..542f95c282 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -16,8 +16,15 @@ from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT -from narwhals._pandas_like.utils import is_non_nullable_boolean +from narwhals._pandas_like.utils import ( + cast_native, + is_non_nullable_boolean, + iter_cast_native, + native_schema, + promote_dtype_backends, +) from narwhals._utils import zip_strict +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -286,6 +293,21 @@ def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) + def _concat_diagonal_relaxed( + self, dfs: Sequence[NativeDataFrameT], / + ) -> NativeDataFrameT: + schemas = tuple(native_schema(df) for df in dfs) + out_schema = reduce( + merge_schemas, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) + + native_res = ( + self._concat(dfs, axis=VERTICAL, copy=False) + if self._implementation.is_pandas() and self._backend_version < (3,) + else self._concat(dfs, axis=VERTICAL) + ) + return cast_native(native_res, out_schema) + def _concat_horizontal( self, dfs: Sequence[NativeDataFrameT | NativeSeriesT], / ) -> NativeDataFrameT: @@ -318,6 +340,20 @@ def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFram return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) + def _concat_vertical_relaxed( + self, dfs: Sequence[NativeDataFrameT], / + ) -> NativeDataFrameT: + schemas = tuple(native_schema(df) for df in dfs) + out_schema = reduce( + to_supertype, (Schema.from_pandas_like(schema) for schema in schemas) + ).to_pandas(promote_dtype_backends(schemas, self._implementation)) + + if self._implementation.is_pandas() and self._backend_version < (3,): + return self._concat( + iter_cast_native(dfs, out_schema), axis=VERTICAL, copy=False + ) + return self._concat(iter_cast_native(dfs, out_schema), axis=VERTICAL) + def concat_str( self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool ) -> PandasLikeExpr: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 1b5b60e71b..1ae640dfa5 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,8 +1,8 @@ from __future__ import annotations -import functools import operator import re +from functools import lru_cache from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast import pandas as pd @@ -26,6 +26,7 @@ requires, ) from narwhals.exceptions import ShapeError +from narwhals.typing import DTypeBackend if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -45,7 +46,13 @@ NativeSeriesT, ) from narwhals.dtypes import DType - from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray + from narwhals.typing import ( + DTypeBackend, + IntoDType, + IntoPandasSchema, + TimeUnit, + _1DArray, + ) ExprT = TypeVar("ExprT", bound=PandasLikeExpr) UnitCurrent: TypeAlias = TimeUnit @@ -53,6 +60,7 @@ BinOpBroadcast: TypeAlias = Callable[[Any, int], Any] IntoRhs: TypeAlias = int +Incomplete: TypeAlias = Any PANDAS_LIKE_IMPLEMENTATION = { Implementation.PANDAS, @@ -188,7 +196,7 @@ def rename( return cast("NativeNDFrameT", result) # type: ignore[redundant-cast] -@functools.lru_cache(maxsize=16) +@lru_cache(maxsize=16) def is_dtype_non_pyarrow_string(native_dtype: Any) -> bool: """*There is no problem which can't be solved by adding an extra string type* pandas.""" # TODO @dangotbanned: Investigate how we could handle `cudf` without `str(native_dtype)` @@ -202,7 +210,7 @@ def is_dtype_non_pyarrow_string(native_dtype: Any) -> bool: } -@functools.lru_cache(maxsize=16) +@lru_cache(maxsize=16) def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912 dtype = str(native_dtype) @@ -381,7 +389,7 @@ def iter_dtype_backends( return (get_dtype_backend(dtype, implementation) for dtype in dtypes) -@functools.lru_cache(maxsize=16) +@lru_cache(maxsize=16) def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]: return hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype) @@ -697,3 +705,57 @@ def broadcast_series_to_index( return series_class(pa_array, index=index, name=native.name) return series_class(value, index=index, dtype=native.dtype, name=native.name) + + +_DTYPE_BACKEND_PRIORITY: dict[DTypeBackend, Literal[0, 1, 2]] = { + "pyarrow": 2, + "numpy_nullable": 1, + None: 0, +} + + +def iter_names_dtype_backends( + schema: IntoPandasSchema, impl: Implementation, / +) -> Iterator[tuple[str, DTypeBackend]]: + yield from zip(schema, iter_dtype_backends(schema.values(), impl)) + + +def promote_dtype_backends( + schemas: Iterable[IntoPandasSchema], implementation: Implementation +) -> Iterable[DTypeBackend]: + """Promote dtype backends for each column based on priority rules. + + Priority: pyarrow > numpy_nullable > None + """ + impl = implementation + it_schemas = iter(schemas) + col_backends = dict(iter_names_dtype_backends(next(it_schemas), impl)) + priority = _DTYPE_BACKEND_PRIORITY.__getitem__ + current = col_backends.__getitem__ + seen = col_backends.keys() + for schema in it_schemas: + col_backends.update( + (name, backend) + for name, backend in iter_names_dtype_backends(schema, impl) + if name not in seen or priority(backend) > priority(current(name)) + ) + return col_backends.values() + + +def native_schema(df: Incomplete) -> IntoPandasSchema: + return df.dtypes.to_dict() + + +def cast_native(df: NativeDataFrameT, schema: IntoPandasSchema) -> NativeDataFrameT: + df_: Incomplete = df + return cast("NativeDataFrameT", df_.astype(schema)) + + +def iter_cast_native( + dfs: Iterable[NativeDataFrameT], schema: IntoPandasSchema +) -> Iterator[NativeDataFrameT]: + if TYPE_CHECKING: + for df in dfs: + yield cast_native(df, schema) + else: + yield from (df.astype(schema) for df in dfs) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index ac8da364be..a8ddae6fe0 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, cast, overload import polars as pl @@ -22,7 +22,14 @@ from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame from narwhals._polars.typing import FrameT from narwhals._utils import Version, _LimitedContext - from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray + from narwhals.typing import ( + ConcatMethod, + Into1DArray, + IntoDType, + IntoSchema, + TimeUnit, + _2DArray, + ) class PolarsNamespace: @@ -130,10 +137,7 @@ def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr: return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version) def concat( - self, - items: Iterable[FrameT], - *, - how: Literal["vertical", "horizontal", "diagonal"], + self, items: Iterable[FrameT], *, how: ConcatMethod ) -> PolarsDataFrame | PolarsLazyFrame: result = pl.concat((item.native for item in items), how=how) if isinstance(result, pl.DataFrame): diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index b96b807575..f756bfa336 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,6 +19,8 @@ true_divide, ) from narwhals._sql.namespace import SQLNamespace +from narwhals._utils import safe_cast +from narwhals.schema import Schema, merge_schemas, to_supertype if TYPE_CHECKING: from collections.abc import Iterable @@ -159,10 +161,11 @@ def func(cols: Iterable[Column]) -> Column: def concat( self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod ) -> SparkLikeLazyFrame: - dfs = [item._native_frame for item in items] + items = tuple(items) if how == "vertical": - cols_0 = dfs[0].columns - for i, df in enumerate(dfs[1:], start=1): + first, *others = (item.native for item in items) + cols_0 = first.columns + for i, df in enumerate(others, start=1): cols_current = df.columns if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): msg = ( @@ -173,7 +176,7 @@ def concat( raise TypeError(msg) return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), dfs), + native_dataframe=reduce(lambda x, y: x.union(y), others, first), version=self._version, implementation=self._implementation, ) @@ -181,11 +184,48 @@ def concat( if how == "diagonal": return SparkLikeLazyFrame( native_dataframe=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs + lambda x, y: x.unionByName(y, allowMissingColumns=True), + (item.native for item in items), ), version=self._version, implementation=self._implementation, ) + + if how == "vertical_relaxed": + schemas: Iterable[Schema] = (Schema(df.collect_schema()) for df in items) + out_schema = reduce(to_supertype, schemas) + native_items = ( + item.select(*safe_cast(self, out_schema)).native for item in items + ) + union = items[0].native.__class__.union + return SparkLikeLazyFrame( + native_dataframe=reduce(union, native_items), + version=self._version, + implementation=self._implementation, + ) + + if how == "diagonal_relaxed": + schemas = tuple(Schema(item.collect_schema()) for item in items) + out_schema = reduce(merge_schemas, schemas) + native_items = ( + item.select( + *( + self.col(name) + if name in schema + else self.lit(None, dtype=dtype).alias(name) + for name, dtype in out_schema.items() + ) + ) + .select(*safe_cast(self, out_schema)) + .native + for item, schema in zip(items, schemas) + ) + union = items[0].native.__class__.union + return SparkLikeLazyFrame( + native_dataframe=reduce(union, native_items), + version=self._version, + implementation=self._implementation, + ) raise NotImplementedError def concat_str( diff --git a/narwhals/_utils.py b/narwhals/_utils.py index bc1795fb59..ac904d2060 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -68,7 +68,14 @@ TypeIs, ) - from narwhals._compliant import CompliantExprT, CompliantSeriesT, NativeSeriesT_co + from narwhals import dtypes + from narwhals._compliant import ( + CompliantExprT, + CompliantFrameT, + CompliantNamespace, + CompliantSeriesT, + NativeSeriesT_co, + ) from narwhals._compliant.any_namespace import NamespaceAccessor from narwhals._compliant.typing import ( Accessor, @@ -2153,3 +2160,14 @@ def __repr__(self) -> str: # pragma: no cover # Can be imported from types in Python 3.10 NoneType = type(None) + + +def safe_cast( + ns: CompliantNamespace[CompliantFrameT, CompliantExprT], + mapping: Mapping[str, dtypes.DType], +) -> Iterable[CompliantExprT]: + Unknown = ns._version.dtypes.Unknown() # noqa: N806 + return ( + ns.col(name) if dtype == Unknown else ns.col(name).cast(dtype) + for name, dtype in mapping.items() + ) diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index b6a445cb7e..27f88e5747 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -101,6 +101,10 @@ class UnsupportedDTypeError(NarwhalsError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" +class SchemaMismatchError(NarwhalsError): + """Exception raised when an unexpected schema mismatch causes an error.""" + + class NarwhalsUnstableWarning(UserWarning): """Warning issued when a method or function is considered unstable in the stable api.""" diff --git a/narwhals/functions.py b/narwhals/functions.py index a19478eb21..2655b0576f 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -62,11 +62,15 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT how: concatenating strategy - vertical: Concatenate vertically. Column names must match. + - vertical_relaxed: Same as vertical, but additionally coerces columns to + their common supertype if they are mismatched (eg: Int32 → Int64). - horizontal: Concatenate horizontally. If lengths don't match, then missing rows are filled with null values. This is only supported when all inputs are (eager) DataFrames. - diagonal: Finds a union between the column schemas and fills missing column values with null. + - diagonal_relaxed: Same as diagonal, but additionally coerces columns to + their common supertype if they are mismatched (eg: Int32 → Int64). Raises: TypeError: The items to concatenate should either all be eager, or all lazy @@ -140,8 +144,14 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT raise ValueError(msg) items = tuple(items) validate_laziness(items) - if how not in {"horizontal", "vertical", "diagonal"}: # pragma: no cover - msg = "Only vertical, horizontal and diagonal concatenations are supported." + if how not in { + "horizontal", + "vertical", + "diagonal", + "vertical_relaxed", + "diagonal_relaxed", + }: # pragma: no cover + msg = "Only vertical, vertical_relaxed, horizontal, diagonal and diagonal_relaxed concatenations are supported." raise NotImplementedError(msg) first_item = items[0] if is_narwhals_lazyframe(first_item) and how == "horizontal": diff --git a/narwhals/schema.py b/narwhals/schema.py index b459be68fa..aae3f5c370 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -21,6 +21,8 @@ is_pyarrow_data_type, is_pyarrow_schema, ) +from narwhals.dtypes._supertyping import get_supertype +from narwhals.exceptions import ComputeError, SchemaMismatchError if TYPE_CHECKING: from collections.abc import Iterable @@ -363,3 +365,63 @@ def _from_pandas_like( (name, native_to_narwhals_dtype(dtype, cls._version, impl, allow_object=True)) for name, dtype in schema.items() ) + + +def _supertype(left: DType, right: DType) -> DType: + if promoted_dtype := get_supertype(left, right): + return promoted_dtype + msg = f"failed to determine supertype of {left} and {right}" + raise SchemaMismatchError(msg) + + +def _ensure_names_match(left: Schema, right: Schema) -> tuple[Schema, Schema]: + if len(left) != len(right): + msg = "schema lengths differ" + raise ComputeError(msg) + if left.names() != right.names(): + it = ((lname, rname) for (lname, rname) in zip(left, right) if lname != rname) + lname, rname = next(it) + msg = f"schema names differ: got {rname}, expected {lname}" + raise ComputeError(msg) + return left, right + + +def to_supertype(left: Schema, right: Schema) -> Schema: + # Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96 + left, right = _ensure_names_match(left, right) + it = zip(left.keys(), left.values(), right.values()) + return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it) + + +def merge_schemas(left: Schema, right: Schema) -> Schema: + """Merge two schemas, combining columns and resolving types to their supertype. + + This function merges two schemas by: + + 1. Taking all columns from the left schema (preserving order). + 2. Appending any columns from the right schema that are missing in the left. + 3. For columns present in both schemas, promoting to the common supertype. + + Arguments: + left: The primary schema whose column order takes precedence. + right: The secondary schema to merge. Columns missing in left are appended. + + Returns: + A new Schema with merged columns and supertypes resolved. + + Raises: + SchemaMismatchError: If no common supertype exists for a shared column. + + Note: + For columns present only in one schema, the type from that schema is used + (via `right.get(name, left_type)` fallback for right-only columns). + """ + left_names = set(left.keys()) + missing_in_left = (kv for kv in right.items() if kv[0] not in left_names) + + extended_left = Schema((*left.items(), *missing_in_left)) + # Reorder right to match: left keys first, then right-only keys + extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items()) + + it = zip(extended_left.keys(), extended_left.values(), extended_right.values()) + return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it) diff --git a/narwhals/typing.py b/narwhals/typing.py index 9454186722..b6346e1d3d 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -192,14 +192,20 @@ def Binary(self) -> type[dtypes.Binary]: ... ClosedInterval: TypeAlias = Literal["left", "right", "none", "both"] """Define which sides of the interval are closed (inclusive).""" -ConcatMethod: TypeAlias = Literal["horizontal", "vertical", "diagonal"] +ConcatMethod: TypeAlias = Literal[ + "horizontal", "vertical", "vertical_relaxed", "diagonal", "diagonal_relaxed" +] """Concatenating strategy. - *"vertical"*: Concatenate vertically. Column names must match. +- *"vertical_relaxed"*: Same as vertical, but additionally coerces columns to their common + supertype if they are mismatched (eg: Int32 → Int64). - *"horizontal"*: Concatenate horizontally. If lengths don't match, then missing rows are filled with null values. - *"diagonal"*: Finds a union between the column schemas and fills missing column values with null. +- *"diagonal_relaxed"*: Same as diagonal, but additionally coerces columns to their common + supertype if they are mismatched (eg: Int32 → Int64). """ FillNullStrategy: TypeAlias = Literal["forward", "backward"] diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 8d8b4b2ce6..83d34d9c7c 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -3,17 +3,35 @@ import datetime as dt import re from typing import TYPE_CHECKING, Any +from uuid import uuid4 as make_uuid import pytest import narwhals as nw from narwhals._utils import Implementation from narwhals.exceptions import InvalidOperationError, NarwhalsError -from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data +from narwhals.schema import Schema +from tests.utils import ( + POLARS_VERSION, + PYARROW_VERSION, + Constructor, + ConstructorEager, + assert_equal_data, +) if TYPE_CHECKING: from collections.abc import Iterator + from narwhals.dtypes import DType + from narwhals.typing import IntoSchema + +if TYPE_CHECKING: + from narwhals.typing import LazyFrameT + + +def _cast(frame: LazyFrameT, schema: IntoSchema) -> LazyFrameT: + return frame.select(nw.col(name).cast(dtype) for name, dtype in schema.items()) + def test_concat_horizontal(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} @@ -138,3 +156,204 @@ def test_concat_diagonal_invalid( context = pytest.raises((InvalidOperationError, TypeError), match=r"same schema") with context: nw.concat([df_1, bad_schema], how="diagonal").collect().to_dict(as_series=False) + + +@pytest.mark.parametrize( + ("ldata", "lschema", "rdata", "rschema", "expected_data", "expected_schema"), + [ + ( + {"a": [1, 2, 3], "b": [True, False, None]}, + {"a": nw.Int8(), "b": nw.Boolean()}, + {"a": [43, 2, 3], "b": [32, 1, None]}, + {"a": nw.Int16(), "b": nw.Int64()}, + {"a": [1, 2, 3, 43, 2, 3], "b": [1, 0, None, 32, 1, None]}, + {"a": nw.Int16(), "b": nw.Int64()}, + ), + ( + {"a": [1, 2], "b": [2, 1]}, + {"a": nw.Int32(), "b": nw.Int32()}, + {"a": [1.0, 0.2], "b": [None, 0.1]}, + {"a": nw.Float32(), "b": nw.Float32()}, + {"a": [1.0, 2.0, 1.0, 0.2], "b": [2.0, 1.0, None, 0.1]}, + {"a": nw.Float64(), "b": nw.Float64()}, + ), + ], + ids=["nullable-integer", "nullable-float"], +) +def test_concat_vertically_relaxed( + constructor: Constructor, + ldata: dict[str, Any], + lschema: dict[str, DType], + rdata: dict[str, Any], + rschema: dict[str, DType], + expected_data: dict[str, Any], + expected_schema: dict[str, DType], + request: pytest.FixtureRequest, +) -> None: + # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L64 + is_nullable_int = request.node.callspec.id.endswith("nullable-integer") + if is_nullable_int and any( + x in str(constructor) + for x in ("dask", "pandas_constructor", "modin_constructor", "cudf") + ): + reason = "Cannot convert non-finite values (NA or inf)" + request.applymarker(pytest.mark.xfail(reason=reason)) + left = nw.from_native(constructor(ldata)).lazy().pipe(_cast, lschema) + right = nw.from_native(constructor(rdata)).lazy().pipe(_cast, rschema) + result = nw.concat([left, right], how="vertical_relaxed") + + assert result.collect_schema() == Schema(expected_schema) + assert_equal_data(result.collect(), expected_data) + + result = nw.concat([right, left], how="vertical_relaxed") + assert result.collect_schema() == Schema(expected_schema) + + +@pytest.mark.parametrize( + ("schema1", "schema2", "schema3", "expected_schema"), + [ + ( + {"a": nw.Int32(), "c": nw.Int64()}, + {"a": nw.Float64(), "b": nw.Float32()}, + {"b": nw.Int32(), "c": nw.Int32()}, + {"a": nw.Float64(), "c": nw.Int64(), "b": nw.Float64()}, + ), + ( + {"a": nw.Float32(), "c": nw.Float32()}, + {"a": nw.Float64(), "b": nw.Float32()}, + {"b": nw.Float32(), "c": nw.Float32()}, + {"a": nw.Float64(), "c": nw.Float32(), "b": nw.Float32()}, + ), + ], + ids=["nullable-integer", "nullable-float"], +) +def test_concat_diagonal_relaxed( + constructor: Constructor, + schema1: dict[str, DType], + schema2: dict[str, DType], + schema3: dict[str, DType], + expected_schema: dict[str, DType], + request: pytest.FixtureRequest, +) -> None: + # Adapted from https://github.com/pola-rs/polars/blob/b0fdbd34d430d934bda9a4ca3f75e136223bd95b/py-polars/tests/unit/functions/test_concat.py#L265C1-L288C41 + is_nullable_int = request.node.callspec.id.endswith("nullable-integer") + if is_nullable_int and any( + x in str(constructor) + for x in ("dask", "pandas_constructor", "modin_constructor", "cudf") + ): + reason = "Cannot convert non-finite values (NA or inf)" + request.applymarker(pytest.mark.xfail(reason=reason)) + + base_schema = {"idx": nw.Int32()} + + data1 = {"idx": [0, 1], "a": [1, 2], "c": [10, 20]} + df1 = nw.from_native(constructor(data1)).lazy().pipe(_cast, base_schema | schema1) + + data2 = {"a": [3.5, 4.5], "b": [30.1, 40.2], "idx": [2, 3]} + df2 = nw.from_native(constructor(data2)).lazy().pipe(_cast, base_schema | schema2) + + data3 = {"b": [5, 6], "idx": [4, 5], "c": [50, 60]} + df3 = nw.from_native(constructor(data3)).lazy().pipe(_cast, base_schema | schema3) + + result = nw.concat([df1, df2, df3], how="diagonal_relaxed").sort("idx") + out_schema = result.collect_schema() + assert out_schema == Schema(base_schema | expected_schema) + + expected_data = { + "idx": [0, 1, 2, 3, 4, 5], + "a": [1.0, 2.0, 3.5, 4.5, None, None], + "c": [10, 20, None, None, 50, 60], + "b": [None, None, 30.1, 40.2, 5.0, 6.0], + } + assert_equal_data(result.collect(), expected_data) + + +@pytest.mark.skipif(PYARROW_VERSION < (19, 0, 0), reason="Too old for pyarrow.uuid type") +def test_pyarrow_concat_vertical_uuid() -> None: + # Test that concat vertical and vertical_relaxed preserves unsupported types like UUID + pa = pytest.importorskip("pyarrow") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + + data1 = {"id": pa.array([id1.bytes, id2.bytes], type=pa.uuid()), "a": [1, 2]} + data2 = {"id": pa.array([id3.bytes, id4.bytes], type=pa.uuid()), "a": [3.14, 4.2]} + frame1, frame2 = nw.from_native(pa.table(data1)), nw.from_native(pa.table(data2)) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical").to_native() + + assert res_v.schema.field("id").type == pa.uuid() + assert res_v["id"].to_pylist() == [id1, id2, id1, id2] + assert res_v["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed").to_native() + + assert res_vr.schema.field("id").type == pa.uuid() + assert res_vr["id"].to_pylist() == [id1, id2, id3, id4] + assert res_vr["a"].to_pylist() == [1.0, 2.0, 3.14, 4.2] + + +def test_concat_vertical_relaxed_duckdb_uuid() -> None: + # Test that concat vertical_relaxed preserves UUID type for DuckDB + duckdb = pytest.importorskip("duckdb") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + conn = duckdb.connect() + + rel1 = conn.sql( + f"SELECT '{id1}'::UUID as id, 1 as a UNION ALL SELECT '{id2}'::UUID, 2" + ) + rel2 = conn.sql( + f"SELECT '{id3}'::UUID as id, 3.14 as a UNION ALL SELECT '{id4}'::UUID, 4.2" + ) + + frame1, frame2 = nw.from_native(rel1), nw.from_native(rel2) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical") + + assert res_v.to_native().types[0] == "UUID" + res_v_pa = res_v.collect().to_native() + assert res_v_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id1, id2)] + assert res_v_pa["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed") + + assert res_vr.to_native().types[0] == "UUID" + res_vr_pa = res_vr.collect().to_native() + assert res_vr_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id3, id4)] + assert [float(v) for v in res_vr_pa["a"].to_pylist()] == [1.0, 2.0, 3.14, 4.2] + + +def test_concat_vertical_relaxed_ibis_uuid() -> None: + """Test that concat vertical_relaxed preserves UUID type for Ibis.""" + ibis = pytest.importorskip("ibis") + + id1, id2, id3, id4 = [make_uuid() for _ in range(4)] + + t1 = ibis.memtable( + {"id": [str(id1), str(id2)], "a": [1, 2]}, schema={"id": "uuid", "a": "int"} + ) + t2 = ibis.memtable( + {"id": [str(id3), str(id4)], "a": [3.14, 4.2]}, + schema={"id": "uuid", "a": "float"}, + ) + frame1, frame2 = nw.from_native(t1), nw.from_native(t2) + + # Vertical + res_v = nw.concat([frame1, frame1], how="vertical") + + assert res_v.to_native().schema()["id"] == ibis.dtype("uuid") + res_v_pa = res_v.collect().to_native() + assert res_v_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id1, id2)] + assert res_v_pa["a"].to_pylist() == [1, 2, 1, 2] + + # Vertical relaxed + res_vr = nw.concat([frame1, frame2], how="vertical_relaxed") + + assert res_vr.to_native().schema()["id"] == ibis.dtype("uuid") + res_vr_pa = res_vr.collect().to_native() + assert res_vr_pa["id"].to_pylist() == [str(v) for v in (id1, id2, id3, id4)] + assert [float(v) for v in res_vr_pa["a"].to_pylist()] == [1.0, 2.0, 3.14, 4.2] diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index d90916b029..129bea142e 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -8,7 +8,8 @@ import pytest import narwhals as nw -from narwhals.exceptions import PerformanceWarning +from narwhals.exceptions import ComputeError, PerformanceWarning, SchemaMismatchError +from narwhals.schema import to_supertype from tests.utils import PANDAS_VERSION, POLARS_VERSION, ConstructorPandasLike if TYPE_CHECKING: @@ -22,6 +23,7 @@ IntoArrowSchema, IntoPandasSchema, IntoPolarsSchema, + IntoSchema, ) from tests.utils import Constructor, ConstructorEager @@ -704,3 +706,55 @@ def test_schema_from_to_roundtrip() -> None: assert nw_schema_1 == nw_schema_2 assert nw_schema_2 == nw_schema_3 assert py_schema_1 == py_schema_2 + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + ( + {"a": nw.Int64(), "b": nw.String()}, + {"a": nw.Int64(), "b": nw.String()}, + {"a": nw.Int64(), "b": nw.String()}, + ), + ( + {"a": nw.Int32(), "b": nw.Float32()}, + {"a": nw.Int64(), "b": nw.Float64()}, + {"a": nw.Int64(), "b": nw.Float64()}, + ), + ({"a": nw.Int32()}, {"a": nw.Float64()}, {"a": nw.Float64()}), + ({"a": nw.Datetime("ns")}, {"a": nw.Datetime("us")}, {"a": nw.Datetime("us")}), + ], +) +def test_to_supertype(left: IntoSchema, right: IntoSchema, expected: IntoSchema) -> None: + result = to_supertype(nw.Schema(left), nw.Schema(right)) + assert result == expected + + +@pytest.mark.parametrize( + ("left", "right", "context"), + [ + ( + {"a": nw.Int64()}, + {"a": nw.Int64(), "b": nw.String()}, + pytest.raises(ComputeError, match="schema lengths differ"), + ), + ( + {"a": nw.Int64()}, + {"b": nw.Int64()}, + pytest.raises(ComputeError, match="schema names differ: got b, expected a"), + ), + ( + {"a": nw.Binary()}, + {"a": nw.Int64()}, + pytest.raises( + SchemaMismatchError, + match="failed to determine supertype of Binary and Int64", + ), + ), + ], +) +def test_to_supertype_exceptions( + left: IntoSchema, right: IntoSchema, context: pytest.RaisesExc +) -> None: + with context: + to_supertype(nw.Schema(left), nw.Schema(right))