diff --git a/scripts/performance/argument_parser.py b/scripts/performance/argument_parser.py index 6db57bac1a..0b6fdca692 100644 --- a/scripts/performance/argument_parser.py +++ b/scripts/performance/argument_parser.py @@ -419,6 +419,13 @@ def parse_cli_args(): help="Comma separated string of srun arguments", default=[], ) + slurm_args.add_argument( + "-cb", + "--custom_bash_cmds", + type=list_of_strings, + help="Comma separated string of bash commands", + default=[], + ) slurm_args.add_argument( "--gres", type=str, diff --git a/scripts/performance/setup_experiment.py b/scripts/performance/setup_experiment.py index 9408e1277c..f40f60717c 100755 --- a/scripts/performance/setup_experiment.py +++ b/scripts/performance/setup_experiment.py @@ -206,6 +206,7 @@ def main( custom_mounts: List[str], custom_env_vars: Dict[str, str], custom_srun_args: List[str], + custom_bash_cmds: List[str], nccl_ub: bool, pretrained_checkpoint: Optional[str], num_gpus: int, @@ -293,6 +294,7 @@ def main( custom_mounts=custom_mounts, custom_env_vars=custom_env_vars, custom_srun_args=custom_srun_args, + custom_bash_cmds=custom_bash_cmds, gres=args.gres, hf_token=hf_token, nemo_home=nemo_home, @@ -563,6 +565,7 @@ def main( custom_mounts=args.custom_mounts, custom_env_vars=args.custom_env_vars, custom_srun_args=args.custom_srun_args, + custom_bash_cmds=args.custom_bash_cmds, nccl_ub=args.nccl_ub, pretrained_checkpoint=args.pretrained_checkpoint, num_gpus=args.num_gpus,