Skip to content

Commit 95fba86

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-39534][PS] Series.argmax only needs single pass
### What changes were proposed in this pull request? compute `Series.argmax ` with one pass ### Why are the changes needed? existing implemation of `Series.argmax` needs two pass on the dataset, the first one is to compute the maximum value, and the second one is to get the index. However, they can be computed on one pass. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UT Closes apache#36927 from zhengruifeng/ps_series_argmax_opt. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 9e468cf commit 95fba86

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

python/pyspark/pandas/series.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -6301,22 +6301,18 @@ def argmax(self, axis: Axis = None, skipna: bool = True) -> int:
63016301
scol = scol_for(sdf, self._internal.data_spark_column_names[0])
63026302

63036303
if skipna:
6304-
sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME)
6304+
sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
63056305
else:
6306-
sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME)
6306+
sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
63076307

6308-
max_value = sdf.select(
6309-
F.first(scol),
6310-
F.first(NATURAL_ORDER_COLUMN_NAME),
6311-
).head()
6308+
results = sdf.select(scol, seq_col_name).take(1)
63126309

6313-
if max_value[1] is None:
6310+
if len(results) == 0:
63146311
raise ValueError("attempt to get argmax of an empty sequence")
6315-
elif max_value[0] is None:
6316-
return -1
6317-
6318-
# If the maximum is achieved in multiple locations, the first row position is returned.
6319-
return sdf.filter(scol == max_value[0]).head()[0]
6312+
else:
6313+
max_value = results[0]
6314+
# If the maximum is achieved in multiple locations, the first row position is returned.
6315+
return -1 if max_value[0] is None else max_value[1]
63206316

63216317
def argmin(self) -> int:
63226318
"""

0 commit comments

Comments
 (0)