Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

15 run inference on baskerville #30

Merged
merged 14 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Thumbs.db


slurm_scripts/slurm_logs*
slurm_scripts/experiments*
# other
temp
.vscode
Expand Down
2 changes: 1 addition & 1 deletion config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
OCR:
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

Expand Down
2 changes: 2 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ level: 1
lang_pair:
source: "fr"
target: "en"

drop_length: 1000
13 changes: 13 additions & 0 deletions config/experiment/baskerville_pipeline_inference_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42

bask:
jobname: "shortened_input_test"
walltime: '0-12:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
15 changes: 15 additions & 0 deletions config/experiment/full_experiment_zero_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
27 changes: 27 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,30 @@ It's called like so e.g. from project root:
```bash
python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator
```

## gen_jobscripts.py

Create jobscript `.sh` files for an experiment, which in this case refers to a `data_config` and `pipeline_config` combo.
It takes a single argument which is `experiment_config_path`. This refers to a file path to a `.yaml` file structured as below:

### eg. Experiment config:

```yaml
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"


```
82 changes: 82 additions & 0 deletions scripts/gen_jobscripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from pathlib import Path

from jinja2 import Environment, FileSystemLoader
from jsonargparse import CLI

from arc_spice.utils import open_yaml_path

PROJECT_DIR = Path(__file__, "..", "..").resolve()


def main(experiment_config_path: str):
"""
_summary_

Args:
experiment_config_path: _description_
"""
experiment_name = experiment_config_path.split("/")[-1].split(".")[0]
experiment_config = open_yaml_path(experiment_config_path)
pipeline_conf_dir = (
f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml"
)
data_conf_dir = (
f"{PROJECT_DIR}/config/data_configs/{experiment_config['data_config']}.yaml"
)
pipeline_config = open_yaml_path(pipeline_conf_dir)
# Get jinja template
environment = Environment(
loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config")
)
template = environment.get_template("jobscript_template.sh")
# We don't want to overwrite results

for index, seed in enumerate(experiment_config["seed"]):
os.makedirs(
f"slurm_scripts/experiments/{experiment_name}/run_{index}", exist_ok=False
)
for model in pipeline_config:
model_script_dict: dict = experiment_config["bask"]
model_script_dict.update(
{
"script_name": (
"scripts/single_component_inference.py "
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
f" {experiment_name} {model}"
),
"job_name": f"{experiment_name}_{model}",
"seed": seed,
}
)
model_train_script = template.render(model_script_dict)

with open(
f"slurm_scripts/experiments/{experiment_name}/run_{index}/{model}.sh",
"w",
) as f:
f.write(model_train_script)

pipeline_script_dict: dict = experiment_config["bask"]
pipeline_script_dict.update(
{
"script_name": (
"scripts/pipeline_inference.py "
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
f" {experiment_name}"
),
"job_name": f"{experiment_name}_full_pipeline",
"seed": seed,
}
)
pipeline_train_script = template.render(pipeline_script_dict)

with open(
f"slurm_scripts/experiments/{experiment_name}/run_{index}/full_pipeline.sh",
"w",
) as f:
f.write(pipeline_train_script)


if __name__ == "__main__":
CLI(main)
28 changes: 19 additions & 9 deletions scripts/pipeline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,43 @@

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.utils import open_yaml_path, seed_everything
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
RTCVariationalPipeline,
)

OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str):
def main(
pipeline_config_pth: str, data_config_pth: str, seed: int, experiment_name: str
):
"""
Run inference on a given pipeline with provided data config

Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
seed: seed for the the inference pass
experiment_name: name of experiment for saving purposes
"""
# create save directory -> fail if already exists
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"{experiment_name}/seed_{seed}/"
)
# This directory needs to exist for all 4 experiments
os.makedirs(save_loc, exist_ok=True)
# seed experiment
seed_everything(seed=seed)
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
rtc_variational_pipeline = RTCVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
Expand All @@ -37,11 +52,6 @@ def main(pipeline_config_pth: str, data_config_pth: str):
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}"
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/full_pipeline.json", "w") as save_file:
json.dump(test_results, save_file)

Expand Down
35 changes: 23 additions & 12 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
J-Dymond marked this conversation as resolved.
Show resolved Hide resolved
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.utils import open_yaml_path, seed_everything
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
ClassificationVariationalPipeline,
RecognitionVariationalPipeline,
Expand All @@ -28,19 +28,38 @@
OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
def main(
pipeline_config_pth: str,
data_config_pth: str,
seed: int,
experiment_name: str,
model_key: str,
):
"""
Run inference on a given pipeline component with provided data config and model key.

Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
seed: seed for the the inference pass
experiment_name: name of experiment for saving purposes
model_key: name of model on which to run inference
"""
# create save directory -> fail if already exists
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"{experiment_name}/seed_{seed}/"
)
# This directory needs to exist for all 4 experiments
os.makedirs(save_loc, exist_ok=True)
# seed experiment
seed_everything(seed=seed)
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
Expand Down Expand Up @@ -69,14 +88,6 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"single_component"
)
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/{model_key}.json", "w") as save_file:
json.dump(test_results, save_file)

Expand Down
28 changes: 28 additions & 0 deletions src/arc_spice/config/jobscript_template.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
#SBATCH --account vjgo8416-spice
#SBATCH --qos turing
#SBATCH --job-name {{ job_name }}
#SBATCH --time {{ walltime }}
#SBATCH --nodes {{ node_number }}
#SBATCH --gpus {{ gpu_number }}
#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out
#SBATCH --cpus-per-gpu 18


# Load required modules here
module purge
module load baskerville
module load bask-apps/live/live
module load Python/3.10.8-GCCcore-12.2.0


# change working directory
cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/

source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate

# change huggingface cache to be in project dir rather than user home
export HF_HOME="{{ hf_cache_dir }}"

# TODO: script uses relative path to project home so must be run from home, fix
python {{ script_name }}
13 changes: 12 additions & 1 deletion src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def load_multieurlex(
level: int,
languages: list[str],
drop_empty: bool = True,
drop_length: int | None = None,
split: str | None = None,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
"""
Expand Down Expand Up @@ -188,6 +189,11 @@ def load_multieurlex(
lambda x: all(x is not None for x in x["text"].values())
)

if drop_length:
dataset_dict = dataset_dict.filter(
lambda x: len(x["text"][languages[0]]) <= drop_length
)

# return datasets and metadata
return dataset_dict, metadata

Expand All @@ -197,11 +203,16 @@ def load_multieurlex_for_pipeline(
level: int,
lang_pair: dict[str, str],
drop_empty: bool = True,
drop_length: int | None = None,
load_ocr_data: bool = False,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
langs = [lang_pair["source"], lang_pair["target"]]
dataset_dict, meta_data = load_multieurlex(
data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty
data_dir=data_dir,
level=level,
languages=langs,
drop_empty=drop_empty,
drop_length=drop_length,
)
# instantiate the preprocessor
preprocesser = TranslationPreProcesser(lang_pair)
Expand Down
Loading
Loading