-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add codepath for computing buckets without int conversion #326
base: main
Are you sure you want to change the base?
Changes from all commits
ccb1e31
f2b1888
816940b
30f383c
d7a2617
954a043
3b51aad
d119740
8dbc48a
dccd964
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str | ||
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( | ||
aggregated_anchor_docs_with_bk_read, | ||
check_empty_buckets, | ||
get_restart_offsets, | ||
update_restart_offsets, | ||
) | ||
|
@@ -261,6 +262,7 @@ def __init__( | |
num_hashes: int, | ||
num_buckets: int, | ||
buckets_per_shuffle: int = 1, | ||
buckets_as_int: bool = False, | ||
logger: Union[logging.LoggerAdapter, str] = "./", | ||
id_fields: Union[str, list] = "id", | ||
minhash_field: str = "_minhash_signature", | ||
|
@@ -291,6 +293,7 @@ def __init__( | |
self.bucket_ranges = self._generate_bucket_ranges( | ||
self.num_buckets, self.num_hashes | ||
) | ||
self.buckets_as_int = buckets_as_int | ||
|
||
if cache_dir is None: | ||
raise ValueError( | ||
|
@@ -379,10 +382,19 @@ def lsh( | |
self, | ||
write_path: str, | ||
df: dask_cudf.DataFrame, | ||
) -> None: | ||
) -> bool: | ||
""" | ||
Computes buckets and writes them as parquet files to the write_path | ||
Computes hash buckets for the DataFrame and writes them as parquet files to the specified path. | ||
|
||
Parameters: | ||
- write_path (str): The directory path to write parquet files. | ||
- df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed. | ||
Returns: | ||
are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise. | ||
""" | ||
wrote_buckets = False | ||
are_buckets_empty = True | ||
|
||
meta = self._minhash_to_bucket_meta(df) | ||
df = df.map_partitions( | ||
self.minhash_to_buckets, | ||
|
@@ -391,12 +403,14 @@ def lsh( | |
) | ||
bucket_start_id = 0 | ||
for i in range(0, self.num_buckets, self.buckets_per_shuffle): | ||
value_vars = [ | ||
bucket_columns = [ | ||
f"_bucket_{i}" | ||
for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle)) | ||
] | ||
df2 = df.melt( | ||
id_vars=self.id_fields, value_name="_bucket_id", value_vars=value_vars | ||
id_vars=self.id_fields, | ||
value_name="_bucket_id", | ||
value_vars=bucket_columns, | ||
)[self.id_fields + ["_bucket_id"]] | ||
|
||
df2 = df2.shuffle( | ||
|
@@ -406,40 +420,90 @@ def lsh( | |
).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)]) | ||
|
||
df2 = df2.reset_index(drop=True) | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
continue | ||
bucket_start_id = end_id + 1 | ||
# Buckets to Int | ||
if self.buckets_as_int: | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
self._logger.info( | ||
f"No duplicate documents found for buckets: {bucket_columns}" | ||
) | ||
continue | ||
bucket_start_id = end_id + 1 | ||
are_buckets_empty = False | ||
|
||
# Workaround for dtype mismatches with empty partitions | ||
dtypes = df2.dtypes.to_dict() | ||
df2 = df2.map_partitions(lambda x: x.astype(dtypes)) | ||
# dtypes = df2.dtypes.to_dict() | ||
# df2 = df2.map_partitions(lambda x: x.astype(dtypes)) | ||
wrote_buckets, are_buckets_empty = self._write_bucket_parquet( | ||
df2, | ||
write_path, | ||
wrote_buckets, | ||
are_buckets_empty, | ||
bucket_columns, | ||
) | ||
|
||
if i == 0: | ||
if os.path.exists(write_path): | ||
warnings.warn( | ||
f"Output path {write_path} already exists and will be overwritten" | ||
) | ||
df2.to_parquet(write_path, write_index=False, overwrite=True) | ||
else: | ||
df2.to_parquet(write_path, write_index=False, append=True) | ||
if are_buckets_empty: | ||
self._logger.info("No duplicate documents found during LSH") | ||
if os.path.exists(write_path): | ||
import shutil | ||
|
||
shutil.rmtree(write_path) | ||
|
||
return are_buckets_empty | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Variable for tracking if all the buckets were empty |
||
|
||
def _write_bucket_parquet( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reviewers ptal at this logic. I've tried to cover most edge cases There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only case I could think was if we ever have to worry about scalability here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a non-zero cost to checking if the buckets are empty or not. I've tried to write Once a non empty bucket is detected, that setting is persisted through the next set of iterations so the check is skipped in future iterations. |
||
self, | ||
df: dask_cudf.DataFrame, | ||
write_path: str, | ||
wrote_buckets: bool, | ||
are_buckets_empty: bool, | ||
buckets_to_write: List[str], | ||
) -> tuple[bool, bool]: | ||
""" | ||
Utility function to write the bucketed data to parquet | ||
handling cases of overwriting and appending as needed. | ||
""" | ||
if not wrote_buckets: | ||
if os.path.exists(write_path): | ||
warnings.warn( | ||
f"Output path {write_path} already exists and will be overwritten" | ||
) | ||
df.to_parquet(write_path, write_index=False, overwrite=True) | ||
else: | ||
df.to_parquet( | ||
write_path, | ||
write_index=False, | ||
overwrite=are_buckets_empty, | ||
append=not are_buckets_empty, | ||
) | ||
# Only check if buckets written so far are empty | ||
if are_buckets_empty: | ||
are_buckets_empty = check_empty_buckets(write_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we need to do this in the first place is because there's no way to know if we're writing out an empty dataframe or not, unless we persist, or write it out, check the metadata and then overwrite on the next iteration. |
||
wrote_buckets = True | ||
|
||
self._logger.info(f"Wrote data for buckets: {value_vars}") | ||
if are_buckets_empty: | ||
self._logger.info( | ||
f"No duplicate documents found for buckets: {buckets_to_write}" | ||
) | ||
else: | ||
self._logger.info(f"Wrote data for buckets: {buckets_to_write}") | ||
return wrote_buckets, are_buckets_empty | ||
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
df = dataset.df | ||
|
||
write_path = os.path.join(self.cache_dir, "_buckets.parquet") | ||
t0 = time.time() | ||
with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"): | ||
self.lsh(write_path=write_path, df=df) | ||
empty_result = self.lsh(write_path=write_path, df=df) | ||
self._logger.info( | ||
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}" | ||
) | ||
|
||
if empty_result: | ||
return None | ||
buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False) | ||
return DocumentDataset(buckets_df) | ||
|
||
|
@@ -488,6 +552,8 @@ def __init__( | |
num_hashes=self.config.num_hashes, | ||
num_buckets=self.config.num_buckets, | ||
buckets_per_shuffle=self.config.buckets_per_shuffle, | ||
# Only convert buckets to int if we are running false positive check | ||
buckets_as_int=self.config.false_positive_check, | ||
logger=self._logger, | ||
id_fields=[self.config.id_field], | ||
profile_dir=self.config.profile_dir, | ||
|
@@ -556,6 +622,11 @@ def __call__(self, dataset: DocumentDataset): | |
minhashLSH = Sequential([self.minhash, self.lsh]) | ||
buckets_df = minhashLSH(dataset) | ||
print(f"Stage{stage_num}: Minhash + LSH complete!") | ||
if buckets_df is None: | ||
print( | ||
f"Stage{stage_num}: No potential duplicate documents found during LSH" | ||
) | ||
return None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return None or an empty There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, but then for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen |
||
stage_num += 1 | ||
|
||
if self.config.false_positive_check: | ||
|
@@ -740,6 +811,7 @@ def buckets_to_edges( | |
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
buckets_df = dataset.df | ||
self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist") | ||
if len(self.id_fields) > 1: | ||
buckets_df = buckets_df.map_partitions( | ||
BucketsToEdges._combine_multiple_ids, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,3 +202,16 @@ def strip_trailing_sep(path: str): | |
Strips a path string of trailing path seperators like `/` if any. | ||
""" | ||
return path.rstrip(os.path.sep) | ||
|
||
|
||
def check_empty_buckets(bucket_path): | ||
""" | ||
Inspects parquet metadata of the buckets dataset to check if it's an empty dataset. | ||
""" | ||
from pyarrow.dataset import dataset | ||
|
||
ds = dataset(bucket_path, format="parquet") | ||
for fragment in ds.get_fragments(): | ||
if fragment.metadata.num_rows > 0: | ||
return False | ||
Comment on lines
+213
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic can probably be simplified by using a global metadata file when writing out the parquet dataset |
||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about calling this
false_positive_check
on the user facing side? I'm fine with then doing something likeself.buckets_as_int = false_positive_check
and referring to it asself.buckets_as_int
everywhere else, but from a user perspective I think it might make it a little clearer about how to set this parameter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a good suggestion. We can update the docstrings to indicate that it writes out data in a format required by false positive_check if set to true.