Skip to content

Commit

Permalink
chore: More typing in _polars
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 13, 2024
1 parent babe141 commit e086cea
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
44 changes: 25 additions & 19 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from types import ModuleType

import numpy as np
import polars as pl
from typing_extensions import Self

from narwhals.dtypes import DType
from narwhals.typing import DTypes


class PolarsDataFrame:
def __init__(
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
self, df: pl.DataFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
Expand Down Expand Up @@ -91,19 +93,21 @@ def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.nd
return self._native_frame.__array__(dtype)
return self._native_frame.__array__(dtype)

def collect_schema(self) -> dict[str, Any]:
def collect_schema(self) -> dict[str, DType]:
if self._backend_version < (1,):
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in self._native_frame.schema.items()
}
else:
schema = dict(self._native_frame.collect_schema())
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in self._native_frame.collect_schema().items()
}

@property
def shape(self) -> tuple[int, int]:
return self._native_frame.shape # type: ignore[no-any-return]
return self._native_frame.shape

def __getitem__(self, item: Any) -> Any:
if self._backend_version > (0, 20, 30):
Expand Down Expand Up @@ -179,7 +183,7 @@ def is_empty(self) -> bool:

@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]
return self._native_frame.columns

def lazy(self) -> PolarsLazyFrame:
return PolarsLazyFrame(
Expand Down Expand Up @@ -246,7 +250,7 @@ def unpivot(

class PolarsLazyFrame:
def __init__(
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
Expand Down Expand Up @@ -285,7 +289,7 @@ def func(*args: Any, **kwargs: Any) -> Any:

@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]
return self._native_frame.columns

@property
def schema(self) -> dict[str, Any]:
Expand All @@ -295,15 +299,17 @@ def schema(self) -> dict[str, Any]:
for name, dtype in schema.items()
}

def collect_schema(self) -> dict[str, Any]:
def collect_schema(self) -> dict[str, DType]:
if self._backend_version < (1,):
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in self._native_frame.schema.items()
}
else:
schema = dict(self._native_frame.collect_schema())
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in self._native_frame.collect_schema().items()
}

def collect(self) -> PolarsDataFrame:
return PolarsDataFrame(
Expand Down
11 changes: 6 additions & 5 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An
return args, kwargs


def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
def native_to_narwhals_dtype(dtype: pl.DataType, dtypes: DTypes) -> DType:
import polars as pl # ignore-banned-import()

if dtype == pl.Float64:
Expand Down Expand Up @@ -80,18 +80,19 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
return dtypes.Struct(
[
dtypes.Field(field_name, native_to_narwhals_dtype(field_type, dtypes))
for field_name, field_type in dtype
for field_name, field_type in dtype # type: ignore[attr-defined]
]
)
if dtype == pl.List:
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes))
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes)) # type: ignore[attr-defined]
if dtype == pl.Array:
if parse_version(pl.__version__) < (0, 20, 30): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, dtypes), dtype.width
native_to_narwhals_dtype(dtype.inner, dtypes), # type: ignore[attr-defined]
dtype.width, # type: ignore[attr-defined]
)
else:
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, dtypes), dtype.size)
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, dtypes), dtype.size) # type: ignore[attr-defined]
return dtypes.Unknown()


Expand Down

0 comments on commit e086cea

Please sign in to comment.