Skip to content
17 changes: 11 additions & 6 deletions dataframe_api_compat/pandas_standard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,12 @@ def convert_to_standard_compliant_column(
raise ValueError(msg)
if ser.name is None:
ser = ser.rename("")
return Column(ser, api_version=api_version or "2023.11-beta", df=None)
return Column(
ser,
api_version=api_version or "2023.11-beta",
df=None,
is_persisted=True,
)


def convert_to_standard_compliant_dataframe(
Expand Down Expand Up @@ -253,15 +258,14 @@ def dataframe_from_columns(
api_versions.add(col._api_version) # type: ignore[attr-defined]
return DataFrame(pd.DataFrame(data), api_version=list(api_versions)[0])

def column_from_1d_array(
def column_from_1d_array( # type: ignore[override]
self,
data: Any,
*,
dtype: DType,
name: str | None = None,
) -> Column:
ser = pd.Series(data, dtype=map_standard_dtype_to_pandas_dtype(dtype), name=name)
return Column(ser, api_version=self._api_version, df=None)
ser = pd.Series(data, name=name)
return Column(ser, api_version=self._api_version, df=None, is_persisted=True)

def column_from_sequence(
self,
Expand All @@ -272,10 +276,11 @@ def column_from_sequence(
) -> Column:
ser = pd.Series(
sequence,
# todo make optional?
dtype=map_standard_dtype_to_pandas_dtype(dtype),
name=name,
)
return Column(ser, api_version=self._api_version, df=None)
return Column(ser, api_version=self._api_version, df=None, is_persisted=True)

def concat(
self,
Expand Down
55 changes: 35 additions & 20 deletions dataframe_api_compat/pandas_standard/column_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import datetime
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -51,15 +52,23 @@ def __init__(
df
DataFrame this column originates from.
"""
from dataframe_api_compat.pandas_standard.scalar_object import Scalar

self._name = series.name or ""
self._series = series
self._api_version = api_version
self._df = df
self._scalar = Scalar
self._is_persisted = is_persisted

def _to_scalar(self, value: Any) -> Scalar:
from dataframe_api_compat.pandas_standard.scalar_object import Scalar

return Scalar(
value,
api_version=self._api_version,
df=self._df,
is_persisted=self._is_persisted,
)

def __repr__(self) -> str: # pragma: no cover
header = f" Standard Column (api_version={self._api_version}) "
length = len(header)
Expand All @@ -83,6 +92,7 @@ def _from_series(self, series: pd.Series) -> Column:
series.reset_index(drop=True),
api_version=self._api_version,
df=self._df,
is_persisted=self._is_persisted,
)

def _materialise(self) -> pd.Series:
Expand All @@ -102,6 +112,12 @@ def __column_namespace__(
)

def persist(self) -> Column:
if self._is_persisted:
warnings.warn(
"Calling `.persist` on Column that was already persisted",
UserWarning,
stacklevel=2,
)
return Column(
self.column,
df=None,
Expand Down Expand Up @@ -136,11 +152,8 @@ def filter(self, mask: Column) -> Column:

def get_value(self, row_number: int) -> Any:
ser = self.column
return self._scalar(
return self._to_scalar(
ser.iloc[row_number],
api_version=self._api_version,
df=self._df,
is_persisted=self._is_persisted,
)

def slice_rows(
Expand Down Expand Up @@ -270,35 +283,35 @@ def __invert__(self: Column) -> Column:

def any(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
ser = self.column
return self._scalar(ser.any(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.any())

def all(self, *, skip_nulls: bool | Scalar = True) -> Scalar:
ser = self.column
return self._scalar(ser.all(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.all())

def min(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.min(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.min())

def max(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.max(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.max())

def sum(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.sum(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.sum())

def prod(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.prod(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.prod())

def median(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.median(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.median())

def mean(self, *, skip_nulls: bool | Scalar = True) -> Any:
ser = self.column
return self._scalar(ser.mean(), api_version=self._api_version, df=self._df)
return self._to_scalar(ser.mean())

def std(
self,
Expand All @@ -307,10 +320,8 @@ def std(
skip_nulls: bool | Scalar = True,
) -> Any:
ser = self.column
return self._scalar(
return self._to_scalar(
ser.std(ddof=correction),
api_version=self._api_version,
df=self._df,
)

def var(
Expand All @@ -320,16 +331,20 @@ def var(
skip_nulls: bool | Scalar = True,
) -> Any:
ser = self.column
return self._scalar(
return self._to_scalar(
ser.var(ddof=correction),
api_version=self._api_version,
df=self._df,
)

def __len__(self) -> int:
ser = self._materialise()
return len(ser)

def n_unique(self) -> Scalar: # pragma: no cover (todo, still needs adding upstream)
ser = self.column
return self._to_scalar(
ser.nunique(),
)

# Transformations

def is_null(self) -> Column:
Expand Down
8 changes: 8 additions & 0 deletions dataframe_api_compat/pandas_standard/dataframe_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import warnings
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
Expand Down Expand Up @@ -91,6 +92,7 @@ def _from_dataframe(self, df: pd.DataFrame) -> DataFrame:
return DataFrame(
df,
api_version=self._api_version,
is_persisted=self._is_persisted,
)

# Properties
Expand Down Expand Up @@ -515,6 +517,12 @@ def join(
)

def persist(self) -> DataFrame:
if self._is_persisted:
warnings.warn(
"Calling `.persist` on DataFrame that was already persisted",
UserWarning,
stacklevel=2,
)
return DataFrame(
self.dataframe,
api_version=self._api_version,
Expand Down
14 changes: 13 additions & 1 deletion dataframe_api_compat/pandas_standard/scalar_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -35,7 +36,12 @@ def __scalar_namespace__(self) -> Namespace:
return Namespace(api_version=self._api_version)

def _from_scalar(self, scalar: Scalar) -> Scalar:
return Scalar(scalar, df=self._df, api_version=self._api_version)
return Scalar(
scalar,
df=self._df,
api_version=self._api_version,
is_persisted=self._is_persisted,
)

@property
def dtype(self) -> DType: # pragma: no cover # todo
Expand All @@ -57,6 +63,12 @@ def _materialise(self) -> Any:
return self._value

def persist(self) -> Scalar:
if self._is_persisted:
warnings.warn(
"Calling `.persist` on Scalar that was already persisted",
UserWarning,
stacklevel=2,
)
return Scalar(
self._value,
df=None,
Expand Down
25 changes: 13 additions & 12 deletions dataframe_api_compat/polars_standard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,14 @@ def dataframe_from_columns(
api_version=list(api_version)[0],
)

def column_from_1d_array(
def column_from_1d_array( # type: ignore[override]
self,
array: Any,
*,
dtype: DType,
name: str = "",
) -> Column:
ser = pl.Series(
values=array,
dtype=_map_standard_to_polars_dtypes(dtype),
name=name,
)
return Column(pl.lit(ser), api_version=self.api_version, df=None)
ser = pl.Series(values=array, name=name)
return Column(ser, api_version=self.api_version, df=None, is_persisted=True)

def column_from_sequence(
self,
Expand All @@ -174,7 +169,7 @@ def column_from_sequence(
dtype=_map_standard_to_polars_dtypes(dtype),
name=name,
)
return Column(pl.lit(ser), api_version=self.api_version, df=None)
return Column(ser, api_version=self.api_version, df=None, is_persisted=True)

def dataframe_from_2d_array(
self,
Expand Down Expand Up @@ -303,16 +298,17 @@ def concat(
dataframes: Sequence[DataFrameT],
) -> DataFrame:
dataframes = cast("Sequence[DataFrame]", dataframes)
dfs: list[pl.LazyFrame] = []
dfs: list[pl.LazyFrame | pl.DataFrame] = []
api_versions: set[str] = set()
for df in dataframes:
dfs.append(df.dataframe)
api_versions.add(df._api_version)
if len(api_versions) > 1: # pragma: no cover
msg = f"Multiple api versions found: {api_versions}"
raise ValueError(msg)
# todo raise if not all share persistedness
return DataFrame(
pl.concat(dfs),
pl.concat(dfs), # type: ignore[type-var]
api_version=api_versions.pop(),
)

Expand Down Expand Up @@ -427,7 +423,12 @@ def convert_to_standard_compliant_column(
ser: pl.Series,
api_version: str | None = None,
) -> Column:
return Column(pl.lit(ser), api_version=api_version or "2023.11-beta", df=None)
return Column(
ser,
api_version=api_version or "2023.11-beta",
df=None,
is_persisted=True,
)


def convert_to_standard_compliant_dataframe(
Expand Down
Loading