diff --git a/src/chemnlp/data/sampler.py b/src/chemnlp/data/sampler.py index 34a5d50e1..bbea75dfb 100644 --- a/src/chemnlp/data/sampler.py +++ b/src/chemnlp/data/sampler.py @@ -13,6 +13,7 @@ import yaml import json from loguru import logger +from tqdm import tqdm class TemplateSampler: @@ -836,8 +837,11 @@ def export(self, output_dir: str, template: str) -> pd.DataFrame: ) for split in self.df["split"].unique(): df_split = self.df[self.df["split"] == split] - samples = [self.sample(row, template) for _, row in df_split.iterrows()] - + samples = [] + for _, row in tqdm(df_split.iterrows(), total=len(df_split)): + sample_dict = row.to_dict() + sample = self._fill_template(template, sample_dict) + samples.append(sample) df_out = pd.DataFrame(samples) # if self.benchmarking_templates: