|
14 | 14 |
|
15 | 15 | import os
|
16 | 16 |
|
| 17 | +import dask |
| 18 | +import numpy as np |
17 | 19 | import pandas as pd
|
18 | 20 | import pytest
|
19 | 21 | from dask import dataframe as dd
|
@@ -508,7 +510,7 @@ def test_repeatedparagraphschar(self):
|
508 | 510 | def test_repeatingtopngrams(self):
|
509 | 511 | dataset = list_to_dataset(
|
510 | 512 | [
|
511 |
| - "this is a totally fine sentence with no repeating ngrams so we are ok", |
| 513 | + "this is a totally fine sentence with no repeat ngrams so we are ok", |
512 | 514 | "a b . a b",
|
513 | 515 | "a a a a a a",
|
514 | 516 | "totally fine small dupe a b a b",
|
@@ -756,3 +758,71 @@ def test_per_extension_filter(self):
|
756 | 758 | assert all_equal(
|
757 | 759 | expected_data, filtered_data
|
758 | 760 | ), f"Expected {expected_data} but got {filtered_data}"
|
| 761 | + |
| 762 | + |
| 763 | +class FakeQualityFilter(DocumentFilter): |
| 764 | + """ |
| 765 | + Emulates FastTextQualityFilter without a model |
| 766 | + """ |
| 767 | + |
| 768 | + def __init__(self, alpha=3, seed=42): |
| 769 | + super().__init__() |
| 770 | + self._alpha = alpha |
| 771 | + self._seed = np.random.seed(seed) |
| 772 | + |
| 773 | + @batched |
| 774 | + def score_document(self, df): |
| 775 | + return pd.Series(np.arange(len(df)) / len(df)) |
| 776 | + |
| 777 | + @batched |
| 778 | + def keep_document(self, df): |
| 779 | + return np.random.pareto(self._alpha, size=len(df)) > 1 - df |
| 780 | + |
| 781 | + |
| 782 | +class FakeLangId(DocumentFilter): |
| 783 | + """ |
| 784 | + Emulates FastTextLangId without a model |
| 785 | + """ |
| 786 | + |
| 787 | + def __init__(self, min_langid_score=0.3, convert_string=False): |
| 788 | + super().__init__() |
| 789 | + self._cutoff = min_langid_score |
| 790 | + |
| 791 | + # Dask will automatically convert the list score type |
| 792 | + # to a string without this option. |
| 793 | + # See https://github.com/NVIDIA/NeMo-Curator/issues/33 |
| 794 | + dask.config.set({"dataframe.convert-string": convert_string}) |
| 795 | + |
| 796 | + @batched |
| 797 | + def score_document(self, df): |
| 798 | + scores = [[0.5, "EN"], [0.7, "HI"], [0.2, "PT"]] |
| 799 | + scores = scores * len(df) |
| 800 | + scores = scores[: len(df)] |
| 801 | + return pd.Series(scores) |
| 802 | + |
| 803 | + def keep_document(self, score): |
| 804 | + return score[0] >= self._cutoff |
| 805 | + |
| 806 | + |
| 807 | +class TestClassifierFilters: |
| 808 | + def test_fake_quality_filter(self): |
| 809 | + dataset = list_to_dataset(["a", "b", "c", "d"], npartitions=1) |
| 810 | + filters = ScoreFilter(FakeQualityFilter()) |
| 811 | + filtered_data = filters(dataset) |
| 812 | + |
| 813 | + expected_indices = [1, 2, 3] |
| 814 | + expected_data = DocumentDataset(dataset.df.loc[expected_indices]) |
| 815 | + assert all_equal( |
| 816 | + expected_data, filtered_data |
| 817 | + ), f"Expected {expected_data} but got {filtered_data}" |
| 818 | + |
| 819 | + def test_fake_langid_filter(self): |
| 820 | + dataset = list_to_dataset(["a", "b", "c", "d"], npartitions=1) |
| 821 | + filters = ScoreFilter(FakeLangId()) |
| 822 | + filtered_data = filters(dataset) |
| 823 | + |
| 824 | + expected_indices = [0, 1, 3] |
| 825 | + expected_data = DocumentDataset(dataset.df.loc[expected_indices]) |
| 826 | + assert all_equal( |
| 827 | + expected_data, filtered_data |
| 828 | + ), f"Expected {expected_data} but got {filtered_data}" |
0 commit comments