Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21b936e
feat(typing): Add `FromNative` protocol
dangotbanned Mar 29, 2025
91870db
chore(typing): Add `FromNative` to `CompliantSeries`
dangotbanned Mar 29, 2025
4f1578f
feat: Implement for `(Arrow|PandasLike)Series`
dangotbanned Mar 29, 2025
f278e67
feat: Implement for `PolarsSeries`
dangotbanned Mar 29, 2025
9c55759
chore: `ArrowSeries` coverage
dangotbanned Mar 29, 2025
8aabde8
chore: `PandasLikeSeries` partial coverage
dangotbanned Mar 29, 2025
f80c693
Merge remote-tracking branch 'upstream/main' into compliant-from-native
dangotbanned Mar 29, 2025
edf4c8c
ignore coverage for now ...
dangotbanned Mar 29, 2025
9e21718
feat(typing): Add `CompliantDataFrame.from_native`
dangotbanned Mar 30, 2025
0118eb0
feat: Implement for `ArrowDataFrame`
dangotbanned Mar 30, 2025
4701561
feat: Implement for `PandasLikeDataFrame`
dangotbanned Mar 30, 2025
2d9da37
feat: Implement for `PolarsDataFrame`
dangotbanned Mar 30, 2025
44a36f7
refactor: Found one more
dangotbanned Mar 30, 2025
ec187d5
chore(typing): Fix missing `SQLExpression` ignore
dangotbanned Mar 30, 2025
766b67a
feat: Implement `EagerNamespace.from_native`
dangotbanned Mar 30, 2025
2de6b56
feat: Add `Polars(Namespace|LazyFrame).from_native`
dangotbanned Mar 30, 2025
c4e8f56
chore: Ignore coverage `PandasLikeDataFrame._is_native`
dangotbanned Mar 30, 2025
ba7506a
feat: Add all `CompliantLazyFrame.from_native`
dangotbanned Mar 30, 2025
fc9506a
feat: Add `LazyNamespace.from_native`
dangotbanned Mar 30, 2025
6ed3c9d
refactor: Get some lazy coverage
dangotbanned Mar 30, 2025
9396bbc
refactor: More `polars` coverage
dangotbanned Mar 30, 2025
43cf204
Merge branch 'main' into compliant-from-native
dangotbanned Mar 30, 2025
d628a9f
refactor: reuse `is_spark_like_dataframe`
dangotbanned Mar 30, 2025
e58419a
Merge branch 'compliant-from-native' of https://github.com/narwhals-d…
dangotbanned Mar 30, 2025
de9483e
Merge branch 'main' into compliant-from-native
dangotbanned Mar 30, 2025
3b4feb2
Merge remote-tracking branch 'upstream/main' into compliant-from-native
dangotbanned Mar 31, 2025
4109c99
Merge branch 'main' into compliant-from-native
dangotbanned Mar 31, 2025
cdd1913
Merge branch 'main' into compliant-from-native
dangotbanned Mar 31, 2025
4ab34e0
Merge branch 'main' into compliant-from-native
dangotbanned Apr 3, 2025
4c08fff
Update narwhals/_compliant/namespace.py
MarcoGorelli Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 27 additions & 72 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
import polars as pl
from typing_extensions import Self
from typing_extensions import TypeAlias
from typing_extensions import TypeIs

from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Indices # type: ignore[attr-defined]
from narwhals._arrow.typing import Mask # type: ignore[attr-defined]
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
backend_version = context._backend_version
if isinstance(data, pa.Table):
if cls._is_native(data):
native = data
elif backend_version >= (14,) or isinstance(data, Collection):
native = pa.table(data)
Expand All @@ -109,12 +109,7 @@ def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
else: # pragma: no cover
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
raise TypeError(msg)
return cls(
native,
backend_version=backend_version,
version=context._version,
validate_column_names=True,
)
return cls.from_native(native, context=context)

