diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a255324d82..26c3595cb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,7 +60,7 @@ repos: language: pygrep files: ^narwhals/ - id: dtypes-import - name: don't import from narwhals.dtypes (use `import_dtypes_module` instead) + name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | (?x) import\ narwhals.dtypes| diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index fb90e4a456..ac270abd40 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -22,7 +22,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: from narwhals._arrow.typing import ArrayOrScalarAny @@ -135,7 +134,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: - dtypes = import_dtypes_module(self._version) + int_64 = self._version.dtypes.Int64() def func(df: ArrowDataFrame) -> list[ArrowSeries]: expr_results = list(chain.from_iterable(expr(df) for expr in exprs)) @@ -143,7 +142,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: *(s.fill_null(0, strategy=None, limit=None) for s in expr_results) ) non_na = align_series_full_broadcast( - *(1 - s.is_null().cast(dtypes.Int64()) for s in expr_results) + *(1 - s.is_null().cast(int_64) for s in expr_results) ) return [reduce(operator.add, series) / reduce(operator.add, non_na)] diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index b08b5374ca..69befdc886 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -32,7 +32,6 @@ from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name -from narwhals.utils import import_dtypes_module from narwhals.utils import is_list_of from narwhals.utils import not_implemented from narwhals.utils import requires @@ -865,7 +864,7 @@ def is_finite(self) -> Self: return self._with_native(pc.is_finite(self.native)) def cum_count(self, *, reverse: bool) -> Self: - dtypes = import_dtypes_module(self._version) + dtypes = self._version.dtypes return (~self.is_null()).cast(dtypes.UInt32()).cum_sum(reverse=reverse) @requires.backend_version((13,)) diff --git a/narwhals/_arrow/series_dt.py b/narwhals/_arrow/series_dt.py index b4905bd4d3..4d0e326626 100644 --- a/narwhals/_arrow/series_dt.py +++ b/narwhals/_arrow/series_dt.py @@ -10,7 +10,6 @@ from narwhals._arrow.utils import ArrowSeriesNamespace from narwhals._arrow.utils import floordiv_compat from narwhals._arrow.utils import lit -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: from narwhals._arrow.series import ArrowSeries @@ -47,7 +46,7 @@ def convert_time_zone(self, time_zone: str) -> ArrowSeries: def timestamp(self, time_unit: TimeUnit) -> ArrowSeries: ser = self.compliant - dtypes = import_dtypes_module(ser._version) + dtypes = ser._version.dtypes if isinstance(ser.dtype, dtypes.Datetime): unit = ser.dtype.time_unit s_cast = self.native.cast(pa.int64()) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index da51f8a025..bd8c65a28a 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -13,7 +13,6 @@ from narwhals._compliant.series import _SeriesNamespace from narwhals.exceptions import ShapeError -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -92,7 +91,7 @@ def nulls_like(n: int, series: ArrowSeries) -> ArrowArray: @lru_cache(maxsize=16) def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if pa.types.is_int64(dtype): return dtypes.Int64() if pa.types.is_int32(dtype): @@ -157,7 +156,7 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa.DataType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 08a279bab4..5f28e8e63b 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -18,7 +18,6 @@ from narwhals.utils import _parse_time_unit_and_time_zone from narwhals.utils import dtype_matches_time_unit_and_time_zone from narwhals.utils import get_column_names -from narwhals.utils import import_dtypes_module from narwhals.utils import is_compliant_dataframe if not TYPE_CHECKING: # pragma: no cover @@ -150,13 +149,13 @@ def names(df: FrameT) -> Sequence[str]: return self._selector.from_callables(series, names, context=self) def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]: - return self._is_dtype(import_dtypes_module(self._version).Categorical) + return self._is_dtype(self._version.dtypes.Categorical) def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]: - return self._is_dtype(import_dtypes_module(self._version).String) + return self._is_dtype(self._version.dtypes.String) def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]: - return self._is_dtype(import_dtypes_module(self._version).Boolean) + return self._is_dtype(self._version.dtypes.Boolean) def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]: def series(df: FrameT) -> Sequence[SeriesOrExprT]: @@ -172,7 +171,7 @@ def datetime( time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone) matches = partial( dtype_matches_time_unit_and_time_zone, - dtypes=import_dtypes_module(version=self._version), + dtypes=self._version.dtypes, time_units=time_units, time_zones=time_zones, ) diff --git a/narwhals/_dask/expr_dt.py b/narwhals/_dask/expr_dt.py index 2f07a8d362..ace841be44 100644 --- a/narwhals/_dask/expr_dt.py +++ b/narwhals/_dask/expr_dt.py @@ -6,7 +6,6 @@ from narwhals._pandas_like.utils import calculate_timestamp_datetime from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals.utils import Implementation -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: try: @@ -115,7 +114,7 @@ def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: ) is_pyarrow_dtype = "pyarrow" in str(dtype) mask_na = s.isna() - dtypes = import_dtypes_module(self._compliant_expr._version) + dtypes = self._compliant_expr._version.dtypes if dtype == dtypes.Date: # Date is only supported in pandas dtypes if pyarrow-backed s_cast = s.astype("Int32[pyarrow]") diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index ba152bd9b5..114b7335db 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -10,7 +10,6 @@ from narwhals.dependencies import get_pyarrow from narwhals.utils import Implementation from narwhals.utils import Version -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass from narwhals.utils import parse_version @@ -96,7 +95,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None: def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> Any: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Float64): return "float64" if isinstance_or_issubclass(dtype, dtypes.Float32): diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 718c3a6117..6e306bb12b 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -26,7 +26,6 @@ from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import generate_temporary_column_name -from narwhals.utils import import_dtypes_module from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -414,7 +413,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: return self._with_native(self.native.filter(keep_condition)) def explode(self, columns: Sequence[str]) -> Self: - dtypes = import_dtypes_module(self._version) + dtypes = self._version.dtypes schema = self.collect_schema() for name in columns: dtype = schema[name] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 2eaea270d5..53382a0381 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -8,7 +8,6 @@ import duckdb from narwhals.utils import Version -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -83,7 +82,7 @@ def evaluate_exprs( def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DType: duckdb_dtype_id = duckdb_dtype.id - dtypes = import_dtypes_module(version) + dtypes = version.dtypes # Handle nested data types first if duckdb_dtype_id == "list": @@ -121,7 +120,7 @@ def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DT @lru_cache(maxsize=16) def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes return { "hugeint": dtypes.Int128(), "bigint": dtypes.Int64(), @@ -150,7 +149,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> str: # noqa: PLR0915 - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 4b34412fe0..9d29aed8da 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -9,7 +9,6 @@ from narwhals.dependencies import get_ibis from narwhals.utils import Implementation from narwhals.utils import Version -from narwhals.utils import import_dtypes_module from narwhals.utils import validate_backend_version if TYPE_CHECKING: @@ -25,7 +24,7 @@ @lru_cache(maxsize=16) def native_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if ibis_dtype.is_int64(): return dtypes.Int64() if ibis_dtype.is_int32(): diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index c52cb0a506..a96fba1687 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -5,7 +5,6 @@ from typing import Any from typing import NoReturn -from narwhals.utils import import_dtypes_module from narwhals.utils import parse_version if TYPE_CHECKING: @@ -33,7 +32,7 @@ class DtypeKind(enum.IntEnum): def map_interchange_dtype_to_narwhals_dtype( interchange_dtype: tuple[DtypeKind, int, Any, Any], version: Version ) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if interchange_dtype[0] == DtypeKind.INT: if interchange_dtype[1] == 64: return dtypes.Int64() diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 2117f68eac..9774349249 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -34,7 +34,6 @@ from narwhals.utils import _remap_full_join_keys from narwhals.utils import check_column_exists from narwhals.utils import generate_temporary_column_name -from narwhals.utils import import_dtypes_module from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import scale_bytes @@ -838,12 +837,11 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.native.to_numpy(dtype=dtype, copy=copy) return self.native.to_numpy(copy=copy) - dtypes = import_dtypes_module(self._version) - + dtype_datetime = self._version.dtypes.Datetime to_convert = [ key for key, val in self.schema.items() - if val == dtypes.Datetime and val.time_zone is not None # type: ignore[attr-defined] + if isinstance(val, dtype_datetime) and val.time_zone is not None ] if to_convert: df = self.with_columns( @@ -1053,7 +1051,7 @@ def unpivot( ) def explode(self, columns: Sequence[str]) -> Self: - dtypes = import_dtypes_module(self._version) + dtypes = self._version.dtypes schema = self.collect_schema() for col_to_explode in columns: @@ -1078,8 +1076,6 @@ def explode(self, columns: Sequence[str]) -> Self: (native_frame[col_name].list.len() == anchor_series).all() for col_name in columns[1:] ): - from narwhals.exceptions import ShapeError - msg = "exploded columns must have matching element counts" raise ShapeError(msg) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index a49d459963..5417971c2b 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -18,7 +18,6 @@ from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.utils import align_series_full_broadcast -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: import pandas as pd @@ -283,13 +282,11 @@ def concat_str( separator: str, ignore_nulls: bool, ) -> PandasLikeExpr: - dtypes = import_dtypes_module(self._version) + string = self._version.dtypes.String() def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: expr_results = [s for _expr in exprs for s in _expr(df)] - series = align_series_full_broadcast( - *(s.cast(dtypes.String()) for s in expr_results) - ) + series = align_series_full_broadcast(*(s.cast(string) for s in expr_results)) null_mask = align_series_full_broadcast(*(s.is_null() for s in expr_results)) if not ignore_nulls: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 08d94746e0..2dbb53edf0 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -28,7 +28,6 @@ from narwhals.dependencies import is_pandas_like_series from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation -from narwhals.utils import import_dtypes_module from narwhals.utils import is_list_of from narwhals.utils import parse_version from narwhals.utils import validate_backend_version @@ -678,31 +677,25 @@ def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: # the default is meant to be None, but pandas doesn't allow it? # https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__array__.html - copy = copy or self._implementation is Implementation.CUDF - dtypes = import_dtypes_module(self._version) + dtypes = self._version.dtypes if isinstance(self.dtype, dtypes.Datetime) and self.dtype.time_zone is not None: s = self.dt.convert_time_zone("UTC").dt.replace_time_zone(None).native else: s = self.native has_missing = s.isna().any() + kwargs: dict[Any, Any] = {"copy": copy or self._implementation.is_cudf()} if has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING: if self._implementation is Implementation.PANDAS and self._backend_version < ( 1, ): # pragma: no cover - kwargs = {} + ... else: - kwargs = {"na_value": float("nan")} - return s.to_numpy( - dtype=dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)], - copy=copy, - **kwargs, - ) + kwargs.update({"na_value": float("nan")}) + dtype = dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)] if not has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING: - return s.to_numpy( - dtype=dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)], copy=copy - ) - return s.to_numpy(dtype=dtype, copy=copy) + dtype = dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)] + return s.to_numpy(dtype=dtype, **kwargs) def to_pandas(self) -> pd.Series[Any]: if self._implementation is Implementation.PANDAS: diff --git a/narwhals/_pandas_like/series_dt.py b/narwhals/_pandas_like/series_dt.py index 2fff762bef..641368fcef 100644 --- a/narwhals/_pandas_like/series_dt.py +++ b/narwhals/_pandas_like/series_dt.py @@ -9,7 +9,6 @@ from narwhals._pandas_like.utils import calculate_timestamp_datetime from narwhals._pandas_like.utils import int_dtype_mapper from narwhals._pandas_like.utils import is_pyarrow_dtype_backend -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: from narwhals._pandas_like.series import PandasLikeSeries @@ -173,7 +172,7 @@ def timestamp(self, time_unit: TimeUnit) -> PandasLikeSeries: s = self.native dtype = self.compliant.dtype mask_na = s.isna() - dtypes = import_dtypes_module(self.version) + dtypes = self.version.dtypes if dtype == dtypes.Date: # Date is only supported in pandas dtypes if pyarrow-backed s_cast = s.astype("Int32[pyarrow]") diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index ce521b63a3..7d32423fe1 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -7,7 +7,6 @@ from narwhals._pandas_like.utils import get_dtype_backend from narwhals._pandas_like.utils import narwhals_to_native_dtype from narwhals._pandas_like.utils import set_index -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: from narwhals._pandas_like.series import PandasLikeSeries @@ -28,7 +27,7 @@ def len(self) -> PandasLikeSeries: backend_version=backend_version, ) dtype = narwhals_to_native_dtype( - import_dtypes_module(self.version).UInt32(), + self.version.dtypes.UInt32(), get_dtype_backend(result.dtype, implementation), implementation, backend_version, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index d186a8095f..3d241de2ba 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -3,6 +3,7 @@ import functools import re from contextlib import suppress +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -19,7 +20,6 @@ from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass T = TypeVar("T", bound=Sized) @@ -215,7 +215,7 @@ def rename( def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: dtype = str(native_dtype) - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}: @@ -283,7 +283,7 @@ def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> def object_native_to_narwhals_dtype( series: PandasLikeSeries, version: Version, implementation: Implementation ) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if implementation is Implementation.CUDF: # pragma: no cover # Per conversations with their maintainers, they don't support arbitrary # objects, so we can just return String. @@ -307,7 +307,7 @@ def native_categorical_to_narwhals_dtype( version: Version, get_categories: Callable[[], tuple[str, ...]], ) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if version is Version.V1: return dtypes.Categorical() if native_dtype.ordered: @@ -341,8 +341,7 @@ def native_to_narwhals_dtype( elif implementation is Implementation.DASK: # Per conversations with their maintainers, they don't support arbitrary # objects, so we can just return String. - dtypes = import_dtypes_module(version) - return dtypes.String() + return version.dtypes.String() msg = ( "Unreachable code, object dtype should be handled separately" # pragma: no cover ) @@ -383,7 +382,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 if dtype_backend is not None and dtype_backend not in {"pyarrow", "numpy_nullable"}: msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'" raise ValueError(msg) - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) @@ -674,32 +673,23 @@ def pivot_table( columns: Sequence[str], aggregate_function: str | None, ) -> Any: - dtypes = import_dtypes_module(df._version) + categorical = df._version.dtypes.Categorical + kwds: dict[Any, Any] = {"observed": True} if df._implementation is Implementation.CUDF: - if any( - x == dtypes.Categorical - for x in df.simple_select(*[*values, *index, *columns]).schema.values() - ): + kwds.pop("observed") + cols = set(chain(values, index, columns)) + schema = df.schema.items() + if any(tp for name, tp in schema if name in cols and isinstance(tp, categorical)): msg = "`pivot` with Categoricals is not implemented for cuDF backend" raise NotImplementedError(msg) - # cuDF doesn't support `observed` argument - result = df._native_frame.pivot_table( - values=values, - index=index, - columns=columns, - aggfunc=aggregate_function, - margins=False, - ) - else: - result = df._native_frame.pivot_table( - values=values, - index=index, - columns=columns, - aggfunc=aggregate_function, - margins=False, - observed=True, - ) - return result + return df.native.pivot_table( + values=values, + index=index, + columns=columns, + aggfunc=aggregate_function, + margins=False, + **kwds, + ) def check_column_names_are_unique(columns: pd.Index[str]) -> None: diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 5e7cdab109..6a2bf3e961 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -19,7 +19,6 @@ from narwhals.exceptions import NarwhalsError from narwhals.exceptions import ShapeError from narwhals.utils import Version -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -76,7 +75,7 @@ def extract_args_kwargs( def native_to_narwhals_dtype( dtype: pl.DataType, version: Version, backend_version: tuple[int, ...] ) -> DType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if dtype == pl.Float64: return dtypes.Float64() if dtype == pl.Float32: @@ -156,7 +155,7 @@ def native_to_narwhals_dtype( def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...] ) -> pl.DataType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if dtype == dtypes.Float64: return pl.Float64() if dtype == dtypes.Float32: diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 15f69e7128..2ba77db442 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -22,7 +22,6 @@ from narwhals.utils import check_column_exists from narwhals.utils import find_stacklevel from narwhals.utils import generate_temporary_column_name -from narwhals.utils import import_dtypes_module from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -422,7 +421,7 @@ def join( ) def explode(self, columns: Sequence[str]) -> Self: - dtypes = import_dtypes_module(self._version) + dtypes = self._version.dtypes schema = self.collect_schema() for col_to_explode in columns: diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index cd2afaef9a..bccb199aca 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -7,7 +7,6 @@ from narwhals.exceptions import UnsupportedDTypeError from narwhals.utils import Implementation -from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -44,7 +43,7 @@ def __init__( def native_to_narwhals_dtype( dtype: _NativeDType, version: Version, spark_types: ModuleType ) -> DType: - dtypes = import_dtypes_module(version=version) + dtypes = version.dtypes if TYPE_CHECKING: native = sqlframe_types else: @@ -104,7 +103,7 @@ def native_to_narwhals_dtype( def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version, spark_types: ModuleType ) -> _NativeDType: - dtypes = import_dtypes_module(version) + dtypes = version.dtypes if TYPE_CHECKING: native = sqlframe_types else: diff --git a/narwhals/utils.py b/narwhals/utils.py index 39314169d7..01e3574891 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -195,6 +195,16 @@ def namespace(self) -> type[Namespace[Any]]: return Namespace + @property + def dtypes(self) -> DTypes: + if self is Version.MAIN: + from narwhals import dtypes + + return dtypes + from narwhals.stable.v1 import dtypes as v1_dtypes + + return v1_dtypes + class Implementation(Enum): """Implementation of native object (pandas, Polars, PyArrow, ...).""" @@ -602,24 +612,6 @@ def validate_backend_version( raise ValueError(msg) -def import_dtypes_module(version: Version) -> DTypes: - if version is Version.MAIN: - from narwhals import dtypes - - return dtypes - elif version is Version.V1: - from narwhals.stable.v1 import dtypes as v1_dtypes - - return v1_dtypes - else: # pragma: no cover - msg = ( - "Congratulations, you have entered unreachable code.\n" - "Please report an issue at https://github.com/narwhals-dev/narwhals/issues.\n" - f"Version: {version}" - ) - raise AssertionError(msg) - - def remove_prefix(text: str, prefix: str) -> str: # pragma: no cover if text.startswith(prefix): return text[len(prefix) :] @@ -1154,13 +1146,12 @@ def is_ordered_categorical(series: Series[Any]) -> bool: """ from narwhals._interchange.series import InterchangeSeries - dtypes = import_dtypes_module(series._compliant_series._version) - - if ( - isinstance(series._compliant_series, InterchangeSeries) - and series.dtype == dtypes.Categorical + dtypes = series._compliant_series._version.dtypes + compliant = series._compliant_series + if isinstance(compliant, InterchangeSeries) and isinstance( + series.dtype, dtypes.Categorical ): - return series._compliant_series.native.describe_categorical["is_ordered"] + return compliant.native.describe_categorical["is_ordered"] if series.dtype == dtypes.Enum: return True if series.dtype != dtypes.Categorical: