Skip to content

Commit 820f102

Browse files
authored
Merge pull request #87 from epfLLM/evalonly_and_wbresume
Added eval only and wandb resume options
2 parents 402f7e8 + 47e6f9c commit 820f102

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

megatron/arguments.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,12 @@ def _add_logging_args(parser):
536536
help='Project name for Weights & Biases.')
537537
group.add_argument('--wandb_entity', type=str, default="meditron",
538538
help='Entity/team name for Weights & Biases.')
539+
group.add_argument('--wandb_name', type=str, default=None,
540+
help='Name for this run, alternatively can set `WANDB_NAME`.')
539541
group.add_argument('--wandb_id',type=str,default=None,
540542
help="Unique ID to identify this run, alternatively can set `WANDB_RUN_ID`.")
541-
group.add_argument('--wandb_resume',action="store_true",
542-
help="If set, we resume logging for the id given instead of launching a new run (errors if id given and resume=False).")
543+
group.add_argument('--wandb_resume',type=str,default="allow",
544+
help="If set, we resume logging for the id given instead of launching a new run (errors if id given and resume=None).")
543545
group.add_argument("--wandb_api_key",type=str,default=None,
544546
help="API key for Weights & Biases, needs to be set if not set in environment variable `WANDB_API_KEY`.")
545547
group.add_argument("--metrics", default=[], nargs="+", choices=list(METRICS) + ["all"],
@@ -878,6 +880,8 @@ def _add_distributed_args(parser):
878880

879881
def _add_validation_args(parser):
880882
group = parser.add_argument_group(title='validation')
883+
group.add_argument('--eval_only', action='store_true',
884+
help='Run evaluation only.')
881885
group.add_argument('--eval_iters', type=int, default=100,
882886
help='Number of iterations to run for evaluation'
883887
'validation/test for.')

megatron/training.py

+5
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,11 @@ def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provid
936936
args.do_valid = flags[1].item()
937937
args.do_test = flags[2].item()
938938

939+
if args.eval_only:
940+
args.do_train = False
941+
args.do_valid = False
942+
args.do_test = True
943+
939944
# Build iterators.
940945
dl_type = args.dataloader_type
941946
assert dl_type in ['single', 'cyclic']

megatron/wandb_logger.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def from_args(args)->'WandBConfig':
6060
log_interval=args.log_interval,
6161
config=args,entity=args.wandb_entity,
6262
project=args.wandb_project,
63+
name=args.wandb_name,
6364
run_id=args.wandb_id,
6465
resume=args.wandb_resume,
6566
api_key=args.wandb_api_key,

0 commit comments

Comments
 (0)