Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 42 additions & 59 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,28 @@ def __narwhals_lazyframe__(self: Self) -> Self:

def _change_version(self: Self, version: Version) -> Self:
return self.__class__(
self._native_frame,
backend_version=self._backend_version,
version=version,
self.native, backend_version=self._backend_version, version=version
)

def _from_native_frame(self: Self, df: Any) -> Self:
return self.__class__(
df,
backend_version=self._backend_version,
version=self._version,
df, backend_version=self._backend_version, version=self._version
)

def _iter_columns(self) -> Iterator[dx.Series]:
for _col, ser in self._native_frame.items(): # noqa: PERF102
for _col, ser in self.native.items(): # noqa: PERF102
yield ser

def with_columns(self: Self, *exprs: DaskExpr) -> Self:
df = self._native_frame
new_series = evaluate_exprs(self, *exprs)
df = df.assign(**dict(new_series))
return self._from_native_frame(df)
return self._from_native_frame(self.native.assign(**dict(new_series)))

def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame[Any, Any, Any]:
result = self._native_frame.compute(**kwargs)
result = self.native.compute(**kwargs)

if backend is None or backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
Expand Down Expand Up @@ -143,14 +137,13 @@ def columns(self: Self) -> list[str]:

def filter(self: Self, predicate: DaskExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate._call(self)[0]

return self._from_native_frame(self._native_frame.loc[mask])
mask = predicate(self)[0]
return self._from_native_frame(self.native.loc[mask])

def simple_select(self: Self, *column_names: str) -> Self:
return self._from_native_frame(
select_columns_by_name(
self._native_frame,
self.native,
list(column_names),
self._backend_version,
self._implementation,
Expand All @@ -165,7 +158,7 @@ def aggregate(self: Self, *exprs: DaskExpr) -> Self:
def select(self: Self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df = select_columns_by_name(
self._native_frame.assign(**dict(new_series)),
self.native.assign(**dict(new_series)),
[s[0] for s in new_series],
self._backend_version,
self._implementation,
Expand All @@ -174,19 +167,19 @@ def select(self: Self, *exprs: DaskExpr) -> Self:

def drop_nulls(self: Self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._from_native_frame(self._native_frame.dropna())
return self._from_native_frame(self.native.dropna())
plx = self.__narwhals_namespace__()
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))

@property
def schema(self: Self) -> dict[str, DType]:
if self._cached_schema is None:
native_dtypes = self._native_frame.dtypes
native_dtypes = self.native.dtypes
self._cached_schema = {
col: native_to_narwhals_dtype(
native_dtypes[col], self._version, self._implementation
)
for col in self._native_frame.columns
for col in self.native.columns
}
return self._cached_schema

Expand All @@ -198,23 +191,21 @@ def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self:
compliant_frame=self, columns=columns, strict=strict
)

return self._from_native_frame(self._native_frame.drop(columns=to_drop))
return self._from_native_frame(self.native.drop(columns=to_drop))

def with_row_index(self: Self, name: str) -> Self:
# Implementation is based on the following StackOverflow reply:
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
return self._from_native_frame(
add_row_index(
self._native_frame, name, self._backend_version, self._implementation
)
add_row_index(self.native, name, self._backend_version, self._implementation)
)

def rename(self: Self, mapping: Mapping[str, str]) -> Self:
return self._from_native_frame(self._native_frame.rename(columns=mapping))
return self._from_native_frame(self.native.rename(columns=mapping))

def head(self: Self, n: int) -> Self:
return self._from_native_frame(
self._native_frame.head(n=n, compute=False, npartitions=-1)
self.native.head(n=n, compute=False, npartitions=-1)
)

def unique(
Expand All @@ -224,17 +215,16 @@ def unique(
keep: Literal["any", "none"],
) -> Self:
check_column_exists(self.columns, subset)
native_frame = self._native_frame
if keep == "none":
subset = subset or self.columns
token = generate_temporary_column_name(n_bytes=8, columns=subset)
ser = native_frame.groupby(subset).size().rename(token)
ser = self.native.groupby(subset).size().rename(token)
ser = ser[ser == 1]
unique = ser.reset_index().drop(columns=token)
result = native_frame.merge(unique, on=subset, how="inner")
result = self.native.merge(unique, on=subset, how="inner")
else:
mapped_keep = {"any": "first"}.get(keep, keep)
result = native_frame.drop_duplicates(subset=subset, keep=mapped_keep)
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
return self._from_native_frame(result)

def sort(
Expand All @@ -243,14 +233,13 @@ def sort(
descending: bool | Sequence[bool],
nulls_last: bool,
) -> Self:
df = self._native_frame
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
na_position = "last" if nulls_last else "first"
position = "last" if nulls_last else "first"
return self._from_native_frame(
df.sort_values(list(by), ascending=ascending, na_position=na_position)
self.native.sort_values(list(by), ascending=ascending, na_position=position)
)

def join(
Expand All @@ -268,15 +257,15 @@ def join(
)

return self._from_native_frame(
self._native_frame.assign(**{key_token: 0})
self.native.assign(**{key_token: 0})
.merge(
other._native_frame.assign(**{key_token: 0}),
other.native.assign(**{key_token: 0}),
how="inner",
left_on=key_token,
right_on=key_token,
suffixes=("", suffix),
)
.drop(columns=key_token),
.drop(columns=key_token)
)

if how == "anti":
Expand All @@ -289,7 +278,7 @@ def join(
raise TypeError(msg)
other_native = (
select_columns_by_name(
other._native_frame,
other.native,
list(right_on),
self._backend_version,
self._implementation,
Expand All @@ -299,7 +288,7 @@ def join(
)
.drop_duplicates()
)
df = self._native_frame.merge(
df = self.native.merge(
other_native,
how="outer",
indicator=indicator_token, # pyright: ignore[reportArgumentType]
Expand All @@ -316,7 +305,7 @@ def join(
raise TypeError(msg)
other_native = (
select_columns_by_name(
other._native_frame,
other.native,
list(right_on),
self._backend_version,
self._implementation,
Expand All @@ -327,18 +316,14 @@ def join(
.drop_duplicates() # avoids potential rows duplication from inner join
)
return self._from_native_frame(
self._native_frame.merge(
other_native,
how="inner",
left_on=left_on,
right_on=left_on,
self.native.merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)
)

if how == "left":
other_native = other._native_frame
result_native = self._native_frame.merge(
other_native,
result_native = self.native.merge(
other.native,
how="left",
left_on=left_on,
right_on=right_on,
Expand All @@ -361,29 +346,27 @@ def join(
assert right_on is not None # noqa: S101

right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)

other_native = other._native_frame
other_native = other_native.rename(columns=right_on_mapper)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_on = list(right_on_mapper.values()) # we now have the suffixed keys
return self._from_native_frame(
self._native_frame.merge(
self.native.merge(
other_native,
left_on=left_on,
right_on=right_on,
how="outer",
suffixes=("", suffix),
),
)
)

return self._from_native_frame(
self._native_frame.merge(
other._native_frame,
self.native.merge(
other.native,
left_on=left_on,
right_on=right_on,
how=how,
suffixes=("", suffix),
),
)
)

def join_asof(
Expand All @@ -400,8 +383,8 @@ def join_asof(
plx = self.__native_namespace__()
return self._from_native_frame(
plx.merge_asof(
self._native_frame,
other._native_frame,
self.native,
other.native,
left_on=left_on,
right_on=right_on,
left_by=by_left,
Expand All @@ -417,11 +400,11 @@ def group_by(self: Self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy:
return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys)

def tail(self: Self, n: int) -> Self: # pragma: no cover
native_frame = self._native_frame
native_frame = self.native
n_partitions = native_frame.npartitions

if n_partitions == 1:
return self._from_native_frame(self._native_frame.tail(n=n, compute=False))
return self._from_native_frame(self.native.tail(n=n, compute=False))
else:
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
Expand All @@ -446,7 +429,7 @@ def unpivot(
value_name: str,
) -> Self:
return self._from_native_frame(
self._native_frame.melt(
self.native.melt(
id_vars=index,
value_vars=on,
var_name=variable_name,
Expand Down
Loading
Loading