Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.pandas/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Computations / Descriptive Stats
GroupBy.mean
GroupBy.median
GroupBy.min
GroupBy.nth
GroupBy.rank
GroupBy.sem
GroupBy.std
Expand Down
98 changes: 98 additions & 0 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,104 @@ def sem(col: Column) -> Column:
bool_to_numeric=True,
)

# TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do we want to create a ticket as a sub-tasks of SPARK-40327 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add it when we start to implement the parameters

def nth(self, n: int) -> FrameLike:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column:

Since 3.1, there are a def nth_value in spark, but considering negetive index and we are going to support list and slice in the future, I think use row_number is right in here, but just FYI if you have other idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can not apply nth_value for this purpose, it return the n-th row within one partition for each input row, can not use it to filter out unnecessary rows.

"""
Take the nth row from each group.

.. versionadded:: 3.4.0

Parameters
----------
n : int
A single nth value for the row

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns
-------

Returns
-------
Series or DataFrame

Notes
-----
There is a behavior difference between pandas-on-Spark and pandas:

* when there is no aggregation column, and `n` not equal to 0 or -1,
the returned empty dataframe may have an index with different lenght `__len__`.

Examples
--------
>>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
>>> g = df.groupby('A')
>>> g.nth(0)
B
A
1 NaN
2 3.0
>>> g.nth(1)
B
A
1 2.0
2 5.0
>>> g.nth(-1)
B
A
1 4.0
2 5.0

See Also
--------
pyspark.pandas.Series.groupby
pyspark.pandas.DataFrame.groupby
"""
if isinstance(n, slice) or is_list_like(n):
raise NotImplementedError("n doesn't support slice or list for now")
if not isinstance(n, int):
raise TypeError("Invalid index %s" % type(n).__name__)

groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
internal, agg_columns, sdf = self._prepare_reduce(
groupkey_names=groupkey_names,
accepted_spark_types=None,
bool_to_numeric=False,
)
psdf: DataFrame = DataFrame(internal)

if len(psdf._internal.column_labels) > 0:
window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
tmp_row_number_col = verify_temp_column_name(sdf, "__tmp_row_number_col__")
if n >= 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate n with a friendly exception?

>>> g.nth('C')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/yikun/venv/lib/python3.9/site-packages/pandas/core/groupby/groupby.py", line 2304, in nth
    raise TypeError("n needs to be an int or a list/set/tuple of ints")
TypeError: n needs to be an int or a list/set/tuple of ints

sdf = (
psdf._internal.spark_frame.withColumn(
tmp_row_number_col, F.row_number().over(window1)
)
.where(F.col(tmp_row_number_col) == n + 1)
.drop(tmp_row_number_col)
)
else:
window2 = Window.partitionBy(*groupkey_names).rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
tmp_group_size_col = verify_temp_column_name(sdf, "__tmp_group_size_col__")
sdf = (
psdf._internal.spark_frame.withColumn(
tmp_group_size_col, F.count(F.lit(0)).over(window2)
)
.withColumn(tmp_row_number_col, F.row_number().over(window1))
.where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n)
.drop(tmp_group_size_col, tmp_row_number_col)
)
else:
sdf = sdf.select(*groupkey_names).distinct()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test to cover this? I'm a little fuzzy about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there seems a bug in Pandas' GroupBy.nth, its returned index varies with n:

In [23]: pdf
Out[23]: 
   A    B  C      D
0  1  3.1  a   True
1  2  4.1  b  False
2  1  4.1  b  False
3  2  3.1  a   True

In [24]: pdf.groupby(["A", "B", "C", "D"]).nth(0)
Out[24]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (1, 4.1, b, False), (2, 3.1, a, True), (2, 4.1, b, False)]

In [25]: pdf.groupby(["A", "B", "C", "D"]).nth(0).index
Out[25]: 
MultiIndex([(1, 3.1, 'a',  True),
            (1, 4.1, 'b', False),
            (2, 3.1, 'a',  True),
            (2, 4.1, 'b', False)],
           names=['A', 'B', 'C', 'D'])

In [26]: pdf.groupby(["A", "B", "C", "D"]).nth(1)
Out[26]: 
Empty DataFrame
Columns: []
Index: []

In [27]: pdf.groupby(["A", "B", "C", "D"]).nth(1).index
Out[27]: MultiIndex([], names=['A', 'B', 'C', 'D'])

In [28]: pdf.groupby(["A", "B", "C", "D"]).nth(-1)
Out[28]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (1, 4.1, b, False), (2, 3.1, a, True), (2, 4.1, b, False)]

In [29]: pdf.groupby(["A", "B", "C", "D"]).nth(-1).index
Out[29]: 
MultiIndex([(1, 3.1, 'a',  True),
            (1, 4.1, 'b', False),
            (2, 3.1, 'a',  True),
            (2, 4.1, 'b', False)],
           names=['A', 'B', 'C', 'D'])

