Skip to content

Commit 52665fa

Browse files
feat: Allow drop_duplicates over unordered dataframe (#2303)
1 parent 41630b5 commit 52665fa

File tree

6 files changed

+49
-33
lines changed

6 files changed

+49
-33
lines changed

bigframes/core/block_transforms.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,40 +67,39 @@ def indicate_duplicates(
6767
if keep not in ["first", "last", False]:
6868
raise ValueError("keep must be one of 'first', 'last', or False'")
6969

70+
rownums = agg_expressions.WindowExpression(
71+
agg_expressions.NullaryAggregation(
72+
agg_ops.RowNumberOp(),
73+
),
74+
window=windows.unbound(grouping_keys=tuple(columns)),
75+
)
76+
count = agg_expressions.WindowExpression(
77+
agg_expressions.NullaryAggregation(
78+
agg_ops.SizeOp(),
79+
),
80+
window=windows.unbound(grouping_keys=tuple(columns)),
81+
)
82+
7083
if keep == "first":
7184
# Count how many copies occur up to current copy of value
7285
# Discard this value if there are copies BEFORE
73-
window_spec = windows.cumulative_rows(
74-
grouping_keys=tuple(columns),
75-
)
86+
predicate = ops.gt_op.as_expr(rownums, ex.const(0))
7687
elif keep == "last":
7788
# Count how many copies occur up to current copy of values
7889
# Discard this value if there are copies AFTER
79-
window_spec = windows.inverse_cumulative_rows(
80-
grouping_keys=tuple(columns),
81-
)
90+
predicate = ops.lt_op.as_expr(rownums, ops.sub_op.as_expr(count, ex.const(1)))
8291
else: # keep == False
8392
# Count how many copies of the value occur in entire series.
8493
# Discard this value if there are copies ANYWHERE
85-
window_spec = windows.unbound(grouping_keys=tuple(columns))
86-
block, dummy = block.create_constant(1)
87-
# use row number as will work even with partial ordering
88-
block, val_count_col_id = block.apply_window_op(
89-
dummy,
90-
agg_ops.sum_op,
91-
window_spec=window_spec,
92-
)
93-
block, duplicate_indicator = block.project_expr(
94-
ops.gt_op.as_expr(val_count_col_id, ex.const(1))
94+
predicate = ops.gt_op.as_expr(count, ex.const(1))
95+
96+
block = block.project_block_exprs(
97+
[predicate],
98+
labels=[None],
9599
)
96100
return (
97-
block.drop_columns(
98-
(
99-
dummy,
100-
val_count_col_id,
101-
)
102-
),
103-
duplicate_indicator,
101+
block,
102+
block.value_columns[-1],
104103
)
105104

106105

bigframes/core/compile/polars/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,9 @@ def compile_agg_op(
547547
return pl.col(*inputs).first()
548548
if isinstance(op, agg_ops.LastOp):
549549
return pl.col(*inputs).last()
550+
if isinstance(op, agg_ops.RowNumberOp):
551+
# pl.row_index is not yet stable enough to use here, and only supports polars>=1.32
552+
return pl.int_range(pl.len(), dtype=pl.Int64)
550553
if isinstance(op, agg_ops.ShiftOp):
551554
return pl.col(*inputs).shift(op.periods)
552555
if isinstance(op, agg_ops.DiffOp):

bigframes/core/indexes/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,6 @@ def dropna(self, how: typing.Literal["all", "any"] = "any") -> Index:
624624
return Index(result)
625625

626626
def drop_duplicates(self, *, keep: __builtins__.str = "first") -> Index:
627-
if keep is not False:
628-
validations.enforce_ordered(self, "drop_duplicates")
629627
block = block_ops.drop_duplicates(self._block, self._block.index_columns, keep)
630628
return Index(block)
631629

bigframes/dataframe.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5054,8 +5054,6 @@ def drop_duplicates(
50545054
*,
50555055
keep: str = "first",
50565056
) -> DataFrame:
5057-
if keep is not False:
5058-
validations.enforce_ordered(self, "drop_duplicates(keep != False)")
50595057
if subset is None:
50605058
column_ids = self._block.value_columns
50615059
elif utils.is_list_like(subset):
@@ -5069,8 +5067,6 @@ def drop_duplicates(
50695067
return DataFrame(block)
50705068

50715069
def duplicated(self, subset=None, keep: str = "first") -> bigframes.series.Series:
5072-
if keep is not False:
5073-
validations.enforce_ordered(self, "duplicated(keep != False)")
50745070
if subset is None:
50755071
column_ids = self._block.value_columns
50765072
else:

bigframes/series.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,8 +2227,6 @@ def reindex_like(self, other: Series, *, validate: typing.Optional[bool] = None)
22272227
return self.reindex(other.index, validate=validate)
22282228

22292229
def drop_duplicates(self, *, keep: str = "first") -> Series:
2230-
if keep is not False:
2231-
validations.enforce_ordered(self, "drop_duplicates(keep != False)")
22322230
block = block_ops.drop_duplicates(self._block, (self._value_column,), keep)
22332231
return Series(block)
22342232

@@ -2249,8 +2247,6 @@ def unique(self, keep_order=True) -> Series:
22492247
return Series(block.select_columns(result).reset_index())
22502248

22512249
def duplicated(self, keep: str = "first") -> Series:
2252-
if keep is not False:
2253-
validations.enforce_ordered(self, "duplicated(keep != False)")
22542250
block, indicator = block_ops.indicate_duplicates(
22552251
self._block, (self._value_column,), keep
22562252
)

tests/system/large/test_dataframe.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,27 @@ def test_cov_150_columns(scalars_df_numeric_150_columns_maybe_ordered):
4040
check_index_type=False,
4141
check_column_type=False,
4242
)
43+
44+
45+
@pytest.mark.parametrize(
46+
("keep",),
47+
[
48+
("first",),
49+
("last",),
50+
(False,),
51+
],
52+
)
53+
def test_drop_duplicates_unordered(
54+
scalars_df_unordered, scalars_pandas_df_default_index, keep
55+
):
56+
uniq_scalar_rows = scalars_df_unordered.drop_duplicates(
57+
subset="bool_col", keep=keep
58+
)
59+
uniq_pd_rows = scalars_pandas_df_default_index.drop_duplicates(
60+
subset="bool_col", keep=keep
61+
)
62+
63+
assert len(uniq_scalar_rows) == len(uniq_pd_rows)
64+
assert len(uniq_scalar_rows.groupby("bool_col")) == len(
65+
uniq_pd_rows.groupby("bool_col")
66+
)

0 commit comments

Comments
 (0)