@@ -233,7 +233,7 @@ def get_batch(source):
233
233
234
234
235
235
def verify_peak_memory (golden_config , std_dev ):
236
- print ( "Peak allocated bytes on cuda:0: {:1d}" . format ( torch . cuda . memory_stats ( 0 )[ "allocated_bytes.all.peak" ]))
236
+
237
237
current_device_usage = torch .cuda .memory_stats (0 )["allocated_bytes.all.peak" ]
238
238
golden_ref = golden_config ["peak_mem_usage" ]
239
239
if not current_device_usage < golden_ref * std_dev :
@@ -246,7 +246,6 @@ def verify_peak_memory(golden_config, std_dev):
246
246
def verify_lm_throughput (wps , golden_config , args ):
247
247
"""Verify that words per second for a given benchmark run matches the golden data."""
248
248
249
- print ("Throughput(wps) is {:.2f}." .format (wps ))
250
249
if not wps > (golden_config ["avg_wps" ] - (3 * golden_config ["std_dev_wps" ])):
251
250
raise RuntimeError (
252
251
"Throughput(wps):{:.2f} is below the golden threshold of an "
@@ -272,9 +271,12 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
272
271
raise RuntimeError (
273
272
f"Golden data verification is only supported for the Transformer(lm) model and not { args .model_name } "
274
273
)
275
- golden_config = get_golden_config (args .model_name , args )
276
- verify_lm_throughput (wps , golden_config , args )
277
- verify_peak_memory (golden_config , 1.1 )
274
+ print ("Throughput(wps) is {:.2f}." .format (wps ))
275
+ print ("Peak allocated bytes on cuda:0: {:1d}" .format (torch .cuda .memory_stats (0 )["allocated_bytes.all.peak" ]))
276
+ if not args .dry_run :
277
+ golden_config = get_golden_config (args .model_name , args )
278
+ verify_lm_throughput (wps , golden_config , args )
279
+ verify_peak_memory (golden_config , 1.1 )
278
280
279
281
280
282
def get_synthetic_dataloaders (args , device , benchmark_config , model_specs ):
@@ -343,11 +345,11 @@ def create_model_config(args, benchmark_config=None, model_specs=None):
343
345
raise RuntimeError (f"Unrecognized args.model_mame { args .model_name } " )
344
346
345
347
346
- def create_benchmark_config (model_name ):
348
+ def create_benchmark_config (args ):
347
349
"""Return a dict with configurations required for benchmarking `model_name` model."""
348
350
349
351
if args .model_name == "lm" :
350
- return lm_wikitext2 .get_benchmark_config ()
352
+ return lm_wikitext2 .get_benchmark_config (checkpoint_activation = args . checkpoint_activation )
351
353
elif args .model_name == "seq" :
352
354
return offload_seq .get_benchmark_config ()
353
355
else :
@@ -383,17 +385,15 @@ def run_benchmark(args):
383
385
init_random_seed (0 )
384
386
385
387
if args .model_name == "lm" :
386
- benchmark_config = create_benchmark_config (args . model_name )
388
+ benchmark_config = create_benchmark_config (args )
387
389
model_specs = get_model_specs (args .model_name )
388
390
model_config = create_model_config (args , benchmark_config = benchmark_config , model_specs = model_specs )
389
391
model = model_config ["model" ]
390
392
391
- if args .dry_run :
392
- train (model_config , model , benchmark_config , model_specs , args )
393
- else :
394
- benchmark_language_model (model_config , model , benchmark_config , model_specs , args )
393
+ benchmark_language_model (model_config , model , benchmark_config , model_specs , args )
394
+
395
395
elif args .model_name == "seq" :
396
- benchmark_config = create_benchmark_config (args . model_name )
396
+ benchmark_config = create_benchmark_config (args )
397
397
model_specs = get_model_specs (args .model_name )
398
398
model_config = create_model_config (args , benchmark_config = benchmark_config , model_specs = model_specs )
399
399
model = model_config ["model" ]
@@ -419,7 +419,7 @@ def run_benchmark(args):
419
419
"--use_synthetic_data" , default = True , action = "store_true" , help = "Uses synthetic data for running benchmarks."
420
420
)
421
421
parser .add_argument ("--use_fp16" , action = "store_true" , default = False )
422
- parser .add_argument ("--checkpoint_activation" , action = "store_true" , default = True )
422
+ parser .add_argument ("--checkpoint_activation" , action = "store_true" , default = False )
423
423
parser .add_argument ("--use_profiler" , action = "store_true" , default = False )
424
424
425
425
0 commit comments