Skip to content

Commit

Permalink
Merge branch 'main' into bm_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Aug 12, 2024
2 parents 7728a64 + 1a844a7 commit c92959e
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 163 deletions.
9 changes: 6 additions & 3 deletions data/postprocess_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
It also merges files that have been created by `dask` if they are chunks of one large dataset.
This script needs to be run after the splitting script.
An independent check (that does not rewrite files is `check_smiles_split.py`;
this checks also for compliance with the predetermined files)
"""

import os
Expand Down Expand Up @@ -143,7 +146,7 @@ def process_file(file: Union[str, Path], id_cols):
# appear multiple times
ddf = read_ddf(file)
ddf = ddf.drop_duplicates(subset=id_cols)
ddf.to_csv("data_clean-{*}.csv", index=False)
ddf.to_csv(os.path.join(dir, "data_clean-{*}.csv"), index=False)
merge_files(dir)

else:
Expand All @@ -154,7 +157,7 @@ def process_file(file: Union[str, Path], id_cols):
for id in id_cols:
test_smiles.extend(df[df["split"] == "test"][id].to_list())
val_smiles.extend(df[df["split"] == "valid"][id].to_list())

df.drop_duplicates(subset=[id], inplace=True)
test_smiles = set(test_smiles)
val_smiles = set(val_smiles)

Expand Down Expand Up @@ -184,7 +187,7 @@ def process_file(file: Union[str, Path], id_cols):
len(this_test_smiles.intersection(this_val_smiles)) == 0
), f"Smiles in test and valid for {id}"

df.to_csv("data_clean.csv", index=False)
df.to_csv(os.path.join(dir, "data_clean.csv"), index=False)


def process_all_files(data_dir):
Expand Down
Loading

0 comments on commit c92959e

Please sign in to comment.