Skip to content

Commit b192e92

Browse files
ryantwolfnicoleeeluo
authored andcommitted
Fix lang id example (NVIDIA#37)
* Fix lang id example Signed-off-by: Ryan Wolf <[email protected]> * Add classifier unit tests Signed-off-by: Ryan Wolf <[email protected]> * Add test for failure Signed-off-by: Ryan Wolf <[email protected]> * Remove failure test Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]> Signed-off-by: Nicole Luo <[email protected]>
1 parent f2b3904 commit b192e92

File tree

3 files changed

+78
-2
lines changed

3 files changed

+78
-2
lines changed

examples/identify_languages_and_fix_unicode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main(args):
6060

6161
# Remove the language score
6262
filtered_dataset.df[language_field] = filtered_dataset.df[language_field].apply(
63-
lambda score: score[1]
63+
lambda score: score[1], meta=(None, str)
6464
)
6565

6666
# Split the dataset by language

nemo_curator/filters/classifier_filter.py

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dask
1516
import fasttext
1617
import numpy as np
1718
import pandas as pd
@@ -75,6 +76,11 @@ def __init__(self, model_path=None, min_langid_score=0.3):
7576
self._cutoff = min_langid_score
7677
self._name = "lang_id"
7778

79+
# Dask will automatically convert the list score type
80+
# to a string without this option.
81+
# See https://github.com/NVIDIA/NeMo-Curator/issues/33
82+
dask.config.set({"dataframe.convert-string": False})
83+
7884
@batched
7985
def score_document(self, df):
8086
model_attr = f"{self._name}_{self._model_path}"

tests/test_filters.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616

17+
import dask
18+
import numpy as np
1719
import pandas as pd
1820
import pytest
1921
from dask import dataframe as dd
@@ -508,7 +510,7 @@ def test_repeatedparagraphschar(self):
508510
def test_repeatingtopngrams(self):
509511
dataset = list_to_dataset(
510512
[
511-
"this is a totally fine sentence with no repeating ngrams so we are ok",
513+
"this is a totally fine sentence with no repeat ngrams so we are ok",
512514
"a b . a b",
513515
"a a a a a a",
514516
"totally fine small dupe a b a b",
@@ -756,3 +758,71 @@ def test_per_extension_filter(self):
756758
assert all_equal(
757759
expected_data, filtered_data
758760
), f"Expected {expected_data} but got {filtered_data}"
761+
762+
763+
class FakeQualityFilter(DocumentFilter):
764+
"""
765+
Emulates FastTextQualityFilter without a model
766+
"""
767+
768+
def __init__(self, alpha=3, seed=42):
769+
super().__init__()
770+
self._alpha = alpha
771+
self._seed = np.random.seed(seed)
772+
773+
@batched
774+
def score_document(self, df):
775+
return pd.Series(np.arange(len(df)) / len(df))
776+
777+
@batched
778+
def keep_document(self, df):
779+
return np.random.pareto(self._alpha, size=len(df)) > 1 - df
780+
781+
782+
class FakeLangId(DocumentFilter):
783+
"""
784+
Emulates FastTextLangId without a model
785+
"""
786+
787+
def __init__(self, min_langid_score=0.3, convert_string=False):
788+
super().__init__()
789+
self._cutoff = min_langid_score
790+
791+
# Dask will automatically convert the list score type
792+
# to a string without this option.
793+
# See https://github.com/NVIDIA/NeMo-Curator/issues/33
794+
dask.config.set({"dataframe.convert-string": convert_string})
795+
796+
@batched
797+
def score_document(self, df):
798+
scores = [[0.5, "EN"], [0.7, "HI"], [0.2, "PT"]]
799+
scores = scores * len(df)
800+
scores = scores[: len(df)]
801+
return pd.Series(scores)
802+
803+
def keep_document(self, score):
804+
return score[0] >= self._cutoff
805+
806+
807+
class TestClassifierFilters:
808+
def test_fake_quality_filter(self):
809+
dataset = list_to_dataset(["a", "b", "c", "d"], npartitions=1)
810+
filters = ScoreFilter(FakeQualityFilter())
811+
filtered_data = filters(dataset)
812+
813+
expected_indices = [1, 2, 3]
814+
expected_data = DocumentDataset(dataset.df.loc[expected_indices])
815+
assert all_equal(
816+
expected_data, filtered_data
817+
), f"Expected {expected_data} but got {filtered_data}"
818+
819+
def test_fake_langid_filter(self):
820+
dataset = list_to_dataset(["a", "b", "c", "d"], npartitions=1)
821+
filters = ScoreFilter(FakeLangId())
822+
filtered_data = filters(dataset)
823+
824+
expected_indices = [0, 1, 3]
825+
expected_data = DocumentDataset(dataset.df.loc[expected_indices])
826+
assert all_equal(
827+
expected_data, filtered_data
828+
), f"Expected {expected_data} but got {filtered_data}"

0 commit comments

Comments
 (0)