3434 macros ,
3535 runopts ,
3636)
37+ from torchx .workspace .dir_workspace import DirWorkspace
3738
39+ SLURM_JOB_DIRS = ".torchxslurmjobdirs"
3840
3941SLURM_STATES : Mapping [str , AppState ] = {
4042 "BOOT_FAIL" : AppState .FAILED ,
@@ -166,6 +168,7 @@ class SlurmBatchRequest:
166168
167169 cmd : List [str ]
168170 replicas : Dict [str , SlurmReplicaRequest ]
171+ job_dir : Optional [str ]
169172
170173 def materialize (self ) -> str :
171174 """
@@ -200,7 +203,7 @@ def materialize(self) -> str:
200203 return script
201204
202205
203- class SlurmScheduler (Scheduler ):
206+ class SlurmScheduler (Scheduler , DirWorkspace ):
204207 """
205208 SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
206209 that slurm CLI tools are locally installed and job accounting is enabled.
@@ -254,11 +257,8 @@ class SlurmScheduler(Scheduler):
254257 Partial support. SlurmScheduler will return job and replica
255258 status but does not provide the complete original AppSpec.
256259 workspaces: |
257- Partial support. Typical Slurm usage is from a shared NFS mount
258- so code will automatically be updated on the workers.
259- SlurmScheduler does not support programmatic patching via
260- WorkspaceScheduler.
261-
260+ If ``job_dir`` is specified the DirWorkspace will create a new
261+ isolated directory with a snapshot of the workspace.
262262 """
263263
264264 def __init__ (self , session_name : str ) -> None :
@@ -276,7 +276,9 @@ def run_opts(self) -> runopts:
276276 "time" ,
277277 type_ = str ,
278278 default = None ,
279- help = "The maximum time the job is allowed to run for." ,
279+ help = 'The maximum time the job is allowed to run for. Formats: \
280+ "minutes", "minutes:seconds", "hours:minutes:seconds", "days-hours", \
281+ "days-hours:minutes" or "days-hours:minutes:seconds"' ,
280282 )
281283 opts .add (
282284 "nomem" ,
@@ -304,25 +306,45 @@ def run_opts(self) -> runopts:
304306 type_ = str ,
305307 help = "What events to mail users on." ,
306308 )
309+ opts .add (
310+ "job_dir" ,
311+ type_ = str ,
312+ help = """The directory to place the job code and outputs. The
313+ directory must not exist and will be created. To enable log
314+ iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
315+ """ ,
316+ )
307317 return opts
308318
309319 def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
310320 req = dryrun_info .request
321+ job_dir = req .job_dir
311322 with tempfile .TemporaryDirectory () as tmpdir :
312323 script = req .materialize ()
313- path = os .path .join (tmpdir , "job .sh" )
324+ path = os .path .join (job_dir or tmpdir , "torchx-sbatch .sh" )
314325
315326 with open (path , "w" ) as f :
316327 f .write (script )
317328
318- cmd = req .cmd + [path ]
329+ cmd = req .cmd
330+ if job_dir is not None :
331+ cmd += [f"--chdir={ job_dir } " ]
332+ cmd += [path ]
319333
320334 p = subprocess .run (cmd , stdout = subprocess .PIPE , check = True )
321- return p .stdout .decode ("utf-8" ).strip ()
335+ job_id = p .stdout .decode ("utf-8" ).strip ()
336+
337+ if job_dir is not None :
338+ _save_job_dir (job_id , job_dir )
339+
340+ return job_id
322341
323342 def _submit_dryrun (
324343 self , app : AppDef , cfg : Mapping [str , CfgVal ]
325344 ) -> AppDryRunInfo [SlurmBatchRequest ]:
345+ job_dir = cfg .get ("job_dir" )
346+ assert job_dir is None or isinstance (job_dir , str ), "job_dir must be str"
347+
326348 replicas = {}
327349 for role in app .roles :
328350 for replica_id in range (role .num_replicas ):
@@ -344,6 +366,7 @@ def _submit_dryrun(
344366 req = SlurmBatchRequest (
345367 cmd = cmd ,
346368 replicas = replicas ,
369+ job_dir = job_dir ,
347370 )
348371 return AppDryRunInfo (req , repr )
349372
@@ -435,6 +458,10 @@ def log_iter(
435458 )
436459
437460 log_file = f"slurm-{ app_id } -{ role_name } -{ k } .{ extension } "
461+ job_dirs = _get_job_dirs ()
462+ print (job_dirs )
463+ if app_id in job_dirs :
464+ log_file = os .path .join (job_dirs [app_id ], log_file )
438465
439466 return LogIterator (
440467 app_id , regex or ".*" , log_file , self , should_tail = should_tail
@@ -445,3 +472,24 @@ def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler:
445472 return SlurmScheduler (
446473 session_name = session_name ,
447474 )
475+
476+
477+ def _save_job_dir (job_id : str , job_dir : str ) -> None :
478+ with open (SLURM_JOB_DIRS , "at" ) as f :
479+ f .write (f"{ job_id } = { job_dir } \n " )
480+
481+
482+ def _get_job_dirs () -> Mapping [str , str ]:
483+ try :
484+ with open (SLURM_JOB_DIRS , "rt" ) as f :
485+ lines = f .readlines ()
486+ except FileNotFoundError :
487+ return {}
488+
489+ out = {}
490+ for line in lines :
491+ first , _ , second = line .partition ("=" )
492+ if not first or not second :
493+ continue
494+ out [first .strip ()] = second .strip ()
495+ return out
0 commit comments