diff --git a/benchmarks/sizing/megatron_wrapper.py b/benchmarks/sizing/megatron_wrapper.py index 37b3153..1b5c4b8 100644 --- a/benchmarks/sizing/megatron_wrapper.py +++ b/benchmarks/sizing/megatron_wrapper.py @@ -110,6 +110,7 @@ def get_megatron_args(configuration, override_tensor_mp_size=False): args.kv_channels = args.hidden_size // args.num_attention_heads args.padded_vocab_size=vocab_size args.attention_config = [[["flash"], 0]] + args.train_batch_size = train_batch_size #megatron.global_vars._GLOBAL_ARGS = args neox_args = megatron.NeoXArgs.from_dict(asdict(args)) return neox_args diff --git a/benchmarks/sizing/transformer_flops.py b/benchmarks/sizing/transformer_flops.py index 0376b1d..24a0e8f 100644 --- a/benchmarks/sizing/transformer_flops.py +++ b/benchmarks/sizing/transformer_flops.py @@ -29,6 +29,9 @@ def benchmark_transformer_from_mm_and_bmm(args, configuration, seq_length, globa elapsed_mlp_time = 0.0 elapsed_add_bias_dropout_time = 0.0 elapsed_layer_norm_time = 0.0 + attention_throughput = 0.0 + mlp_throughput = 0.0 + total_throughput = 0.0 if 'qkv_transform' in args.blocks or 'all' in args.blocks: elapsed_attention_time += benchmark_mm_b( @@ -109,9 +112,12 @@ def benchmark_transformer_from_mm_and_bmm(args, configuration, seq_length, globa 16 * microbatch_size * seq_length * hidden_size * hidden_size / tensor_mp_size num_total_floating_point_operations = num_attention_floating_point_operations + \ num_mlp_floating_point_operations - attention_throughput = num_attention_floating_point_operations / (elapsed_attention_time * 10**12) - mlp_throughput = num_mlp_floating_point_operations / (elapsed_mlp_time * 10**12) - total_throughput = num_total_floating_point_operations / (elapsed_total_time * 10**12) + if elapsed_attention_time > 0: + attention_throughput = num_attention_floating_point_operations / (elapsed_attention_time * 10**12) + if elapsed_mlp_time > 0: + mlp_throughput = num_mlp_floating_point_operations / (elapsed_mlp_time * 10**12) + if elapsed_total_time > 0: + total_throughput = num_total_floating_point_operations / (elapsed_total_time * 10**12) print() for (elapsed_time, throughput, label) in \