Skip to content
Merged
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ repos:
language: pygrep
files: ^narwhals/
- id: dtypes-import
name: don't import from narwhals.dtypes (use `import_dtypes_module` instead)
name: don't import from narwhals.dtypes (use `Version.dtypes` instead)
entry: |
(?x)
import\ narwhals.dtypes|
Expand Down
5 changes: 2 additions & 3 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from narwhals._arrow.typing import ArrayOrScalarAny
Expand Down Expand Up @@ -135,15 +134,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)

def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
dtypes = import_dtypes_module(self._version)
int_64 = self._version.dtypes.Int64()

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
series = align_series_full_broadcast(
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
)
non_na = align_series_full_broadcast(
*(1 - s.is_null().cast(dtypes.Int64()) for s in expr_results)
*(1 - s.is_null().cast(int_64) for s in expr_results)
)
return [reduce(operator.add, series) / reduce(operator.add, non_na)]

Expand Down
3 changes: 1 addition & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_list_of
from narwhals.utils import not_implemented
from narwhals.utils import requires
Expand Down Expand Up @@ -865,7 +864,7 @@ def is_finite(self) -> Self:
return self._with_native(pc.is_finite(self.native))

def cum_count(self, *, reverse: bool) -> Self:
dtypes = import_dtypes_module(self._version)
dtypes = self._version.dtypes
return (~self.is_null()).cast(dtypes.UInt32()).cum_sum(reverse=reverse)

@requires.backend_version((13,))
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_arrow/series_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from narwhals._arrow.utils import ArrowSeriesNamespace
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import lit
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
Expand Down Expand Up @@ -47,7 +46,7 @@ def convert_time_zone(self, time_zone: str) -> ArrowSeries:

def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
ser = self.compliant
dtypes = import_dtypes_module(ser._version)
dtypes = ser._version.dtypes
if isinstance(ser.dtype, dtypes.Datetime):
unit = ser.dtype.time_unit
s_cast = self.native.cast(pa.int64())
Expand Down
5 changes: 2 additions & 3 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from narwhals._compliant.series import _SeriesNamespace
from narwhals.exceptions import ShapeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,7 +91,7 @@ def nulls_like(n: int, series: ArrowSeries) -> ArrowArray:

@lru_cache(maxsize=16)
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if pa.types.is_int64(dtype):
return dtypes.Int64()
if pa.types.is_int32(dtype):
Expand Down Expand Up @@ -157,7 +156,7 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa.DataType:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Decimal):
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
Expand Down
9 changes: 4 additions & 5 deletions narwhals/_compliant/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from narwhals.utils import _parse_time_unit_and_time_zone
from narwhals.utils import dtype_matches_time_unit_and_time_zone
from narwhals.utils import get_column_names
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_compliant_dataframe

if not TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -150,13 +149,13 @@ def names(df: FrameT) -> Sequence[str]:
return self._selector.from_callables(series, names, context=self)

def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).Categorical)
return self._is_dtype(self._version.dtypes.Categorical)

def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).String)
return self._is_dtype(self._version.dtypes.String)

def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).Boolean)
return self._is_dtype(self._version.dtypes.Boolean)

def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
Expand All @@ -172,7 +171,7 @@ def datetime(
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=import_dtypes_module(version=self._version),
dtypes=self._version.dtypes,
time_units=time_units,
time_zones=time_zones,
)
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_dask/expr_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from narwhals._pandas_like.utils import calculate_timestamp_datetime
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -115,7 +114,7 @@ def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
)
is_pyarrow_dtype = "pyarrow" in str(dtype)
mask_na = s.isna()
dtypes = import_dtypes_module(self._compliant_expr._version)
dtypes = self._compliant_expr._version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from narwhals.dependencies import get_pyarrow
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

Expand Down Expand Up @@ -96,7 +95,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:


def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> Any:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import not_implemented
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
Expand Down Expand Up @@ -414,7 +413,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self:
return self._with_native(self.native.filter(keep_condition))

def explode(self, columns: Sequence[str]) -> Self:
dtypes = import_dtypes_module(self._version)
dtypes = self._version.dtypes
schema = self.collect_schema()
for name in columns:
dtype = schema[name]
Expand Down
7 changes: 3 additions & 4 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import duckdb

from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,7 +82,7 @@ def evaluate_exprs(

def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DType:
duckdb_dtype_id = duckdb_dtype.id
dtypes = import_dtypes_module(version)
dtypes = version.dtypes

# Handle nested data types first
if duckdb_dtype_id == "list":
Expand Down Expand Up @@ -121,7 +120,7 @@ def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DT

@lru_cache(maxsize=16)
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
return {
"hugeint": dtypes.Int128(),
"bigint": dtypes.Int64(),
Expand Down Expand Up @@ -150,7 +149,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)


def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> str: # noqa: PLR0915
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Decimal):
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from narwhals.dependencies import get_ibis
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import validate_backend_version

if TYPE_CHECKING:
Expand All @@ -25,7 +24,7 @@