In [30]: pdf.groupby(["A", "B", "C", "D"]).nth(-2)
Out[30]: 
Empty DataFrame
Columns: []
Index: []

In [31]: pdf.groupby(["A", "B", "C", "D"]).nth(-2).index
Out[31]: MultiIndex([], names=['A', 'B', 'C', 'D'])

while other functions' behavior in Pandas and PS are like this:

In [17]: pdf
Out[17]: 
   A    B  C      D
0  1  3.1  a   True
1  2  4.1  b  False
2  1  4.1  b  False
3  2  3.1  a   True

In [18]: pdf.groupby(["A", "B", "C", "D"]).max()
Out[18]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (1, 4.1, b, False), (2, 3.1, a, True), (2, 4.1, b, False)]

In [19]: pdf.groupby(["A", "B", "C", "D"]).mad()
Out[19]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (1, 4.1, b, False), (2, 3.1, a, True), (2, 4.1, b, False)]

In [20]: psdf.groupby(["A", "B", "C", "D"]).max()
Out[20]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (2, 4.1, b, False), (1, 4.1, b, False), (2, 3.1, a, True)]

In [21]: psdf.groupby(["A", "B", "C", "D"]).mad()
Out[21]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (2, 4.1, b, False), (1, 4.1, b, False), (2, 3.1, a, True)]

In [22]: 

In [22]: psdf.groupby(["A", "B", "C", "D"]).nth(0)
Out[22]: 
Empty DataFrame
Columns: []
Index: [(1, 3.1, a, True), (2, 4.1, b, False), (1, 4.1, b, False), (2, 3.1, a, True)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think we can not add a test for it for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a bug in pandas, maybe we should add a test by manually creating the expected result rather than just skipping the test ?

e.g.

if LooseVersion("1.1.1") <= LooseVersion(pd.__version__) < LooseVersion("1.1.4"):
# a pandas bug: https://github.com/databricks/koalas/pull/1818#issuecomment-703961980
self.assert_eq(psser.astype(str).tolist(), ["hi", "hi ", " ", " \t", "", "None"])
else:
self.assert_eq(psser.astype(str), pser.astype(str))

Copy link
Contributor

@itholic itholic Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... I just noticed that we're following the pandas behavior even though there is a bug in pandas.

When there is a bug in pandas, we usually do something like this:

  • we don't follow the behavior of pandas, we just assume it works properly and implement it.

  • comment the link related pandas issues to the test, from pandas repository(https://github.com/pandas-dev/pandas/issues/...) as below:

    • if LooseVersion(pd.__version__) < LooseVersion("1.1.3"):
      # pandas < 1.1.0: object dtype is returned after negation
      # pandas 1.1.1 and 1.1.2:
      # a TypeError "bad operand type for unary -: 'IntegerArray'" is raised
      # Please refer to https://github.com/pandas-dev/pandas/issues/36063.
      self.check_extension(pd.Series([-1, -2, -3, None], dtype=pser.dtype), -psser)
      else:
      self.check_extension(-pser, -psser)
    • if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"):
      # Limit pandas version due to
      # https://github.com/pandas-dev/pandas/issues/31204
      self.check_extension(pser.astype(dtype), psser.astype(dtype))
      else:
      self.check_extension(pser.astype(dtype), psser.astype(dtype))
  • If it's not clear that it's a bug (unless it's not an officially discussed as a bug in pandas community), we can just follow the pandas behavior.

Copy link
Contributor

@itholic itholic Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can post a question for pandas community if it's a bug or intended behavior, and comment the question link it if they reply like "yes, it's a bug".


internal = internal.copy(
spark_frame=sdf,
index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
data_fields=None,
)

return self._prepare_return(DataFrame(internal))

def all(self, skipna: bool = True) -> FrameLike:
"""
Returns True if all values in the group are truthful, else False.
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/pandas/missing/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class MissingPandasLikeDataFrameGroupBy:
# Functions
boxplot = _unsupported_function("boxplot")
ngroup = _unsupported_function("ngroup")
nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
Expand Down Expand Up @@ -93,7 +92,6 @@ class MissingPandasLikeSeriesGroupBy:
aggregate = _unsupported_function("aggregate")
describe = _unsupported_function("describe")
ngroup = _unsupported_function("ngroup")
nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,17 @@ def test_last(self):
self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None))
self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True))

def test_nth(self):
for n in [0, 1, 2, 128, -1, -2, -128]:
self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))

with self.assertRaisesRegex(NotImplementedError, "slice or list"):
self.psdf.groupby("B").nth(slice(0, 2))
with self.assertRaisesRegex(NotImplementedError, "slice or list"):
self.psdf.groupby("B").nth([0, 1, -1])
with self.assertRaisesRegex(TypeError, "Invalid index"):
self.psdf.groupby("B").nth("x")

def test_cumcount(self):
pdf = pd.DataFrame(
{
Expand Down