Skip to content

Commit

Permalink
Refine Frame._reduce_for_stat_function. (#1975)
Browse files Browse the repository at this point in the history
Refines `DataFrame/Series._reduce_for_stat_function` to avoid special handling based on a specific function.

Also:
- Consolidates the implementations of `count` and support `numeric_only` parameter.
- Adds argument type annotations.
  • Loading branch information
ueshin authored Dec 21, 2020
1 parent b81afcc commit c973195
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 240 deletions.
72 changes: 1 addition & 71 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,6 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=True):
from databricks.koalas import Series
from databricks.koalas.series import first_series

if name not in ("count", "min", "max") and not numeric_only:
raise ValueError("Disabling 'numeric_only' parameter is not supported.")

axis = validate_axis(axis)
if axis == 0:
exprs = [F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)]
Expand All @@ -639,15 +636,9 @@ def _reduce_for_stat_function(self, sfun, name, axis=None, numeric_only=True):
col_type = self._internal.spark_type_for(label)

is_numeric_or_boolean = isinstance(col_type, (NumericType, BooleanType))
min_or_max = name in ("min", "max")
keep_column = not numeric_only or is_numeric_or_boolean or min_or_max
keep_column = not numeric_only or is_numeric_or_boolean

if keep_column:
if isinstance(col_type, BooleanType) and not min_or_max:
# Stat functions cannot be used with boolean values by default
# Thus, cast to integer (true to 1 and false to 0)
# Exclude the min and max methods though since those work with booleans
col_sdf = col_sdf.cast("integer")
if num_args == 1:
# Only pass in the column if sfun accepts only one arg
col_sdf = sfun(col_sdf)
Expand Down Expand Up @@ -6025,67 +6016,6 @@ def select_dtypes(self, include=None, exclude=None) -> "DataFrame":
self._internal.with_new_columns(data_spark_columns, column_labels=column_labels)
)

def count(self, axis=None) -> pd.Series:
"""
Count non-NA cells for each column.
The values `None`, `NaN` are considered NA.
Parameters
----------
axis : {0 or ‘index’, 1 or ‘columns’}, default 0
If 0 or ‘index’ counts are generated for each column. If 1 or ‘columns’ counts are
generated for each row.
Returns
-------
pandas.Series
See Also
--------
Series.count: Number of non-NA elements in a Series.
DataFrame.shape: Number of DataFrame rows and columns (including NA
elements).
DataFrame.isna: Boolean same-sized DataFrame showing places of NA
elements.
Examples
--------
Constructing DataFrame from a dictionary:
>>> df = ks.DataFrame({"Person":
... ["John", "Myla", "Lewis", "John", "Myla"],
... "Age": [24., np.nan, 21., 33, 26],
... "Single": [False, True, True, True, False]},
... columns=["Person", "Age", "Single"])
>>> df
Person Age Single
0 John 24.0 False
1 Myla NaN True
2 Lewis 21.0 True
3 John 33.0 True
4 Myla 26.0 False
Notice the uncounted NA values:
>>> df.count()
Person 5
Age 4
Single 5
dtype: int64
>>> df.count(axis=1)
0 3
1 2
2 3
3 3
4 3
dtype: int64
"""
return self._reduce_for_stat_function(
Frame._count_expr, name="count", axis=axis, numeric_only=False
)

def droplevel(self, level, axis=0) -> "DataFrame":
"""
Return DataFrame with requested index / column level(s) removed.
Expand Down
Loading

0 comments on commit c973195

Please sign in to comment.