@classmethod
def from_dict(
Expand All @@ -129,8 +124,16 @@ def from_dict(

pa_schema = Schema(schema).to_arrow() if schema is not None else schema
native = pa.Table.from_pydict(data, schema=pa_schema)
return cls.from_native(native, context=context)

@staticmethod
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
return isinstance(obj, pa.Table)

@classmethod
def from_native(cls, data: pa.Table, /, *, context: _FullContext) -> Self:
return cls(
native,
data,
backend_version=context._backend_version,
version=context._version,
validate_column_names=True,
Expand All @@ -152,12 +155,7 @@ def from_numpy(
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
else:
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
return cls(
native,
backend_version=context._backend_version,
version=context._version,
validate_column_names=True,
)
return cls.from_native(native, context=context)

def __narwhals_namespace__(self: Self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
Expand Down Expand Up @@ -224,15 +222,8 @@ def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, A
return self.native.to_pylist()

def iter_columns(self) -> Iterator[ArrowSeries]:
from narwhals._arrow.series import ArrowSeries

for name, series in zip(self.columns, self.native.itercolumns()):
yield ArrowSeries(
series,
name=name,
backend_version=self._backend_version,
version=self._version,
)
yield ArrowSeries.from_native(series, context=self, name=name)

_iter_columns = iter_columns

Expand All @@ -251,18 +242,10 @@ def iter_rows(
yield from df[i : i + buffer_size].to_pylist()

def get_column(self: Self, name: str) -> ArrowSeries:
from narwhals._arrow.series import ArrowSeries

if not isinstance(name, str):
msg = f"Expected str, got: {type(name)}"
raise TypeError(msg)

return ArrowSeries(
self.native[name],
name=name,
backend_version=self._backend_version,
version=self._version,
)
return ArrowSeries.from_native(self.native[name], context=self, name=name)

def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray:
return self.native.__array__(dtype, copy=copy)
Expand Down Expand Up @@ -304,14 +287,7 @@ def __getitem__(
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType]

if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries

return ArrowSeries(
self.native[item],
name=item,
backend_version=self._backend_version,
version=self._version,
)
return ArrowSeries.from_native(self.native[item], context=self, name=item)
elif (
isinstance(item, tuple)
and len(item) == 2
Expand Down Expand Up @@ -345,7 +321,6 @@ def __getitem__(
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
col_name = (
Expand All @@ -357,18 +332,12 @@ def __getitem__(
msg = "Can not slice with tuple with the first element as a str"
raise TypeError(msg)
if (isinstance(item[0], slice)) and (item[0] == slice(None)):
return ArrowSeries(
self.native[col_name],
name=col_name,
backend_version=self._backend_version,
version=self._version,
return ArrowSeries.from_native(
self.native[col_name], context=self, name=col_name
)
selected_rows = select_rows(self.native, item[0])
return ArrowSeries(
selected_rows[col_name],
name=col_name,
backend_version=self._backend_version,
version=self._version,
return ArrowSeries.from_native(
selected_rows[col_name], context=self, name=col_name
)

elif isinstance(item, slice):
Expand Down Expand Up @@ -589,18 +558,10 @@ def to_dict(
self: Self, *, as_series: bool
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
df = self.native

names_and_values = zip(df.column_names, df.columns)
if as_series:
from narwhals._arrow.series import ArrowSeries

return {
name: ArrowSeries(
col,
name=name,
backend_version=self._backend_version,
version=self._version,
)
name: ArrowSeries.from_native(col, context=self, name=name)
for name, col in names_and_values
}
else:
Expand Down Expand Up @@ -778,26 +739,20 @@ def write_csv(self: Self, file: str | Path | BytesIO | None) -> str | None:
return None

def is_unique(self: Self) -> ArrowSeries:
from narwhals._arrow.series import ArrowSeries

col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
row_index = pa.array(range(len(self)))
keep_idx = (
self.native.append_column(col_token, row_index)
.group_by(self.columns)
.aggregate([(col_token, "min"), (col_token, "max")])
)
return ArrowSeries(
pa.chunked_array(
pc.and_(
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
)
),
name="",
backend_version=self._backend_version,
version=self._version,
native = pa.chunked_array(
pc.and_(
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
)
)
return ArrowSeries.from_native(native, context=self)

def unique(
self: ArrowDataFrame,
Expand Down
6 changes: 5 additions & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from narwhals.utils import Version


class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
class ArrowNamespace(
EagerNamespace[
ArrowDataFrame, ArrowSeries, ArrowExpr, "pa.Table", "ArrowChunkedArray"
]
):
@property
def _dataframe(self) -> type[ArrowDataFrame]:
return ArrowDataFrame
Expand Down
32 changes: 20 additions & 12 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import pandas as pd
import polars as pl
from typing_extensions import Self
from typing_extensions import TypeIs

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
Expand Down Expand Up @@ -135,12 +136,7 @@ def _with_native(
*,
preserve_broadcast: bool = False,
) -> Self:
result = self.__class__(
chunked_array(series),
name=self._name,
backend_version=self._backend_version,
version=self._version,
)
result = self.from_native(chunked_array(series), name=self.name, context=self)
if preserve_broadcast:
result._broadcast = self._broadcast
return result
Expand All @@ -156,18 +152,30 @@ def from_iterable(
) -> Self:
version = context._version
dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None
return cls(
chunked_array([data], dtype_pa),
name=name,
backend_version=context._backend_version,
version=version,
return cls.from_native(
chunked_array([data], dtype_pa), name=name, context=context
)

def _from_scalar(self, value: Any) -> Self:
if self._backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return super()._from_scalar(value)

@staticmethod
def _is_native(obj: ArrowChunkedArray | Any) -> TypeIs[ArrowChunkedArray]:
return isinstance(obj, pa.ChunkedArray)

@classmethod
def from_native(
cls, data: ArrowChunkedArray, /, *, context: _FullContext, name: str = ""
) -> Self:
return cls(
data,
backend_version=context._backend_version,
version=context._version,
name=name,
)

@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self:
return cls.from_iterable(
Expand Down Expand Up @@ -546,7 +554,7 @@ def tail(self: Self, n: int) -> Self:
return self._with_native(self.native.slice(abs(n)))

def is_in(self: Self, other: Any) -> Self:
if isinstance(other, pa.ChunkedArray):
if self._is_native(other):
value_set: ArrowChunkedArray | ArrowArray = other
else:
value_set = pa.array(other)
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_compliant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from narwhals._compliant.group_by import LazyGroupBy
from narwhals._compliant.namespace import CompliantNamespace
from narwhals._compliant.namespace import EagerNamespace
from narwhals._compliant.namespace import LazyNamespace
from narwhals._compliant.selectors import CompliantSelector
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.selectors import EagerSelectorNamespace
Expand Down Expand Up @@ -64,6 +65,7 @@
"IntoCompliantExpr",
"LazyExpr",
"LazyGroupBy",
"LazyNamespace",
"LazySelectorNamespace",
"LazyWhen",
"NativeFrameT_co",
Expand Down
35 changes: 22 additions & 13 deletions narwhals/_compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from narwhals._compliant.typing import CompliantSeriesT
from narwhals._compliant.typing import EagerExprT_contra
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import NativeFrameT_co
from narwhals._compliant.typing import NativeFrameT
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._translate import ArrowConvertible
from narwhals._translate import DictConvertible
from narwhals._translate import FromNative
from narwhals._translate import NumpyConvertible
from narwhals.utils import Version
from narwhals.utils import _StoresNative
Expand Down Expand Up @@ -57,11 +58,12 @@ class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
_StoresNative[NativeFrameT_co],
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
Comment on lines +61 to +62
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Been thinking about combining these as one protocol.

_StoreNative might still be useful on its own, but something like this could be nice:

from typing import Protocol, TypeVar

from narwhals._translate import FromNative
from narwhals.utils import _StoresNative

NativeT = TypeVar("NativeT")

class Compliant(_StoresNative[NativeT], FromNative[NativeT], Protocol[NativeT]): ...

Not super important though πŸ˜…

Sized,
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT_co],
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT],
):
_native_frame: Any
_native_frame: NativeFrameT
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
Expand All @@ -80,6 +82,8 @@ def from_dict(
schema: Mapping[str, DType] | Schema | None,
) -> Self: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
Expand All @@ -105,8 +109,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
def _with_version(self, version: Version) -> Self: ...

@property
def native(self) -> NativeFrameT_co:
return self._native_frame # type: ignore[no-any-return]
def native(self) -> NativeFrameT:
return self._native_frame

@property
def columns(self) -> Sequence[str]: ...
Expand Down Expand Up @@ -210,16 +214,21 @@ def write_parquet(self, file: str | Path | BytesIO) -> None: ...


class CompliantLazyFrame(
_StoresNative[NativeFrameT_co], Protocol[CompliantExprT_contra, NativeFrameT_co]
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
Protocol[CompliantExprT_contra, NativeFrameT],
):
_native_frame: Any
_native_frame: NativeFrameT
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version

def __narwhals_lazyframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...

@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...

def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
Expand All @@ -234,8 +243,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
def _with_version(self, version: Version) -> Self: ...

@property
def native(self) -> NativeFrameT_co:
return self._native_frame # type: ignore[no-any-return]
def native(self) -> NativeFrameT:
return self._native_frame

@property
def columns(self) -> Sequence[str]: ...
Expand Down Expand Up @@ -307,9 +316,9 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:


class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT_co],
CompliantLazyFrame[EagerExprT_contra, NativeFrameT_co],
Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT_co],
CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT],
CompliantLazyFrame[EagerExprT_contra, NativeFrameT],
Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT],
):
def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
Expand Down
Loading
Loading