From 02d78eaa163307cfa9f5744171026c2f36f9e786 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 16 Apr 2025 17:44:18 -0700 Subject: [PATCH 1/3] Add hydra style overrides to SFT Signed-off-by: Hemil Desai --- examples/run_sft.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/run_sft.py b/examples/run_sft.py index 875aa9a000..20eceadb3a 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -28,7 +28,7 @@ from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log from nemo_reinforcer.distributed.virtual_cluster import init_ray -from nemo_reinforcer.utils.config import load_config +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides from nemo_reinforcer.utils.logger import get_next_experiment_dir @@ -40,10 +40,7 @@ def parse_args(): ) # Parse known args for the script - args, remaining = parser.parse_known_args() - - # Convert remaining args to OmegaConf format - overrides = OmegaConf.from_dotlist(remaining) + args, overrides = parser.parse_known_args() return args, overrides @@ -154,7 +151,7 @@ def main(): if overrides: print(f"Overrides: {overrides}") - config = OmegaConf.merge(config, overrides) + config = parse_hydra_overrides(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") From 7685a7fa8ab385dd45e8dceffa0f7552236e0c09 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 28 Apr 2025 12:31:44 -0700 Subject: [PATCH 2/3] renaming fix Signed-off-by: ashors1 --- examples/run_sft.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/run_sft.py b/examples/run_sft.py index 20eceadb3a..9377a32fcf 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -21,15 +21,15 @@ from omegaconf import OmegaConf from transformers import AutoTokenizer -from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup -from nemo_reinforcer.algorithms.utils import get_tokenizer -from nemo_reinforcer.data import DataConfig, hf_datasets -from nemo_reinforcer.data.datasets import AllTaskProcessedDataset -from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec -from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log -from nemo_reinforcer.distributed.virtual_cluster import init_ray -from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides -from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_rl.algorithms.sft import MasterConfig, sft_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig, hf_datasets +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec +from nemo_rl.data.llm_message_utils import get_formatted_message_log +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir def parse_args(): From 7c6dfe9189f9550ba1fddcfe8a6b5f5e43f429cf Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 28 Apr 2025 12:33:56 -0700 Subject: [PATCH 3/3] fix failing sft test Signed-off-by: ashors1 --- tests/functional/sft.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index f3474fb0fd..d107d1599f 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -29,7 +29,7 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ checkpointing.enabled=true \ - checkpointing.save_every_n_steps=10 \ + checkpointing.save_period=10 \ checkpointing.checkpoint_dir=/tmp/sft_checkpoints \ $@ \ 2>&1 | tee $RUN_LOG