Skip to content

Commit 9849164

Browse files
ryantwolfnicoleeeluo
authored andcommitted
Fix indexing in PII Modifier (NVIDIA#55)
* Fix pii index issue Signed-off-by: Ryan Wolf <[email protected]> * Add sequential wrapper Signed-off-by: Ryan Wolf <[email protected]> * Fix pii tests Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]> Signed-off-by: Nicole Luo <[email protected]>
1 parent 0bab063 commit 9849164

File tree

6 files changed

+119
-5
lines changed

6 files changed

+119
-5
lines changed

docs/user-guide/QualityFiltering.rst

+27
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,33 @@ Here is the ``WordCountFilter`` rewritten to use batches in the ``keep_document`
153153
pass_max = score <= self._max_words
154154
return pass_min & pass_max
155155
156+
When you use the ``batched`` decorator, the index of the series returned from the function must remain the same as the index that was passed in.
157+
The index may not be continuous due to filters being applied prior to the current filter.
158+
In the above code, the index will be the same automatically so no change is required.
159+
However, when writing functions that transform the series into a different structure like a list, special care is needed.
160+
The following code example demonstrates what this error may look like, and how to fix it.
161+
162+
.. code-block:: python
163+
164+
class BuggyLengthFilter(DocumentFilter):
165+
166+
@batched
167+
def score_document(self, documents: pd.Series):
168+
scores = []
169+
for document in documents:
170+
scores.append(len(document))
171+
172+
return pd.Series(scores) # Bad! Does not preserve the index
173+
174+
class CorrectLengthFilter(DocumentFilter):
175+
176+
@batched
177+
def score_document(self, documents: pd.Series):
178+
scores = []
179+
for document in documents:
180+
scores.append(len(document))
181+
182+
return pd.Series(scores, index=documents.index) # Good! Preserves the index
156183
157184
158185
-----------------------------------------

nemo_curator/filters/classifier_filter.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42):
3737
self._name = "fasttext_quality_filter"
3838

3939
@batched
40-
def score_document(self, df):
40+
def score_document(self, df: pd.Series):
4141
model_attr = f"{self._name}_{self._model_path}"
4242
try:
4343
model = load_object_on_worker(model_attr, self._load_model, {})
@@ -56,7 +56,7 @@ def _score_document(text):
5656
return df.apply(_score_document)
5757

5858
@batched
59-
def keep_document(self, df):
59+
def keep_document(self, df: pd.Series):
6060
return np.random.pareto(self._alpha, size=len(df)) > 1 - df
6161

6262
def _load_model(self):
@@ -82,7 +82,7 @@ def __init__(self, model_path=None, min_langid_score=0.3):
8282
dask.config.set({"dataframe.convert-string": False})
8383

8484
@batched
85-
def score_document(self, df):
85+
def score_document(self, df: pd.Series):
8686
model_attr = f"{self._name}_{self._model_path}"
8787
try:
8888
model = load_object_on_worker(model_attr, self._load_model, {})

nemo_curator/modifiers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from .c4 import BoilerPlateStringModifier
1616
from .doc_modifier import DocumentModifier
1717
from .fasttext import FastTextLabelModifier
18+
from .pii_modifier import PiiModifier
1819
from .unicode_reformatter import UnicodeReformatter
1920

2021
__all__ = [
2122
"DocumentModifier",
2223
"BoilerPlateStringModifier",
2324
"FastTextLabelModifier",
2425
"UnicodeReformatter",
26+
"PiiModifier",
2527
]

nemo_curator/modifiers/pii_modifier.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def modify_document(self, text: pd.Series, partition_info: Dict = None):
8585
logging.error(
8686
f"Encountered error {str(e)} in partition {partition_info['number']}"
8787
)
88-
return pd.Series([True])
89-
output: pd.Series = pd.Series(output)
88+
return pd.Series([True], index=text.index)
89+
output: pd.Series = pd.Series(output, text.index)
9090
return output
9191

9292
def load_deidentifier(self):

tests/test_filters.py

+17
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,23 @@ def test_score_type(self, letter_count_data):
282282
expected_scores == scores.compute()
283283
), f"Expected {expected_scores} but got {scores}"
284284

285+
def test_chain_filter(self, letter_count_data):
286+
letter_count_filter = LetterCountFilter(min_count=4)
287+
length_filter = BatchedLengthFilter(min_length=8, max_length=11)
288+
filters = Sequential(
289+
[
290+
ScoreFilter(letter_count_filter, text_field="documents"),
291+
ScoreFilter(length_filter, text_field="documents"),
292+
]
293+
)
294+
filtered_data = filters(letter_count_data)
295+
296+
expected_indices = [2]
297+
expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices])
298+
assert all_equal(
299+
expected_data, filtered_data
300+
), f"Expected {expected_data} but got {filtered_data}"
301+
285302

286303
class TestHeuristicFilters:
287304
def test_nonalpha(self):

tests/test_pii_accuracy.py

+68
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,17 @@
1616
import re
1717
from pathlib import Path
1818

19+
import pandas as pd
1920
import pytest
21+
from dask import dataframe as dd
22+
from dask.distributed import Client, LocalCluster
2023

24+
import nemo_curator as nc
25+
from nemo_curator.datasets import DocumentDataset
26+
from nemo_curator.filters import DocumentFilter
27+
from nemo_curator.modifiers import PiiModifier
2128
from nemo_curator.pii.algorithm import PiiDeidentifier
29+
from nemo_curator.utils.decorators import batched
2230

2331
LOGGER = logging.getLogger(__name__)
2432

@@ -118,3 +126,63 @@ def test_batch_accuracy(self):
118126
match = all(compare_outputs(x, y) for x, y in zip(outputs, targets))
119127
print("Matches:", "No" if not match else "Yes")
120128
assert match == True
129+
130+
131+
class BatchedLengthFilter(DocumentFilter):
132+
"""
133+
Keeps documents of a given length
134+
"""
135+
136+
def __init__(self, min_length=5, max_length=10):
137+
super().__init__()
138+
self.min_length = min_length
139+
self.max_length = max_length
140+
141+
@batched
142+
def score_document(self, df):
143+
return df.str.len()
144+
145+
@batched
146+
def keep_document(self, scores):
147+
min_threshold = self.min_length <= scores
148+
max_threshold = scores <= self.max_length
149+
return min_threshold & max_threshold
150+
151+
152+
class TestPIIModule:
153+
def test_filter_chain(self):
154+
inputs = [
155+
"Alice goes on a walk",
156+
"Bob goes on a walk",
157+
"Someone named Charlie goes on a walk",
158+
"A human walking is David",
159+
"A human walking is Eliza",
160+
]
161+
targets = [
162+
"***** goes on a walk",
163+
"*** goes on a walk",
164+
"A human walking is *****",
165+
"A human walking is *****",
166+
]
167+
input_df = pd.DataFrame({"text": inputs})
168+
target_df = pd.DataFrame({"text": targets})
169+
with LocalCluster(n_workers=1, threads_per_worker=1) as cluster:
170+
with Client(cluster):
171+
input_dataset = DocumentDataset(dd.from_pandas(input_df, npartitions=1))
172+
pipeline = nc.Sequential(
173+
[
174+
nc.ScoreFilter(
175+
BatchedLengthFilter(min_length=0, max_length=25)
176+
),
177+
nc.Modify(
178+
PiiModifier(
179+
language="en", anonymize_action="mask", device="cpu"
180+
)
181+
),
182+
]
183+
)
184+
output_dataset = pipeline(input_dataset)
185+
186+
output_df = output_dataset.df.compute().reset_index(drop=True)
187+
match = all(output_df["text"] == target_df["text"])
188+
assert match

0 commit comments

Comments
 (0)