Skip to content

Commit

Permalink
Fix pytests due to loc bug
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Dattagupta <[email protected]>
  • Loading branch information
ayushdg committed Apr 23, 2024
1 parent e467078 commit 9f38aa1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def test_retain_score_filter(self, letter_count_data):
filtered_data = filter_step(letter_count_data)

expected_indices = [2, 3]
expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices])
# Compute before loc due to https://github.com/dask/dask-expr/issues/1036
expected_data = letter_count_data.df.compute().loc[expected_indices]
expected_data = DocumentDataset(dd.from_pandas(expected_data, 2))
expected_data.df[score_field] = pd.Series([5, 7], index=expected_data.df.index)
assert all_equal(
expected_data, filtered_data
Expand All @@ -168,7 +170,9 @@ def test_filter(self, letter_count_data):
filtered_data = filter_step(scored_data)

expected_indices = [2, 3]
expected_data = letter_count_data.df.loc[expected_indices]
# Compute before loc due to https://github.com/dask/dask-expr/issues/1036
expected_data = letter_count_data.df.compute().loc[expected_indices]
expected_data = dd.from_pandas(expected_data, 2)
expected_data[score_field] = pd.Series([5, 7], index=expected_data.index)
expected_data = DocumentDataset(expected_data)
assert all_equal(
Expand Down

0 comments on commit 9f38aa1

Please sign in to comment.