@lru_cache(maxsize=16)
def native_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if ibis_dtype.is_int64():
return dtypes.Int64()
if ibis_dtype.is_int32():
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any
from typing import NoReturn

from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -33,7 +32,7 @@ class DtypeKind(enum.IntEnum):
def map_interchange_dtype_to_narwhals_dtype(
interchange_dtype: tuple[DtypeKind, int, Any, Any], version: Version
) -> DType:
dtypes = import_dtypes_module(version)
dtypes = version.dtypes
if interchange_dtype[0] == DtypeKind.INT:
if interchange_dtype[1] == 64:
return dtypes.Int64()
Expand Down
10 changes: 3 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from narwhals.utils import _remap_full_join_keys
from narwhals.utils import check_column_exists
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import scale_bytes
Expand Down Expand Up @@ -838,12 +837,11 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
return self.native.to_numpy(dtype=dtype, copy=copy)
return self.native.to_numpy(copy=copy)

dtypes = import_dtypes_module(self._version)

dtype_datetime = self._version.dtypes.Datetime
to_convert = [
key
for key, val in self.schema.items()
if val == dtypes.Datetime and val.time_zone is not None # type: ignore[attr-defined]
if isinstance(val, dtype_datetime) and val.time_zone is not None
]
if to_convert:
df = self.with_columns(
Expand Down Expand Up @@ -1053,7 +1051,7 @@ def unpivot(
)

def explode(self, columns: Sequence[str]) -> Self:
dtypes = import_dtypes_module(self._version)
dtypes = self._version.dtypes

schema = self.collect_schema()
for col_to_explode in columns:
Expand All @@ -1078,8 +1076,6 @@ def explode(self, columns: Sequence[str]) -> Self:
(native_frame[col_name].list.len() == anchor_series).all()
for col_name in columns[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

Expand Down
7 changes: 2 additions & 5 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from narwhals._pandas_like.selectors import PandasSelectorNamespace
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.utils import align_series_full_broadcast
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -283,13 +282,11 @@ def concat_str(
separator: str,
ignore_nulls: bool,
) -> PandasLikeExpr:
dtypes = import_dtypes_module(self._version)
string = self._version.dtypes.String()

def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = align_series_full_broadcast(
*(s.cast(dtypes.String()) for s in expr_results)
)
series = align_series_full_broadcast(*(s.cast(string) for s in expr_results))
null_mask = align_series_full_broadcast(*(s.is_null() for s in expr_results))

if not ignore_nulls:
Expand Down
21 changes: 7 additions & 14 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from narwhals.dependencies import is_pandas_like_series
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_list_of
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version
Expand Down Expand Up @@ -678,31 +677,25 @@ def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
# the default is meant to be None, but pandas doesn't allow it?
# https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__array__.html
copy = copy or self._implementation is Implementation.CUDF
dtypes = import_dtypes_module(self._version)
dtypes = self._version.dtypes
if isinstance(self.dtype, dtypes.Datetime) and self.dtype.time_zone is not None:
s = self.dt.convert_time_zone("UTC").dt.replace_time_zone(None).native
else:
s = self.native

has_missing = s.isna().any()
kwargs: dict[Any, Any] = {"copy": copy or self._implementation.is_cudf()}
if has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
if self._implementation is Implementation.PANDAS and self._backend_version < (
1,
): # pragma: no cover
kwargs = {}
...
else:
kwargs = {"na_value": float("nan")}
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)],
copy=copy,
**kwargs,
)
kwargs.update({"na_value": float("nan")})
dtype = dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)]
if not has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING:
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)], copy=copy
)
return s.to_numpy(dtype=dtype, copy=copy)
dtype = dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)]
return s.to_numpy(dtype=dtype, **kwargs)

def to_pandas(self) -> pd.Series[Any]:
if self._implementation is Implementation.PANDAS:
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_pandas_like/series_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from narwhals._pandas_like.utils import calculate_timestamp_datetime
from narwhals._pandas_like.utils import int_dtype_mapper
from narwhals._pandas_like.utils import is_pyarrow_dtype_backend
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -173,7 +172,7 @@ def timestamp(self, time_unit: TimeUnit) -> PandasLikeSeries:
s = self.native
dtype = self.compliant.dtype
mask_na = s.isna()
dtypes = import_dtypes_module(self.version)
dtypes = self.version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_pandas_like/series_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from narwhals._pandas_like.utils import get_dtype_backend
from narwhals._pandas_like.utils import narwhals_to_native_dtype
from narwhals._pandas_like.utils import set_index
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
Expand All @@ -28,7 +27,7 @@ def len(self) -> PandasLikeSeries:
backend_version=backend_version,
)
dtype = narwhals_to_native_dtype(
import_dtypes_module(self.version).UInt32(),
self.version.dtypes.UInt32(),
get_dtype_backend(result.dtype, implementation),
implementation,
backend_version,
Expand Down
Loading
Loading