Skip to content

Commit

Permalink
feat: update to latest sampling setup
Browse files Browse the repository at this point in the history
  • Loading branch information
MicPie committed Feb 8, 2024
1 parent 3c6521e commit 7314837
Showing 1 changed file with 22 additions and 42 deletions.
64 changes: 22 additions & 42 deletions data/text_sampling/text_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,7 @@ def __getitem__(self, sample_idx: int, template_idx: int = None):
sample = self.df.iloc[sample_idx]
return self.sample(sample, template_idx)

def apply_sampling(
self, template_idx: int = None, class_balanced: bool = True
): # TODO: set class_balanced to False !!!
def apply_sampling(self, template_idx: int = None, class_balanced: bool = True): ## todo: set class_balanced to False !!!
"""Applies the sampling to the entire data frame."""
if template_idx is not None and class_balanced is True:
# create a copy of the original self.df to restore self.df after class balanced sampling
Expand All @@ -925,10 +923,8 @@ def apply_sampling(
target_to_balance = []
for target in self.meta["targets"]:
for var in template.input_variables:
if (target["id"] in var.replace("#", "")) or (
target["id"] in var.replace("%", "")
):
# print(f"{target['id']=}")
if (target["id"] in var.replace("#", "")) or (target["id"] in var.replace("%", "")):
#print(f"{target['id']=}")
target_to_balance.append(target["id"])
target_to_balance = list(set(target_to_balance))

Expand All @@ -949,17 +945,13 @@ def apply_sampling(
dfs = []
# cycle through all values and get a sample of size vc_min
for values in df_vc.index.tolist():
dfs.append(
self.df_orig[self.df_orig[target_to_balance] == values].sample(
vc_min
)
)
dfs.append(self.df_orig[self.df_orig[target_to_balance] == values].sample(vc_min))
self.df = pd.concat(dfs)
else:
self.df = self.df_orig
print(self.df[target_to_balance].value_counts())
# else:
# assert template_idx is None and class_balanced is True, "class_balanced sampling is only supported with template_idx." # noqa: E501
#else:
# assert template_idx is None and class_balanced is True, "class_balanced sampling is only supported with template_idx."

self.df["sample"] = self.df.apply(
lambda sample: self.sample(sample, template_idx), axis=1
Expand Down Expand Up @@ -1096,10 +1088,7 @@ def export(self, fn_suffix: str = None):
return pd.DataFrame(print_data)

def apply_sampling_and_export(
self,
template_idx: int = None,
fn_suffix: str = None,
class_balanced=True,
self, template_idx: int = None, fn_suffix: str = None, class_balanced = True,
):
"""Applies the sampling and exports the data."""
self.apply_sampling(template_idx=template_idx, class_balanced=class_balanced)
Expand All @@ -1122,8 +1111,8 @@ def apply_sampling_and_export(

index = [i for i, x in enumerate(path_data_dir) if x.find("odd_one_out") != -1][0]
print(index)
# path_data_dir = path_data_dir[index:]
# path_data_dir = path_data_dir[index + 1 :]
#path_data_dir = path_data_dir[index:]
path_data_dir = path_data_dir[index+1:]
# path_data_dir = [path_data_dir[index]]
# path_data_dir = [
# '/weka/proj-chemnlp/micpie/chemnlp/data/tabular/ames_mutagenicity',
Expand Down Expand Up @@ -1178,37 +1167,28 @@ def apply_sampling_and_export(

for path in path_data_dir:
# subselect one path
# if path.find("data/kg/") == -1: continue
# if path.find("data/tabular/") == -1:
# continue
#if path.find("data/kg/") == -1: continue
if path.find("data/tabular/") == -1: continue
# if path.find("data/kg/") == -1: continue
# if path.find("chembl33") != -1: continue
# if path.find("data/kg/compound_chebi") == -1: continue
# if path.find("data/tabular/cyp3a4_substrate_carbonmangels") == -1: continue
# if path.find("data/tabular/bio_ner") == -1: continue
# if path.find("rdkit_features") != -1: continue

# exclude data_clean.csv files with more than 1GB
if path.find("rdkit_features") != -1:
continue
if path.find("iupac_smiles") != -1:
continue
if path.find("orbnet_denali") != -1:
continue
if path.find("ord_masked") != -1:
continue
if path.find("ord_predictions") != -1:
continue
if path.find("chembl_v29") != -1:
continue
if path.find("rdkit_features") != -1: continue
if path.find("iupac_smiles") != -1: continue
if path.find("orbnet_denali") != -1: continue
if path.find("ord_masked") != -1: continue
if path.find("ord_predictions") != -1: continue
if path.find("chembl_v29") != -1: continue

# needs fix
if path.find("bicerano_dataset") != -1:
continue
# if path.find("BACE") != -1: continue
# if path.find("BBBP") != -1: continue
# if path.find("ames_mutagenicity") == -1: continue
# if path.find("bio_ner") == -1: continue
if path.find("bicerano_dataset") != -1: continue
#if path.find("BACE") != -1: continue
#if path.find("BBBP") != -1: continue
#if path.find("ames_mutagenicity") == -1: continue
#if path.find("bio_ner") == -1: continue

print(f"\n###### {path}")
path_meta = path + "/meta.yaml"
Expand Down

0 comments on commit 7314837

Please sign in to comment.