From a17c66f6d20d711c43f9cb60013a7eb2c29245ef Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sun, 11 Aug 2024 14:17:37 +0200 Subject: [PATCH] Add MLP Linears Argument (#37) * Add mlp linears arg and clean up args indentation * add linears per mlp and expansion factor args to flop calculation * add updated options to calc readme --- calc/README.md | 22 +++++++++++++----- calc/calc_transformer_flops.py | 15 ++++++++++--- calc/calc_transformer_mem.py | 6 ++++- calc/calc_transformer_params.py | 40 ++++++++++++++++++--------------- 4 files changed, 55 insertions(+), 28 deletions(-) diff --git a/calc/README.md b/calc/README.md index 259283a..1e599d7 100644 --- a/calc/README.md +++ b/calc/README.md @@ -20,7 +20,8 @@ Currently, scripts are entirely self-contained. This is for the dual purpose of: ``` Example with Fairseq-MoE 15B: python calc_transformer_flops.py -l 12 -hs 768 --moe -e 512 Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288 -usage: calc_transformer_flops.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--kv-size-ratio KV_SIZE_RATIO] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] [--topk TOPK] [--swiglu] [--batch-size BATCH_SIZE] [--tokens TOKENS] [--no-checkpoint-activations] +usage: calc_transformer_flops.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--kv-size-ratio KV_SIZE_RATIO] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] + [--topk TOPK] [--swiglu] [--batch-size BATCH_SIZE] [--tokens TOKENS] [--no-checkpoint-activations] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--num-mlp-linears NUM_MLP_LINEARS] [--infer] options: -h, --help show this help message and exit @@ -46,6 +47,11 @@ options: --tokens TOKENS Number of tokens you are training over --no-checkpoint-activations, -ca Whether Megatron-style activation checkpointing is being used + --ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR + How much the MLP hidden size expands + --num-mlp-linears NUM_MLP_LINEARS, -nl NUM_MLP_LINEARS + How many linear layers per MLP block + --infer, -i Pass to calculate FLOPs for inference-only workload (no backward pass) ``` @@ -56,8 +62,8 @@ options: ``` Example with Fairseq-MoE 15B: python calc_transformer_params.py -l 12 -hs 768 --moe -e 512 Example with GPT-3 175B: python calc_transformer_params.py -l 96 -hs 12288 -usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--tied-embeddings] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] - [--expert-interval EXPERT_INTERVAL] [--topk TOPK] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--kv-size-ratio KV_SIZE_RATIO] +usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--tied-embeddings] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] [--topk TOPK] + [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--num-mlp-linears NUM_MLP_LINEARS] [--kv-size-ratio KV_SIZE_RATIO] options: -h, --help show this help message and exit @@ -78,6 +84,8 @@ options: --topk TOPK, -t TOPK Top k routing for MoE --ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR How much the MLP hidden size expands + --num-mlp-linears NUM_MLP_LINEARS, -nl NUM_MLP_LINEARS + How many linear layers per MLP block --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO What fraction of num. query heads is num. key/value heads ``` @@ -94,9 +102,9 @@ Example with default 20B: python calc_transformer_mem.py --num-layers=44 --seque usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] - [--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--output-tokens OUTPUT_TOKENS] [--disable-mixed-precision] - [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM] - [--misc-mem-gib MISC_MEM_GIB] + [--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--num-mlp-linears NUM_MLP_LINEARS] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--output-tokens OUTPUT_TOKENS] + [--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] + [--expert-parallelism EXPERT_PARALLELISM] [--misc-mem-gib MISC_MEM_GIB] options: -h, --help show this help message and exit @@ -129,6 +137,8 @@ options: Number of transformer layers used in model --ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR How much the MLP hidden size expands + --num-mlp-linears NUM_MLP_LINEARS, -nl NUM_MLP_LINEARS + How many linear layers per MLP block --infer whether we're doing inference --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA. diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 035cd9e..d9f1cd2 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -66,13 +66,21 @@ def config_parser(): action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') + parser.add_argument("--ffn-expansion-factor", "-ff", + type=int, + default=4, + help='How much the MLP hidden size expands') + parser.add_argument("--num-mlp-linears", "-nl", + type=int, + default=2, + help='How many linear layers per MLP block') parser.add_argument("--infer", "-i", action='store_true', help='Pass to calculate FLOPs for inference-only workload (no backward pass)') return parser # calculates the flops of a model given its hparams -def calc_params(args): +def calc_flops(args): assert args.topk <= args.num_experts, "You cannot route to more experts than you have!" assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" @@ -89,11 +97,12 @@ def calc_params(args): if args.infer: iter_factor = 1 + # The factor of 2 from all these terms comes from the multiply + accumulate qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size) attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size attention_over_values_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size linear_projection_flops = iter_factor * 2 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size - ffn_flops = iter_factor * 16 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size + ffn_flops = iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor * args.num_layers * args.tokens * args.hidden_size * args.hidden_size if args.swiglu: ffn_flops = 3/2 * ffn_flops # no activation checkpointing for embeddings @@ -126,4 +135,4 @@ def calc_params(args): print('Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288') args = config_parser().parse_args() - calc_params(args) + calc_flops(args) diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index a3a8bb1..f303d02 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -77,6 +77,10 @@ def config_parser(): type=int, default=4, help='How much the MLP hidden size expands') + parser.add_argument("--num-mlp-linears", "-nl", + type=int, + default=2, + help='How many linear layers per MLP block') # Inference settings parser.add_argument("--infer", action="store_true", @@ -134,7 +138,7 @@ def calc_mem(args): positional_params = args.hidden_size * args.sequence_length ln_params = 8 * args.hidden_size * args.num_layers + (2 * args.hidden_size) attention_params = int(2 * (1 + args.kv_size_ratio) * args.num_layers * args.hidden_size * args.hidden_size) - mlp_params = 2 * args.num_layers * args.hidden_size * args.ffn_expansion_factor * args.hidden_size + mlp_params = args.num_mlp_linears * args.num_layers * args.hidden_size * args.ffn_expansion_factor * args.hidden_size total_params = embed_params + positional_params + ln_params + attention_params + mlp_params # --- MODEL MEMORY --- diff --git a/calc/calc_transformer_params.py b/calc/calc_transformer_params.py index 6490ed3..94d63d6 100644 --- a/calc/calc_transformer_params.py +++ b/calc/calc_transformer_params.py @@ -35,28 +35,32 @@ def config_parser(): default=44, help='Number of transformer layers used in model') parser.add_argument("--moe", - action="store_true", - help='Whether our model is MoE') + action="store_true", + help='Whether our model is MoE') parser.add_argument("--num-experts", "-e", - type=int, - default=8, - help='Number of experts for MoE') + type=int, + default=8, + help='Number of experts for MoE') parser.add_argument("--expert-interval", "-ei", - type=int, - default=1, - help='Expert interval for MoE') + type=int, + default=1, + help='Expert interval for MoE') parser.add_argument("--topk", "-t", type=int, default=1, help='Top k routing for MoE') parser.add_argument("--ffn-expansion-factor", "-ff", - type=int, - default=4, - help='How much the MLP hidden size expands') + type=int, + default=4, + help='How much the MLP hidden size expands') + parser.add_argument("--num-mlp-linears", "-nl", + type=int, + default=2, + help='How many linear layers per MLP block') parser.add_argument("--kv-size-ratio", "-kv", - type=float, - default=1.0, - help='What fraction of num. query heads is num. key/value heads') + type=float, + default=1.0, + help='What fraction of num. query heads is num. key/value heads') return parser # calculates the params of a model given their hparams @@ -79,15 +83,15 @@ def calc_params(args): # the number of layers that are MoE. (e.g. interval is 2 for GShard) num_expert_layers = args.num_layers / args.expert_interval # the number of FFN params for each MoE layer - ffn_expert_params = 2 * args.ffn_expansion_factor * num_expert_layers * args.num_experts * args.hidden_size * args.hidden_size + ffn_expert_params = args.num_mlp_linears * args.ffn_expansion_factor * num_expert_layers * args.num_experts * args.hidden_size * args.hidden_size # the number of FFN params for every dense layer - ffn_dense_params = 2 * args.ffn_expansion_factor * (args.num_layers - num_expert_layers) * args.hidden_size * args.hidden_size + ffn_dense_params = args.num_mlp_linears * args.ffn_expansion_factor * (args.num_layers - num_expert_layers) * args.hidden_size * args.hidden_size ffn_params = ffn_expert_params + ffn_dense_params # the number of gating layer params assuming it's implemented as a simple linear layer gating_params = num_expert_layers * args.hidden_size * args.num_experts else: - # two (h x [ffn_expansion_factor * h]) FFN matrices - ffn_params = 2 * args.ffn_expansion_factor * args.num_layers * args.hidden_size * args.hidden_size + # num_mlp_layers * (h x [ffn_expansion_factor * h]) FFN matrices + ffn_params = args.num_mlp_linears * args.ffn_expansion_factor * args.num_layers * args.hidden_size * args.hidden_size total_params = embedding_params + attention_params + ffn_params + position_embedding_params + layernorm_params