diff --git a/src/exabiome/run/cori.py b/src/exabiome/run/cori.py index bb525b1..ca2835d 100644 --- a/src/exabiome/run/cori.py +++ b/src/exabiome/run/cori.py @@ -17,7 +17,7 @@ class SlurmJob(AbstractJob): debug_queue = 'debug' - def __init__(self, queue='batch', project='m2865', time='1:00:00', nodes=1, jobname=None, output=None, error=None): + def __init__(self, queue='batch', project='m2865', time='1:00:00', nodes=1, jobname=None, output=None, error=None, arch='gpu'): super().__init__() self.queue = queue self.project = project @@ -28,7 +28,7 @@ def __init__(self, queue='batch', project='m2865', time='1:00:00', nodes=1, jobn self.output = f'{self.jobname}.%J' self.error = f'{self.jobname}.%J' - self.add_addl_jobflag('C', 'gpu') + self.add_addl_jobflag('C', arch) def write_run(self, f, command, command_options, options): print(f'srun -u {command}', file=f) diff --git a/src/exabiome/run/run_job.py b/src/exabiome/run/run_job.py index 933a1c0..8c48fc4 100644 --- a/src/exabiome/run/run_job.py +++ b/src/exabiome/run/run_job.py @@ -27,7 +27,7 @@ def check_cori(args): if args.nodes is None: args.nodes = 1 if args.outdir is None: - args.outdir = os.path.abspath("$CSCRATCH/exabiome/deep-index") + args.outdir = os.path.abspath(os.path.expandvars("$CSCRATCH/exabiome/deep-index")) def run_train(argv=None): @@ -48,6 +48,7 @@ def run_train(argv=None): rsc_grp.add_argument('-N', '--jobname', help="the name of the job", default=None) rsc_grp.add_argument('-q', '--queue', help="the queue to submit to", default=None) rsc_grp.add_argument('-P', '--project', help="the project/account to submit under", default=None) + rsc_grp.add_argument('-a', '--arch', help="the architecture to use, e.g., gpu or haswell (cori only)", default='gpu') system_grp = parser.add_argument_group('Compute system') grp = system_grp.add_mutually_exclusive_group() @@ -82,13 +83,14 @@ def run_train(argv=None): if args.summit: check_summit(args) job = LSFJob() - job.set_conda_env(args.conda_env) job.add_modules('open-ce') if not args.load: job.set_use_bb(True) else: check_cori(args) - job = SlurmJob() + job = SlurmJob(arch=args.arch) + + job.set_conda_env(args.conda_env) job.nodes = args.nodes job.time = args.time