Skip to content

Commit

Permalink
quick fix for transformer_flops.py error (#27)
Browse files Browse the repository at this point in the history
* quick fix for transformer_flops.py error

* fix breakage

* cleanup

* fix bad merge

* fixes megatron init

* fixes megatron init

---------

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
jahatef and stas00 committed Feb 16, 2024
1 parent 3c075f4 commit f52e7f3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions benchmarks/sizing/megatron_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def get_megatron_args(configuration, override_tensor_mp_size=False):
args.kv_channels = args.hidden_size // args.num_attention_heads
args.padded_vocab_size=vocab_size
args.attention_config = [[["flash"], 0]]
args.train_batch_size = train_batch_size
#megatron.global_vars._GLOBAL_ARGS = args
neox_args = megatron.NeoXArgs.from_dict(asdict(args))
return neox_args
12 changes: 9 additions & 3 deletions benchmarks/sizing/transformer_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def benchmark_transformer_from_mm_and_bmm(args, configuration, seq_length, globa
elapsed_mlp_time = 0.0
elapsed_add_bias_dropout_time = 0.0
elapsed_layer_norm_time = 0.0
attention_throughput = 0.0
mlp_throughput = 0.0
total_throughput = 0.0

if 'qkv_transform' in args.blocks or 'all' in args.blocks:
elapsed_attention_time += benchmark_mm_b(
Expand Down Expand Up @@ -109,9 +112,12 @@ def benchmark_transformer_from_mm_and_bmm(args, configuration, seq_length, globa
16 * microbatch_size * seq_length * hidden_size * hidden_size / tensor_mp_size
num_total_floating_point_operations = num_attention_floating_point_operations + \
num_mlp_floating_point_operations
attention_throughput = num_attention_floating_point_operations / (elapsed_attention_time * 10**12)
mlp_throughput = num_mlp_floating_point_operations / (elapsed_mlp_time * 10**12)
total_throughput = num_total_floating_point_operations / (elapsed_total_time * 10**12)
if elapsed_attention_time > 0:
attention_throughput = num_attention_floating_point_operations / (elapsed_attention_time * 10**12)
if elapsed_mlp_time > 0:
mlp_throughput = num_mlp_floating_point_operations / (elapsed_mlp_time * 10**12)
if elapsed_total_time > 0:
total_throughput = num_total_floating_point_operations / (elapsed_total_time * 10**12)

print()
for (elapsed_time, throughput, label) in \
Expand Down

0 comments on commit f52e7f3

Please sign in to comment.