Skip to content

Commit eee5065

Browse files
VibhuJawasarahyurick
authored andcommitted
Add tests and fix bugs found during testing (NVIDIA#151)
Signed-off-by: Vibhu Jawa <[email protected]>
1 parent 0cbe447 commit eee5065

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

nemo_curator/modules/semantic_dedup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,8 @@ def __init__(
525525
cache_dir = config.cache_dir
526526
self.embedding_creator = EmbeddingCreator(
527527
embedding_model_name_or_path=config.embedding_model_name_or_path,
528-
max_memory=config.embedding_max_mem_gb,
529-
batch_size=config.embedding_batch_size,
528+
embedding_max_mem_gb=config.embedding_max_mem_gb,
529+
embedding_batch_size=config.embedding_batch_size,
530530
input_column=config.input_column,
531531
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
532532
logger=logger,

nemo_curator/utils/semdedup_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def rank_within_cluster(
157157
sort_descending = keep_hard
158158
cluster_sorted = sorted(
159159
zip(example_id, cluster_dists_to_cent, cluster_label),
160-
key=lambda x: x[2],
160+
key=lambda x: x[1],
161161
reverse=sort_descending,
162162
) # -- sort_descending = True for descending sort
163163

tests/test_semdedup.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import pytest
17+
18+
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False"
19+
from dask.dataframe.utils import assert_eq
20+
from distributed import Client
21+
22+
from nemo_curator import SemDedup, SemDedupConfig
23+
from nemo_curator.datasets import DocumentDataset
24+
from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from
25+
26+
cudf = gpu_only_import("cudf")
27+
dask_cudf = gpu_only_import("dask_cudf")
28+
LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster")
29+
30+
31+
@pytest.fixture
32+
def dedup_data():
33+
df = cudf.DataFrame(
34+
{
35+
"id": [1, 2, 3, 4, 100, 200, 300],
36+
"text": [
37+
"The quick brown fox jumps over the lazy dog",
38+
"The quick brown foxes jumps over the lazy dog",
39+
"The quick brown wolf jumps over the lazy dog",
40+
"The quick black cat jumps over the lazy dog",
41+
"A test string",
42+
"Another test string",
43+
"A different object",
44+
],
45+
}
46+
)
47+
df = dask_cudf.from_cudf(df, 2)
48+
return DocumentDataset(df)
49+
50+
51+
@pytest.mark.gpu
52+
class TestFuzzyDuplicates:
53+
@pytest.fixture(autouse=True, scope="class")
54+
def gpu_client(self, request):
55+
with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client:
56+
request.cls.client = client
57+
request.cls.cluster = cluster
58+
yield
59+
60+
def test_fuzzy_dedup(
61+
self,
62+
dedup_data,
63+
tmpdir,
64+
):
65+
print("client", self.client)
66+
cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache")
67+
config = SemDedupConfig(
68+
cache_dir=cache_dir,
69+
id_col_name="id",
70+
id_col_type="int",
71+
input_column="text",
72+
seed=42,
73+
n_clusters=3,
74+
eps_thresholds=[0.10],
75+
eps_to_extract=0.10,
76+
)
77+
sem_duplicates = SemDedup(config=config)
78+
result = sem_duplicates(dedup_data)
79+
result_df = result.df.compute()
80+
duplicate_docs = [2, 3, 4, 200, 300]
81+
expected_df = cudf.Series(duplicate_docs, name="id")
82+
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

0 commit comments

Comments
 (0)