From 56aeee16b246c53f5d17b76f5fc37687b13e82a9 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Mon, 19 Feb 2024 21:56:18 -0500 Subject: [PATCH] Add simple inference FLOP counter to `calc_transformer_flops.py` (#31) * add --infer arg to flops calculator * add comment * fix comment * Update calc_transformer_flops.py --- calc/calc_transformer_flops.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index edad7f9..035cd9e 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -66,11 +66,13 @@ def config_parser(): action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') + parser.add_argument("--infer", "-i", + action='store_true', + help='Pass to calculate FLOPs for inference-only workload (no backward pass)') return parser # calculates the flops of a model given its hparams def calc_params(args): - assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" @@ -82,6 +84,10 @@ def calc_params(args): iter_factor = 3 if args.checkpoint_activations: iter_factor += 1 + # If inference-only, no bwd pass or activation ckpting necessary + # This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens. + if args.infer: + iter_factor = 1 qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size) attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size