@@ -188,19 +188,25 @@ def __init__(self) -> None:
188
188
189
189
def __call__ (self , dict_config : DictConfig ) -> None :
190
190
self .config = dict_config
191
- if self .config .job .scheduler .mode == SchedulerType .SLURM :
192
- # modify the config metadata to add slurm info, this should be only time we intentionally modify the metadata
193
- self .config .job .metadata .slurm_env = get_slurm_env ()
194
- remove_runner_state_from_submission (
195
- dict_config .job .metadata .log_dir ,
196
- self .config .job .metadata .slurm_env .slurm_id ,
197
- )
191
+ # modify the config metadata to add slurm info if they exist
192
+ self .config .job .metadata .slurm_env = get_slurm_env ()
198
193
199
194
setup_env_vars ()
200
195
setup_logging ()
201
196
202
197
dist_config = map_job_config_to_dist_config (self .config .job )
203
198
distutils .setup (dist_config )
199
+ distutils .synchronize ()
200
+ if (
201
+ distutils .is_master ()
202
+ and self .config .job .scheduler .mode == SchedulerType .SLURM
203
+ ):
204
+ # this pickle file is shared across all processes so can only modify this on the main rank
205
+ remove_runner_state_from_submission (
206
+ dict_config .job .metadata .log_dir ,
207
+ self .config .job .metadata .slurm_env .slurm_id ,
208
+ )
209
+
204
210
if self .config .job .graph_parallel_group_size is not None :
205
211
gp_utils .setup_graph_parallel_groups (
206
212
self .config .job .graph_parallel_group_size ,
0 commit comments