Skip to content

Commit

Permalink
feat: draft class balanced sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
MicPie committed Feb 8, 2024
1 parent 6718fab commit f5e3f10
Showing 1 changed file with 91 additions and 8 deletions.
99 changes: 91 additions & 8 deletions data/text_sampling/text_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def check_targets_and_identifiers(meta: dict, df: pd.DataFrame):
if "split" not in df.columns:
df["split"] = "train"
self.df = df
self.df_orig = None # only used for class_balanced sampling to keep a copy of the original self.df

# text templates
self.benchmarking_templates = benchmarking_templates
Expand Down Expand Up @@ -906,12 +907,60 @@ def sample(self, sample: pd.Series, template_idx: int = None):

def __getitem__(self, sample_idx: int, template_idx: int = None):
"""Get item from data with sample and template index.
A random template will be ised if no template index is handed over."""
A random template will be used if no template index is handed over."""
sample = self.df.iloc[sample_idx]
return self.sample(sample, template_idx)

def apply_sampling(self, template_idx: int = None):
def apply_sampling(
self, template_idx: int = None, class_balanced: bool = True
): # TODO: set class_balanced to False !!!
"""Applies the sampling to the entire data frame."""
if template_idx is not None and class_balanced is True:
# create a copy of the original self.df to restore self.df after class balanced sampling
if self.df_orig is None:
self.df_orig = self.df.copy()

# get targets for balancing
template = self.get_prompt_template_from_template_idx(template_idx)
target_to_balance = []
for target in self.meta["targets"]:
for var in template.input_variables:
if (target["id"] in var.replace("#", "")) or (
target["id"] in var.replace("%", "")
):
# print(f"{target['id']=}")
target_to_balance.append(target["id"])
target_to_balance = list(set(target_to_balance))

# create class balanced self.df
if len(target_to_balance) > 1:
print("TEMPLATE USES MORE THAN ONE TARGET!")
print(f"{target_to_balance=}")
target_to_balance = random.sample(target_to_balance, k=1)[0]
print(f"{target_to_balance=}")
else:
# unwrap list of length 1
target_to_balance = target_to_balance[0]
df_vc = self.df_orig[target_to_balance].value_counts()
print(df_vc)
vc_min = df_vc.min()
vc_max = df_vc.max()
if vc_max > 1:
dfs = []
# cycle through all values and get a sample of size vc_min
for values in df_vc.index.tolist():
dfs.append(
self.df_orig[self.df_orig[target_to_balance] == values].sample(
vc_min
)
)
self.df = pd.concat(dfs)
else:
self.df = self.df_orig
print(self.df[target_to_balance].value_counts())
# else:
# assert template_idx is None and class_balanced is True, "class_balanced sampling is only supported with template_idx." # noqa: E501

self.df["sample"] = self.df.apply(
lambda sample: self.sample(sample, template_idx), axis=1
)
Expand Down Expand Up @@ -1047,11 +1096,19 @@ def export(self, fn_suffix: str = None):
return pd.DataFrame(print_data)

def apply_sampling_and_export(
self, template_idx: int = None, fn_suffix: str = None
self,
template_idx: int = None,
fn_suffix: str = None,
class_balanced=True,
):
"""Applies the sampling and exports the data."""
self.apply_sampling(template_idx=template_idx)
self.apply_sampling(template_idx=template_idx, class_balanced=class_balanced)
df_results = self.export(fn_suffix=fn_suffix)

# if class_balanced restore self.df to original df that is not balanced
if class_balanced:
self.df = self.df_orig

print(f"\n### results\n{df_results.to_string()}")


Expand All @@ -1063,12 +1120,15 @@ def apply_sampling_and_export(
)
path_lm_eval_data_dir = path_base + "text_sampling/export"

# index = [i for i, x in enumerate(path_data_dir) if x.find("RedDB") != -1][0]
# print(index)
index = [i for i, x in enumerate(path_data_dir) if x.find("odd_one_out") != -1][0]
print(index)
# path_data_dir = path_data_dir[index:]
# path_data_dir = path_data_dir[index + 1 :]
# path_data_dir = [path_data_dir[index]]
# path_data_dir = [
# '/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ames_mutagenicity',
# '/weka/proj-chemnlp/micpie/chemnlp/data/tabular/bioavailability_ma_et_al',
# '/weka/proj-chemnlp/micpie/chemnlp/data/tabular/caco2_wang',
# '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/RedDB',
# '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/SIDER',
# '/fsx/proj-chemnlp/micpie/chemnlp/data/tabular/aminoacids',
Expand Down Expand Up @@ -1118,13 +1178,36 @@ def apply_sampling_and_export(

for path in path_data_dir:
# subselect one path
# if path.find("data/tabular/") == -1: continue
# if path.find("data/kg/") == -1: continue
# if path.find("data/tabular/") == -1:
# continue
# if path.find("data/kg/") == -1: continue
# if path.find("chembl33") != -1: continue
# if path.find("data/kg/compound_chebi") == -1: continue
# if path.find("data/tabular/cyp3a4_substrate_carbonmangels") == -1: continue
# if path.find("data/tabular/bio_ner") == -1: continue
# if path.find("rdkit_features") != -1: continue

# exclude data_clean.csv files with more than 1GB
if path.find("rdkit_features") != -1:
continue
if path.find("iupac_smiles") != -1:
continue
if path.find("orbnet_denali") != -1:
continue
if path.find("ord_masked") != -1:
continue
if path.find("ord_predictions") != -1:
continue
if path.find("chembl_v29") != -1:
continue

# needs fix
if path.find("bicerano_dataset") != -1:
continue
# if path.find("BACE") != -1: continue
# if path.find("BBBP") != -1: continue
# if path.find("ames_mutagenicity") == -1: continue
# if path.find("bio_ner") == -1: continue

print(f"\n###### {path}")
path_meta = path + "/meta.yaml"
Expand Down

0 comments on commit f5e3f10

Please sign in to comment.