diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 14cb7657f7..0568fa7dda 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -20,6 +20,7 @@ check_column_names_are_unique, convert_str_slice_to_int_slice, generate_temporary_column_name, + is_sequence_of, not_implemented, parse_columns_to_drop, scale_bytes, @@ -55,6 +56,7 @@ from narwhals.typing import ( IntoSchema, JoinStrategy, + SchemaDefinition, SizedMultiIndexSelector, SizedMultiNameSelector, SizeUnit, @@ -119,36 +121,39 @@ def from_dict( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: + from narwhals._utils import NullableSchema + if not schema and not data: return cls.from_native(pa.table({}), context=context) if not schema: return cls.from_native(pa.table(data), context=context) # type: ignore[arg-type] - if not any(dtype is None for dtype in schema.values()): - from narwhals.schema import Schema + nullable_schema = NullableSchema(schema) + + if nullable_schema.is_nullable: + if context._implementation._backend_version() < (14,): + msg = "Passing `None` dtype in `from_dict` requires PyArrow>=14" + raise NotImplementedError(msg) + res = pa.table( + { + name: pa.chunked_array( # type: ignore[misc] + [data[name] if data else []], + type=narwhals_to_native_dtype(nw_dtype, version=context._version) + if nw_dtype is not None + else None, + ) + for name, nw_dtype in nullable_schema.items() + } + ) + return cls.from_native(pa.table(res), context=context) - pa_schema = Schema(cast("IntoSchema", schema)).to_arrow() - if pa_schema and not data: - native = pa_schema.empty_table() - else: - native = pa.Table.from_pydict(data, schema=pa_schema) - return cls.from_native(native, context=context) - if context._implementation._backend_version() < (14,): - msg = "Passing `None` dtype in `from_dict` requires PyArrow>=14" - raise NotImplementedError(msg) - res = pa.table( - { - name: pa.chunked_array( # type: ignore[misc] - [data[name] if data else []], - type=narwhals_to_native_dtype(nw_dtype, version=context._version) - if nw_dtype is not None - else None, - ) - for name, nw_dtype in schema.items() - } - ) - return cls.from_native(pa.table(res), context=context) + pa_schema = nullable_schema.to_schema().to_arrow() + if pa_schema and not data: + native = pa_schema.empty_table() + else: + native = pa.Table.from_pydict(data, schema=pa_schema) + return cls.from_native(native, context=context) @classmethod def from_dicts( @@ -157,17 +162,15 @@ def from_dicts( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: - from narwhals.schema import Schema + from narwhals._utils import NullableSchema - if schema and any(dtype is None for dtype in schema.values()): + if schema and (nullable_schema := NullableSchema(schema)).is_nullable: msg = "`from_dicts` with `schema` where any dtype is `None` is not supported for PyArrow." raise NotImplementedError(msg) pa_schema = ( - Schema(cast("IntoSchema", schema)).to_arrow() - if schema is not None - else schema + nullable_schema.to_schema().to_arrow() if schema is not None else schema ) if pa_schema and not data: native = pa_schema.empty_table() @@ -195,10 +198,10 @@ def from_numpy( from narwhals.schema import Schema arrays = [pa.array(val) for val in data.T] - if isinstance(schema, (Mapping, Schema)): - native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow()) - else: + if is_sequence_of(schema, str) or schema is None: native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema)) + else: + native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow()) return cls.from_native(native, context=context) def __narwhals_namespace__(self) -> ArrowNamespace: diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 458b5a3794..3c3c5efb19 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -66,6 +66,7 @@ MultiColSelector, MultiIndexSelector, PivotAgg, + SchemaDefinition, SingleIndexSelector, SizedMultiIndexSelector, SizedMultiNameSelector, @@ -191,7 +192,7 @@ def from_dict( /, *, context: CompliantNamespaceAny, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: ... @classmethod def from_dicts( @@ -200,7 +201,7 @@ def from_dicts( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: ... @classmethod def from_numpy( diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index f16ffd5529..283f131324 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Mapping, Sequence from itertools import chain, product from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload @@ -28,6 +27,7 @@ check_column_names_are_unique, exclude_column_names, generate_temporary_column_name, + is_sequence_of, parse_columns_to_drop, scale_bytes, zip_strict, @@ -36,6 +36,7 @@ from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence from io import BytesIO from pathlib import Path from types import ModuleType @@ -59,6 +60,7 @@ IntoSchema, JoinStrategy, PivotAgg, + SchemaDefinition, SizedMultiIndexSelector, SizedMultiNameSelector, SizeUnit, @@ -148,8 +150,11 @@ def from_dict( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: + from narwhals._utils import NullableSchema + + schema = NullableSchema(schema) if schema is not None else None implementation = context._implementation pdx = implementation.to_native_namespace() Series = cast("type[pd.Series[Any]]", pdx.Series) @@ -196,8 +201,11 @@ def from_dicts( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: + from narwhals._utils import NullableSchema + + schema = NullableSchema(schema) if schema is not None else None implementation = context._implementation ns = implementation.to_native_namespace() DataFrame = cast("type[pd.DataFrame]", ns.DataFrame) @@ -250,16 +258,15 @@ def from_numpy( implementation = context._implementation DataFrame: Constructor = implementation.to_native_namespace().DataFrame - if isinstance(schema, (Mapping, Schema)): + if is_sequence_of(schema, str) or schema is None: + native = DataFrame(data, columns=cls._numpy_column_names(data, schema)) + else: + schema = Schema(schema) it: Iterable[DTypeBackend] = ( get_dtype_backend(native_type, implementation) for native_type in schema.values() ) - native = DataFrame(data, columns=schema.keys()).astype( - Schema(schema).to_pandas(it) - ) - else: - native = DataFrame(data, columns=cls._numpy_column_names(data, schema)) + native = DataFrame(data, columns=schema.keys()).astype(schema.to_pandas(it)) return cls.from_native(native, context=context) def __narwhals_dataframe__(self) -> Self: diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index faf1357c79..18b1d04e73 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -23,6 +23,7 @@ is_index_selector, is_range, is_sequence_like, + is_sequence_of, is_slice_index, is_slice_none, parse_columns_to_drop, @@ -55,6 +56,7 @@ MultiColSelector, MultiIndexSelector, PivotAgg, + SchemaDefinition, SingleIndexSelector, UniqueKeepStrategy, _2DArray, @@ -317,14 +319,16 @@ def from_dict( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: + from narwhals._utils import NullableSchema + pl_schema = ( { key: narwhals_to_native_dtype(dtype, context._version) if dtype is not None else None - for (key, dtype) in schema.items() + for (key, dtype) in NullableSchema(schema).items() } if schema else None @@ -338,14 +342,16 @@ def from_dicts( /, *, context: _LimitedContext, - schema: IntoSchema | Mapping[str, DType | None] | None, + schema: SchemaDefinition | None, ) -> Self: + from narwhals._utils import NullableSchema + pl_schema = ( { key: narwhals_to_native_dtype(dtype, context._version) if dtype is not None else None - for (key, dtype) in schema.items() + for (key, dtype) in NullableSchema(schema).items() } if schema else None @@ -378,9 +384,9 @@ def from_numpy( from narwhals.schema import Schema pl_schema = ( - Schema(schema).to_polars() - if isinstance(schema, (Mapping, Schema)) - else schema + schema + if is_sequence_of(schema, str) or schema is None + else Schema(schema).to_polars() ) return cls.from_native(pl.from_numpy(data, pl_schema), context=context) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index c7066a0d5f..9066832628 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -3,6 +3,7 @@ import os import re import sys +from collections import OrderedDict from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence from datetime import timezone from enum import Enum, auto @@ -111,6 +112,7 @@ ) from narwhals.dataframe import DataFrame, LazyFrame from narwhals.dtypes import DType + from narwhals.schema import Schema from narwhals.series import Series from narwhals.typing import ( CompliantDataFrame, @@ -118,8 +120,10 @@ CompliantSeries, DTypes, FileSource, + IntoDType, IntoSeriesT, MultiIndexSelector, + SchemaDefinition, SingleIndexSelector, SizedMultiIndexSelector, SizeUnit, @@ -2107,3 +2111,20 @@ def extend_bool( Stolen from https://github.com/pola-rs/polars/blob/b8bfb07a4a37a8d449d6d1841e345817431142df/py-polars/polars/_utils/various.py#L580-L594 """ return (value,) * n_match if isinstance(value, bool) else tuple(value) + + +class NullableSchema(OrderedDict[str, "IntoDType | None"]): + def __init__(self, schema: SchemaDefinition | None = None) -> None: + schema = schema or {} + super().__init__(schema) + self.is_nullable = None in self.values() + + def to_schema(self) -> Schema: + """Converts to Schema by filtering out None values.""" + from narwhals.schema import Schema + + if self.is_nullable: # pragma: no cover + msg = "Cannot convert nullable mapping into `Schema`" + raise AssertionError(msg) + + return Schema(self.items()) # type: ignore[arg-type] diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 428c037c10..76aba9f829 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -72,7 +72,6 @@ from narwhals._expression_parsing import ExprMetadata from narwhals._translate import IntoArrowTable from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars - from narwhals.dtypes import DType from narwhals.group_by import GroupBy, LazyGroupBy from narwhals.typing import ( AsofJoinStrategy, @@ -85,6 +84,7 @@ MultiColSelector as _MultiColSelector, MultiIndexSelector as _MultiIndexSelector, PivotAgg, + SchemaDefinition, SingleColSelector, SingleIndexSelector, SizeUnit, @@ -559,7 +559,7 @@ def from_arrow( def from_dict( cls, data: Mapping[str, Any], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, ) -> DataFrame[Any]: @@ -601,8 +601,15 @@ def from_dict( | 1 2 4 | └──────────────────┘ """ + from narwhals._utils import NullableSchema + if backend is None: data, backend = _from_dict_no_backend(data) + if (schema and data) and ( + diff := set(NullableSchema(schema).keys()).symmetric_difference(data.keys()) + ): + msg = f"Keys in `schema` and `data` are expected to match, found unmatched keys: {diff}" + raise InvalidOperationError(msg) implementation = Implementation.from_backend(backend) if is_eager_allowed(implementation): ns = cls._version.namespace.from_backend(implementation).compliant @@ -620,7 +627,7 @@ def from_dict( def from_dicts( cls, data: Sequence[Mapping[str, Any]], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: diff --git a/narwhals/functions.py b/narwhals/functions.py index bff3f27c85..3ba8aa5bf0 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -17,12 +17,7 @@ supports_arrow_c_stream, validate_laziness, ) -from narwhals.dependencies import ( - is_narwhals_series, - is_numpy_array, - is_numpy_array_2d, - is_pyarrow_table, -) +from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_pyarrow_table from narwhals.exceptions import InvalidOperationError from narwhals.expr import Expr from narwhals.translate import from_native, to_native @@ -30,13 +25,12 @@ if TYPE_CHECKING: from types import ModuleType - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import TypeIs from narwhals._native import NativeDataFrame, NativeLazyFrame, NativeSeries from narwhals._translate import IntoArrowTable from narwhals._typing import Backend, EagerAllowed, IntoBackend from narwhals.dataframe import DataFrame, LazyFrame - from narwhals.dtypes import DType from narwhals.series import Series from narwhals.typing import ( ConcatMethod, @@ -46,11 +40,10 @@ IntoExpr, IntoSchema, NonNestedLiteral, + SchemaDefinition, _2DArray, ) - _IntoSchema: TypeAlias = "IntoSchema | Sequence[str] | None" - def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT: """Concatenate multiple DataFrames, LazyFrames into a single entity. @@ -229,7 +222,7 @@ def _new_series_impl( @deprecate_native_namespace(warn_version="1.26.0") def from_dict( data: Mapping[str, Any], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, native_namespace: ModuleType | None = None, # noqa: ARG001 @@ -273,33 +266,7 @@ def from_dict( | 1 2 4 | └──────────────────┘ """ - if backend is None: - data, backend = _from_dict_no_backend(data) - if schema and data and (diff := set(schema.keys()).symmetric_difference(data.keys())): - msg = f"Keys in `schema` and `data` are expected to match, found unmatched keys: {diff}" - raise InvalidOperationError(msg) - implementation = Implementation.from_backend(backend) - if is_eager_allowed(implementation): - ns = Version.MAIN.namespace.from_backend(implementation).compliant - return ns._dataframe.from_dict(data, schema=schema, context=ns).to_narwhals() - if implementation is Implementation.UNKNOWN: # pragma: no cover - _native_namespace = implementation.to_native_namespace() - try: - # implementation is UNKNOWN, Narwhals extension using this feature should - # implement `from_dict` function in the top-level namespace. - native_frame: NativeDataFrame = _native_namespace.from_dict( - data, schema=schema - ) - except AttributeError as e: - msg = "Unknown namespace is expected to implement `from_dict` function." - raise AttributeError(msg) from e - return from_native(native_frame, eager_only=True) - msg = ( - f"{implementation} support in Narwhals is lazy-only, but `from_dict` is an eager-only function.\n\n" - "Hint: you may want to use an eager backend and then call `.lazy`, e.g.:\n\n" - f" nw.from_dict({{'a': [1, 2]}}, backend='pyarrow').lazy('{implementation}')" - ) - raise ValueError(msg) + return Version.MAIN.dataframe.from_dict(data, schema, backend=backend) def _from_dict_no_backend( @@ -318,7 +285,7 @@ def _from_dict_no_backend( def from_dicts( data: Sequence[Mapping[str, Any]], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: @@ -423,41 +390,10 @@ def from_numpy( | e: [[1,3]] | └──────────────────┘ """ - if not is_numpy_array_2d(data): - msg = "`from_numpy` only accepts 2D numpy arrays" - raise ValueError(msg) - if not _is_into_schema(schema): - msg = ( - "`schema` is expected to be one of the following types: " - "IntoSchema | Sequence[str]. " - f"Got {type(schema)}." - ) - raise TypeError(msg) - implementation = Implementation.from_backend(backend) - if is_eager_allowed(implementation): - ns = Version.MAIN.namespace.from_backend(implementation).compliant - return ns.from_numpy(data, schema).to_narwhals() - if implementation is Implementation.UNKNOWN: # pragma: no cover - _native_namespace = implementation.to_native_namespace() - try: - # implementation is UNKNOWN, Narwhals extension using this feature should - # implement `from_numpy` function in the top-level namespace. - native_frame: NativeDataFrame = _native_namespace.from_numpy( - data, schema=schema - ) - except AttributeError as e: - msg = "Unknown namespace is expected to implement `from_numpy` function." - raise AttributeError(msg) from e - return from_native(native_frame, eager_only=True) - msg = ( - f"{implementation} support in Narwhals is lazy-only, but `from_numpy` is an eager-only function.\n\n" - "Hint: you may want to use an eager backend and then call `.lazy`, e.g.:\n\n" - f" nw.from_numpy(arr, backend='pyarrow').lazy('{implementation}')" - ) - raise ValueError(msg) + return Version.MAIN.dataframe.from_numpy(data, schema, backend=backend) -def _is_into_schema(obj: Any) -> TypeIs[_IntoSchema]: +def _is_into_schema(obj: Any) -> TypeIs[IntoSchema]: from narwhals.schema import Schema return ( diff --git a/narwhals/schema.py b/narwhals/schema.py index b459be68fa..44140c6479 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -30,19 +30,20 @@ import pyarrow as pa from typing_extensions import Self - from narwhals.dtypes import DType from narwhals.typing import ( DTypeBackend, IntoArrowSchema, + IntoDType, IntoPandasSchema, IntoPolarsSchema, + IntoSchema, ) __all__ = ["Schema"] -class Schema(OrderedDict[str, "DType"]): +class Schema(OrderedDict[str, "IntoDType"]): """Ordered mapping of column names to their data type. Arguments: @@ -74,9 +75,7 @@ class Schema(OrderedDict[str, "DType"]): _version: ClassVar[Version] = Version.MAIN - def __init__( - self, schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None - ) -> None: + def __init__(self, schema: IntoSchema | None = None) -> None: schema = schema or {} super().__init__(schema) @@ -84,7 +83,7 @@ def names(self) -> list[str]: """Get the column names of the schema.""" return list(self.keys()) - def dtypes(self) -> list[DType]: + def dtypes(self) -> list[IntoDType]: """Get the data types of the schema.""" return list(self.values()) @@ -126,8 +125,10 @@ def from_arrow(cls, schema: IntoArrowSchema, /) -> Self: from narwhals._arrow.utils import native_to_narwhals_dtype return cls( - (field.name, native_to_narwhals_dtype(field.type, cls._version)) - for field in schema + { + field.name: native_to_narwhals_dtype(field.type, cls._version) + for field in schema + } ) @classmethod @@ -232,8 +233,10 @@ def from_polars(cls, schema: IntoPolarsSchema, /) -> Self: from narwhals._polars.utils import native_to_narwhals_dtype return cls( - (name, native_to_narwhals_dtype(dtype, cls._version)) - for name, dtype in schema.items() + { + name: native_to_narwhals_dtype(dtype, cls._version) + for name, dtype in schema.items() + } ) def to_arrow(self) -> pa.Schema: @@ -360,6 +363,10 @@ def _from_pandas_like( impl = implementation return cls( - (name, native_to_narwhals_dtype(dtype, cls._version, impl, allow_object=True)) - for name, dtype in schema.items() + { + name: native_to_narwhals_dtype( + dtype, cls._version, impl, allow_object=True + ) + for name, dtype in schema.items() + } ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8e5a78672d..edad762b26 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -89,6 +89,7 @@ IntoSchema, IntoSeries, NonNestedLiteral, + SchemaDefinition, SingleColSelector, SingleIndexSelector, _1DArray, @@ -128,7 +129,7 @@ def from_arrow( def from_dict( cls, data: Mapping[str, Any], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, ) -> DataFrame[Any]: @@ -139,7 +140,7 @@ def from_dict( def from_dicts( cls, data: Sequence[Any], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: @@ -150,7 +151,7 @@ def from_dicts( def from_numpy( cls, data: _2DArray, - schema: Mapping[str, DType] | Schema | Sequence[str] | None = None, + schema: IntoSchema | Sequence[str] | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: @@ -453,9 +454,7 @@ class Schema(NwSchema): _version = Version.V1 @inherit_doc(NwSchema) - def __init__( - self, schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None - ) -> None: + def __init__(self, schema: IntoSchema | None = None) -> None: super().__init__(schema) @@ -1252,7 +1251,7 @@ def from_arrow( @deprecate_native_namespace() def from_dict( data: Mapping[str, Any], - schema: Mapping[str, DType] | Schema | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, native_namespace: ModuleType | None = None, # noqa: ARG001 diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index cbc5ff21d3..e399ead6de 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -82,6 +82,7 @@ IntoSchema, IntoSeries, NonNestedLiteral, + SchemaDefinition, SingleColSelector, SingleIndexSelector, _1DArray, @@ -120,7 +121,7 @@ def from_arrow( def from_dict( cls, data: Mapping[str, Any], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, ) -> DataFrame[Any]: @@ -131,7 +132,7 @@ def from_dict( def from_dicts( cls, data: Sequence[Mapping[str, Any]], - schema: IntoSchema | Mapping[str, DType | None] | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: @@ -142,7 +143,7 @@ def from_dicts( def from_numpy( cls, data: _2DArray, - schema: Mapping[str, DType] | Schema | Sequence[str] | None = None, + schema: IntoSchema | Sequence[str] | None = None, *, backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any]: @@ -311,9 +312,7 @@ class Schema(NwSchema): _version = Version.V2 @inherit_doc(NwSchema) - def __init__( - self, schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None - ) -> None: + def __init__(self, schema: IntoSchema | None = None) -> None: super().__init__(schema) @@ -1031,7 +1030,7 @@ def from_arrow( def from_dict( data: Mapping[str, Any], - schema: Mapping[str, DType] | Schema | None = None, + schema: SchemaDefinition | None = None, *, backend: IntoBackend[EagerAllowed] | None = None, ) -> DataFrame[Any]: diff --git a/narwhals/typing.py b/narwhals/typing.py index 8c89b7745b..40e96c274a 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -302,13 +302,12 @@ def Binary(self) -> type[dtypes.Binary]: ... └──────────────────┘ """ - -# TODO @dangotbanned: fix this? -# Constructor allows tuples, but we don't support that *everywhere* yet -IntoSchema: TypeAlias = "Mapping[str, dtypes.DType] | Schema" +IntoSchema: TypeAlias = ( + "Mapping[str, IntoDType] | Sequence[tuple[str, IntoDType]] | Schema" +) """Anything that can be converted into a Narwhals Schema. -Defined by column names and their associated *instantiated* Narwhals DType. +Defined by column names and their associated Narwhals DType. Examples: >>> import narwhals as nw @@ -331,8 +330,31 @@ def Binary(self) -> type[dtypes.Binary]: ... |b: [[null,"hi","howdy"]]| |c: [[2.1,2,null]] | └────────────────────────┘ + + >>> nw.DataFrame.from_dict( + ... data, + ... schema=[("a", nw.Int32()), ("b", nw.String), ("c", nw.Float64())], + ... backend="pyarrow", + ... ) + ┌────────────────────────┐ + | Narwhals DataFrame | + |------------------------| + |pyarrow.Table | + |a: int32 | + |b: string | + |c: float | + |---- | + |a: [[1,2,3]] | + |b: [[null,"hi","howdy"]]| + |c: [[2.1,2,null]] | + └────────────────────────┘ """ +SchemaDefinition: TypeAlias = ( + "Mapping[str, IntoDType | None] | Sequence[tuple[str, IntoDType | None]]" +) +"""Either a {str: IntoDType| None} mapping or a Sequence of (str, IntoDType | None) pairs.""" + IntoArrowSchema: TypeAlias = "pa.Schema | Mapping[str, pa.DataType]" IntoPolarsSchema: TypeAlias = "pl.Schema | Mapping[str, pl.DataType]" IntoPandasSchema: TypeAlias = Mapping[str, PandasLikeDType] @@ -377,6 +399,7 @@ def Binary(self) -> type[dtypes.Binary]: ... "Frame", "FrameT", "IntoBackend", + "IntoDType", "IntoDataFrame", "IntoDataFrameT", "IntoExpr", @@ -384,6 +407,7 @@ def Binary(self) -> type[dtypes.Binary]: ... "IntoFrameT", "IntoLazyFrame", "IntoLazyFrameT", + "IntoSchema", "IntoSeries", "IntoSeriesT", "LazyAllowed", diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 2fc9e5cf7b..689680bea9 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -410,9 +410,9 @@ def test_schema_to_pandas( pytest.importorskip("pyarrow") schema = nw.Schema( { - "a": nw.Int64(), + "a": nw.Int64, "b": nw.String(), - "c": nw.Boolean(), + "c": nw.Boolean, "d": nw.Float64(), "e": nw.Datetime("ns"), } @@ -425,10 +425,10 @@ def test_schema_to_pandas_strict_zip() -> None: schema = nw.Schema( { - "a": nw.Int64(), + "a": nw.Int64, "b": nw.String(), "c": nw.Boolean(), - "d": nw.Float64(), + "d": nw.Float64, "e": nw.Datetime("ns"), } ) @@ -700,3 +700,9 @@ def test_schema_from_to_roundtrip() -> None: assert nw_schema_1 == nw_schema_2 assert nw_schema_2 == nw_schema_3 assert py_schema_1 == py_schema_2 + + +def test_schema_from_sequence() -> None: + into_schema_mapping = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} + into_schema_sequence = list(into_schema_mapping.items()) + assert nw.Schema(into_schema_mapping) == nw.Schema(into_schema_sequence) diff --git a/tests/testing/assert_series_equal_test.py b/tests/testing/assert_series_equal_test.py index d16e1065ce..9f757c0509 100644 --- a/tests/testing/assert_series_equal_test.py +++ b/tests/testing/assert_series_equal_test.py @@ -11,9 +11,11 @@ from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION if TYPE_CHECKING: + from collections.abc import Mapping + from typing_extensions import TypeAlias - from narwhals.typing import IntoSchema, IntoSeriesT + from narwhals.typing import IntoDType, IntoSeriesT from tests.conftest import Data from tests.utils import ConstructorEager @@ -31,7 +33,9 @@ def series_from_native(native: IntoSeriesT) -> nw.Series[IntoSeriesT]: def test_self_equal( - constructor_eager: ConstructorEager, testing_data: Data, testing_schema: IntoSchema + constructor_eager: ConstructorEager, + testing_data: Data, + testing_schema: Mapping[str, IntoDType], ) -> None: """Test that a series is equal to itself, including nested dtypes with nulls.""" if "pandas" in str(constructor_eager): diff --git a/tests/testing/conftest.py b/tests/testing/conftest.py index ffb4fcc6fb..af83c78189 100644 --- a/tests/testing/conftest.py +++ b/tests/testing/conftest.py @@ -8,12 +8,14 @@ import narwhals as nw if TYPE_CHECKING: - from narwhals.typing import IntoSchema + from collections.abc import Mapping + + from narwhals.typing import IntoDType from tests.conftest import Data @pytest.fixture(scope="module") -def testing_schema() -> IntoSchema: +def testing_schema() -> Mapping[str, IntoDType]: return { "int": nw.Int32(), "float": nw.Float32(),