Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New fixes #522

Merged
merged 33 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c880d91
feat: add data/check_pandas.py
MicPie Jan 29, 2024
30f160c
add data/check_smiles_split.py
MicPie Jan 29, 2024
3e6e541
update kg meta.yaml files
MicPie Jan 29, 2024
f8e8f65
add data/natural
MicPie Jan 29, 2024
e837fea
add dataset scripts
MicPie Jan 29, 2024
4df139f
update data/text_sampling/
MicPie Jan 29, 2024
019c91a
update meta and transform
MicPie Jan 29, 2024
fafc650
additional fixes
MicPie Jan 29, 2024
66dd58a
apply pre-commit hook
MicPie Jan 29, 2024
625d6eb
sort exclude_from_standard_tabular_text_templates
MicPie Jan 29, 2024
896c055
more fixes (#517)
kjappelbaum Jan 31, 2024
73c197c
additional fixes
MicPie Jan 31, 2024
96067cf
Update data/check_pandas.py
MicPie Feb 1, 2024
5807e44
Update data/check_pandas.py
MicPie Feb 1, 2024
68b20dc
Update data/check_smiles_split.py
MicPie Feb 1, 2024
171b945
Update data/natural/preprocess_europepmc.py
MicPie Feb 1, 2024
2294ad2
Update data/natural/preprocess_msds.py
MicPie Feb 1, 2024
981e7de
Update data/natural/preprocess_nougat.py
MicPie Feb 1, 2024
361291f
Update data/postprocess_split.py
MicPie Feb 1, 2024
90a1ebd
additional fixes 2
MicPie Feb 5, 2024
747c9bc
additional fixes 3
MicPie Feb 5, 2024
642d8b2
additional fixes 4
MicPie Feb 5, 2024
f7a2ac8
additional fixes 5
MicPie Feb 5, 2024
33054f5
additional fixes 6
MicPie Feb 5, 2024
df19548
additional fixes 7
MicPie Feb 5, 2024
1d101a2
additional fixes 8
MicPie Feb 5, 2024
3f6920a
remove linebreak
kjappelbaum Feb 5, 2024
fcf7706
remove linebreak
kjappelbaum Feb 5, 2024
5b8ef87
Delete data/tabular/bicerano_dataset/meta.yaml
kjappelbaum Feb 5, 2024
6718fab
feat: update yamls
MicPie Feb 6, 2024
6054934
Update data/text_sampling/preprocess_kg.py
MicPie Feb 6, 2024
ad1dd94
Update data/text_sampling/preprocess_kg.py
MicPie Feb 6, 2024
ac9c8c0
Update data/text_sampling/preprocess_kg.py
MicPie Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions data/check_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
This check performs a basic check for data leakage. The checks in this script only focus on SMILES.
Train/test split needs to be run before running this script.
This script assumes that `test_smiles.txt` and `val_smiles.txt` exist in the current working directory.

If leakage is detected, an `AssertionError` will be thrown.

This script has a command line interface. You can run it using `python check_pandas <data_dir>`,
where `<data_dir>` points to a nested set of directories with `data_clean.csv` files.
"""
import os
MicPie marked this conversation as resolved.
Show resolved Hide resolved
from glob import glob
from pathlib import Path

import fire
import pandas as pd
from pandarallel import pandarallel
from tqdm import tqdm

pandarallel.initialize(progress_bar=False)

with open("test_smiles.txt", "r") as f:
test_smiles_ref = f.readlines()
test_smiles_ref = [x.strip() for x in test_smiles_ref]

with open("val_smiles.txt", "r") as f:
valid_smiles_ref = f.readlines()
valid_smiles_ref = [x.strip() for x in valid_smiles_ref]
kjappelbaum marked this conversation as resolved.
Show resolved Hide resolved


def leakage_check(file, outdir="out"):
# mirror subdir structures in outdir
if not os.path.exists(outdir):
os.makedirs(outdir)
print(f"Checking {file}")
df = pd.read_csv(file, low_memory=False)
print(df["split"].value_counts())
train_smiles = df[df["split"] == "train"]["SMILES"].to_list()
train_smiles = set(train_smiles)
test_smiles = df[df["split"] == "test"]["SMILES"].to_list()
test_smiles = set(test_smiles)
valid_smiles = df[df["split"] == "valid"]["SMILES"].to_list()
valid_smiles = set(valid_smiles)

try:
assert (
len(train_smiles.intersection(test_smiles)) == 0
), "Smiles in train and test"
assert (
len(train_smiles.intersection(valid_smiles)) == 0
), "Smiles in train and valid"
assert (
len(test_smiles.intersection(valid_smiles)) == 0
), "Smiles in test and valid"
except AssertionError as e:
path = os.path.join(outdir, Path(file).parts[-2], Path(file).name)
print(f"Leakage in {file}: {e}. Fixing... {path}")
is_in_test = df["SMILES"].isin(test_smiles)
is_in_val = df["SMILES"].isin(valid_smiles)

df.loc[is_in_test, "split"] = "test"
df.loc[is_in_val, "split"] = "valid"

os.makedirs(os.path.dirname(path), exist_ok=True)
df.to_csv(path, index=False)
print(f"Saved fixed file to {path}")
print("Checking fixed file...")
leakage_check(path, outdir)

try:
assert (
len(train_smiles.intersection(test_smiles_ref)) == 0
), "Smiles in train and scaffold test"

assert (
len(train_smiles.intersection(valid_smiles_ref)) == 0
), "Smiles in train and scaffold valid"

assert (
len(test_smiles.intersection(valid_smiles_ref)) == 0
), "Smiles in test and scaffold valid"
except AssertionError as e:
path = os.path.join(outdir, Path(file).parts[-2], Path(file).name)
print(f"Leakage in {file}: {e}. Fixing... {path}")
is_in_test = df["SMILES"].isin(test_smiles)
is_in_val = df["SMILES"].isin(valid_smiles)

df.loc[is_in_test, "split"] = "test"
df.loc[is_in_val, "split"] = "valid"

test_smiles = df[df["split"] == "test"]["SMILES"].to_list()
test_smiles = set(test_smiles)

valid_smiles = df[df["split"] == "valid"]["SMILES"].to_list()
valid_smiles = set(valid_smiles)

is_in_test = df["SMILES"].isin(test_smiles)
is_in_val = df["SMILES"].isin(valid_smiles)

df.loc[is_in_test, "split"] = "test"
df.loc[is_in_val, "split"] = "valid"

path = os.path.join(outdir, Path(file).parts[-2], Path(file).name)
os.makedirs(os.path.dirname(path), exist_ok=True)
df.to_csv(path, index=False)
print(f"Saved fixed file to {path}")
print("Checking fixed file...")
leakage_check(path, outdir)

print(f"No leakage in {file}")
with open("leakage_check.txt", "a") as f:
f.write(f"No leakage in {file}\n")
f.write(f"train: {len(train_smiles)}\n")
f.write(f"test: {len(test_smiles)}\n")
f.write(f"valid: {len(valid_smiles)}\n")
return True


def check_all_files(data_dir):
all_csv_files = glob(os.path.join(data_dir, "**", "**", "data_clean.csv"))
for csv_file in tqdm(all_csv_files):
if Path(csv_file).parts[-2] not in [
"odd_one_out",
"uniprot_binding_single",
"uniprot_binding_sites_multiple",
"uniprot_organisms",
"uniprot_reactions",
"uniprot_sentences",
"fda_adverse_reactions",
"drugchat_liang_zhang_et_al",
"herg_central",
# those files were checked manually
]:
# if filesize < 35 GB:
if os.path.getsize(csv_file) < 35 * 1024 * 1024 * 1024:
try:
leakage_check(csv_file)
except Exception as e:
print(f"Could not process {csv_file}: {e}")
else:
print(f"Skipping {csv_file} due to size")


if __name__ == "__main__":
fire.Fire(check_all_files)
Loading
Loading