diff --git a/examples/run_sft.py b/examples/run_sft.py index 2b7dd9489f..9377a32fcf 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -28,7 +28,7 @@ from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec from nemo_rl.data.llm_message_utils import get_formatted_message_log from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.utils.config import load_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir @@ -40,10 +40,7 @@ def parse_args(): ) # Parse known args for the script - args, remaining = parser.parse_known_args() - - # Convert remaining args to OmegaConf format - overrides = OmegaConf.from_dotlist(remaining) + args, overrides = parser.parse_known_args() return args, overrides @@ -154,7 +151,7 @@ def main(): if overrides: print(f"Overrides: {overrides}") - config = OmegaConf.merge(config, overrides) + config = parse_hydra_overrides(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index f3474fb0fd..d107d1599f 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -29,7 +29,7 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ checkpointing.enabled=true \ - checkpointing.save_every_n_steps=10 \ + checkpointing.save_period=10 \ checkpointing.checkpoint_dir=/tmp/sft_checkpoints \ $@ \ 2>&1 | tee $RUN_LOG