From e91071d369b0beafcf0dc93d6b7d9aae07e70757 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sat, 27 Jan 2024 13:48:45 -0800 Subject: [PATCH] Improve Memory Calculation Script (#22) * Update README and add llm directory site source * Add model directory source * Create static.yml (#13) * Update mem calculation to include per-gpu, any mixed precision, and fix a few errors * Update README and add swiglu to flops --- calc/README.md | 57 +++---- calc/calc_transformer_flops.py | 5 + calc/calc_transformer_mem.py | 264 ++++++++++++++++++++------------- 3 files changed, 196 insertions(+), 130 deletions(-) diff --git a/calc/README.md b/calc/README.md index 860633c..af2c8c0 100644 --- a/calc/README.md +++ b/calc/README.md @@ -20,8 +20,7 @@ Currently, scripts are entirely self-contained. This is for the dual purpose of: ``` Example with Fairseq-MoE 15B: python calc_transformer_flops.py -l 12 -hs 768 --moe -e 512 Example with GPT-3 175B: python calc_transformer_flops.py -l 96 -hs 12288 -usage: calc_transformer_flops.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] - [--topk TOPK] [--batch-size BATCH_SIZE] [--tokens TOKENS] [--no-checkpoint-activations] +usage: calc_transformer_flops.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--kv-size-ratio KV_SIZE_RATIO] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] [--topk TOPK] [--swiglu] [--batch-size BATCH_SIZE] [--tokens TOKENS] [--no-checkpoint-activations] options: -h, --help show this help message and exit @@ -33,12 +32,15 @@ options: Sequence length used for training --num-layers NUM_LAYERS, -l NUM_LAYERS Number of transformer layers used in model + --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO + Ratio of kv heads to query heads used in model. 1.0 for MHA --moe Whether our model is MoE --num-experts NUM_EXPERTS, -e NUM_EXPERTS Number of experts for MoE --expert-interval EXPERT_INTERVAL, -ei EXPERT_INTERVAL Expert interval for MoE --topk TOPK, -t TOPK Top k routing for MoE + --swiglu Use swiglu MLP. If set, ffn-hidden-size is defined as the inner dimension of each of the three MLP weights. --batch-size BATCH_SIZE, -b BATCH_SIZE Global batch size in units of samples --tokens TOKENS Number of tokens you are training over @@ -83,19 +85,15 @@ options: `calc_transformer_mem.py` calculates the amount of device memory required to train or infer a model. See [Transformers Math 101](https://blog.eleuther.ai/transformer-math/) for more details on how memory overhead is calculated. Take this estimation with a grain of salt, because every implementation is different and these calculations were written to match the GPT-NeoX library as close as possible. Even for other training and inference libraries, however, we expect our script to give approximate memory estimations within acceptable error. (Please see [LLM finetuning memory requirements](https://blog.scottlogic.com/2023/11/24/llm-mem.html) for a treatment of how specific memory costs may vary framework-to-framework). Other good resources that we consulted are [the ZeRO Paper](https://arxiv.org/abs/1910.02054) and [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/pdf/2205.05198.pdf). ``` -Example with pythia 6.9B: python transformer_mem.py --num-layers=32 --sequence-length=2048 --num-attention-heads=32 --hidden-size=4096 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=2 --num-gpus=128 --params=6900000000 -Example with pythia 12B: python transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256 --params=11849420800 -Example with default 20B: python transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1 --params=20000000000 +Example with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --sequence-length=2048 --num-attention-heads=32 --hidden-size=4096 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=2 --num-gpus=128 +Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256 +Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1 -usage: calc_transformer_mem.py [-h] [--params PARAMS] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] - [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--hidden-size HIDDEN_SIZE] [--num-attention-heads NUM_ATTENTION_HEADS] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] - [--fp32-model] [--fp32-grads] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--misc-mem-gb MISC_MEM_GB] [--num-experts NUM_EXPERTS] - [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--expert-parallelism EXPERT_PARALLELISM] [--vocab-size VOCAB_SIZE] +usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--num-attention-heads NUM_ATTENTION_HEADS] + [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM] [--misc-mem-gib MISC_MEM_GIB] options: -h, --help show this help message and exit - --params PARAMS, -p PARAMS - Number of Parameters --num-gpus NUM_GPUS Number of GPUs used for training --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE Tensor parallel degree (1 if not used) @@ -105,34 +103,43 @@ options: Whether we use ZeRO-R to partition activation memory across tensor-parallel degree --zero-stage {0,1,2,3}, -z {0,1,2,3} Stage of the ZeRO optimizer + --zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE, -zbs ZERO_ALLGATHER_BUCKET_SIZE + Size of allgather buckets used by ZeRO + --zero3-max-live-params ZERO3_MAX_LIVE_PARAMS, -zmlp ZERO3_MAX_LIVE_PARAMS + Maximum number of parameters ZeRO3 keeps in GPU memory --checkpoint-activations, -ca Whether Megatron-style activation checkpointing is being used --batch-size-per-gpu BATCH_SIZE_PER_GPU, -b BATCH_SIZE_PER_GPU Batch size per GPU + --sequence-length SEQUENCE_LENGTH, -s SEQUENCE_LENGTH + Sequence length used for training + --vocab-size VOCAB_SIZE, -v VOCAB_SIZE + How many tokens are in the embedding layer --hidden-size HIDDEN_SIZE, -hs HIDDEN_SIZE Dimension of the model's hidden size --num-attention-heads NUM_ATTENTION_HEADS, -a NUM_ATTENTION_HEADS Number of attention heads used in model - --sequence-length SEQUENCE_LENGTH, -s SEQUENCE_LENGTH - Sequence length used for training --num-layers NUM_LAYERS, -l NUM_LAYERS Number of transformer layers used in model - --fp32-model Whether model is stored in fp32 - --fp32-grads Whether grads are stored in fp32 - --zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE, -zbs ZERO_ALLGATHER_BUCKET_SIZE - Size of allgather buckets used by ZeRO - --zero3-max-live-params ZERO3_MAX_LIVE_PARAMS, -zmlp ZERO3_MAX_LIVE_PARAMS - Maximum number of parameters ZeRO3 keeps in GPU memory - --misc-mem-gb MISC_MEM_GB - Miscellaneous memory overhead by DL framework(s), communication libraries, etc - --num-experts NUM_EXPERTS - Number of experts --ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR How much the MLP hidden size expands + --infer whether we're doing inference + --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO + Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA. + --disable-mixed-precision + Disables mixed precision training + --high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL + The high-precision bytes per value (parameter, optimizer state, etc) in mixed precision + --low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL + The low-precision bytes per value (parameter, optimizer state, etc) in mixed precision + --bytes-per-grad-ele BYTES_PER_GRAD_ELE + The precision of gradient elements as bytes per value + --num-experts NUM_EXPERTS + Number of experts --expert-parallelism EXPERT_PARALLELISM, -ep EXPERT_PARALLELISM How many ways are the experts sharded across ranks - --vocab-size VOCAB_SIZE, -v VOCAB_SIZE - How many ways are the experts sharded across ranks + --misc-mem-gib MISC_MEM_GIB + Miscellaneous memory overhead per GPU by DL framework(s), communication libraries, etc ``` diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 64d550d..edad7f9 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -51,6 +51,9 @@ def config_parser(): type=int, default=1, help='Top k routing for MoE') + parser.add_argument("--swiglu", + action="store_true", + help='Use swiglu MLP. If set, ffn-hidden-size is defined as the inner dimension of each of the three MLP weights.') parser.add_argument("--batch-size", "-b", type=int, default=1, @@ -85,6 +88,8 @@ def calc_params(args): attention_over_values_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size linear_projection_flops = iter_factor * 2 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size ffn_flops = iter_factor * 16 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size + if args.swiglu: + ffn_flops = 3/2 * ffn_flops # no activation checkpointing for embeddings embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index 9e1a419..0a38b46 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -15,10 +15,7 @@ def convert_params(params): def config_parser(): parser = argparse.ArgumentParser() - parser.add_argument("--params", "-p", - type=int, - default=20000000000, - help='Number of Parameters') + # Distributed Settings parser.add_argument("--num-gpus", type=int, default=1, @@ -39,6 +36,15 @@ def config_parser(): default=1, choices=[0,1,2,3], help='Stage of the ZeRO optimizer') + parser.add_argument("--zero-allgather-bucket-size", "-zbs", + type=int, + default=5e8, + help='Size of allgather buckets used by ZeRO') + parser.add_argument("--zero3-max-live-params", "-zmlp", + type=int, + default=1e9, + help='Maximum number of parameters ZeRO3 keeps in GPU memory') + # Training settings parser.add_argument("--checkpoint-activations", "-ca", action="store_true", help='Whether Megatron-style activation checkpointing is being used') @@ -46,6 +52,15 @@ def config_parser(): type=int, default=1, help='Batch size per GPU') + parser.add_argument("--sequence-length", "-s", + type=int, + default=2048, + help='Sequence length used for training') + parser.add_argument("--vocab-size", "-v", + type=int, + default=51200, + help='How many tokens are in the embedding layer') + # Model settings parser.add_argument("--hidden-size", "-hs", type=int, default=6144, @@ -54,180 +69,219 @@ def config_parser(): type=int, default=64, help='Number of attention heads used in model') - parser.add_argument("--kv-size-ratio", "-kv", - type=float, - default=1.0, - help='Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.') - parser.add_argument("--sequence-length", "-s", - type=int, - default=2048, - help='Sequence length used for training') parser.add_argument("--num-layers", "-l", type=int, default=44, help='Number of transformer layers used in model') - parser.add_argument("--fp32-model", - action="store_true", - help='Whether model is stored in fp32') - parser.add_argument("--fp32-grads", + parser.add_argument("--ffn-expansion-factor", "-ff", + type=int, + default=4, + help='How much the MLP hidden size expands') + # Inference settings + parser.add_argument("--infer", action="store_true", - help='Whether grads are stored in fp32') - parser.add_argument("--zero-allgather-bucket-size", "-zbs", + help="whether we're doing inference") + parser.add_argument("--kv-size-ratio", "-kv", + type=float, + default=1.0, + help='Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.') + # Precision settings + parser.add_argument("--disable-mixed-precision", + action="store_false", + help='Disables mixed precision training', + dest='is_mixed_precision') + parser.add_argument("--high-prec-bytes-per-val", type=int, - default=5e8, - help='Size of allgather buckets used by ZeRO') - parser.add_argument("--zero3-max-live-params", "-zmlp", + default=4, + help='The high-precision bytes per value (parameter, optimizer state, etc) in mixed precision') + parser.add_argument("--low-prec-bytes-per-val", type=int, - default=1e9, - help='Maximum number of parameters ZeRO3 keeps in GPU memory') - parser.add_argument("--misc-mem-gb", + default=2, + help='The low-precision bytes per value (parameter, optimizer state, etc) in mixed precision') + parser.add_argument("--bytes-per-grad-ele", type=int, - default=0, - help='Miscellaneous memory overhead by DL framework(s), communication libraries, etc') + default=4, + help='The precision of gradient elements as bytes per value') + # MoE Settings parser.add_argument("--num-experts", - type=int, - default=0, - help='Number of experts') - parser.add_argument("--ffn-expansion-factor", "-ff", - type=int, - default=4, - help='How much the MLP hidden size expands') + type=int, + default=0, + help='Number of experts') parser.add_argument("--expert-parallelism", "-ep", - type=int, - default=1, - help='How many ways are the experts sharded across ranks') - parser.add_argument("--vocab-size", "-v", - type=int, - default=51200, - help='How many tokens are in the embedding layer') - parser.add_argument("--infer", - action="store_true", - help="whether we're doing inference") + type=int, + default=1, + help='How many ways are the experts sharded across ranks') + # Miscellaneous memory (good for accounting for implementation-dependent fudge factors) + parser.add_argument("--misc-mem-gib", + type=int, + default=0, + help='Miscellaneous memory overhead per GPU by DL framework(s), communication libraries, etc') + return parser -# calculates the total memory necessary for training a model + +# Calculates the total memory necessary for model training or inference def calc_mem(args): dp_degree = args.num_gpus / (args.tensor_parallel_size * args.pipeline_parallel_size) - # 4 bytes in fp32, 2 bytes in fp16/bf16 - if args.fp32_model: - bytes_per_param = 4 - else: - bytes_per_param = 2 - - - # compute total parameters from the config + # Compute total parameters from the config embed_params = 2 * args.vocab_size * args.hidden_size positional_params = args.hidden_size * args.sequence_length ln_params = 8 * args.hidden_size * args.num_layers + (2 * args.hidden_size) attention_params = int(2 * (1 + args.kv_size_ratio) * args.num_layers * args.hidden_size * args.hidden_size) mlp_params = 2 * args.num_layers * args.hidden_size * args.ffn_expansion_factor * args.hidden_size total_params = embed_params + positional_params + ln_params + attention_params + mlp_params + + # --- MODEL MEMORY --- + # 4 bytes in fp32, 2 bytes in fp16/bf16, 1 byte in fp8 + if args.is_mixed_precision: + bytes_per_param = args.low_prec_bytes_per_val + else: + bytes_per_param = args.high_prec_bytes_per_val + + # Compute memory from param calculation and parallelism settings + model_mem = total_params * bytes_per_param + per_gpu_model_mem = model_mem if args.num_experts > 0: total_moe_params = embed_params + positional_params + ln_params + attention_params + (args.num_experts * mlp_params) # Split the model with 3D parallelism if args.num_experts == 0: - model_mem = (total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) + per_gpu_model_mem = (total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) else: EP_total_params = embed_params + positional_params + ln_params + attention_params + ((args.num_experts/args.expert_parallelism) * mlp_params) - model_mem = (EP_total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) + per_gpu_model_mem = (EP_total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) # ZeRO stage 3 shards the model parameters across GPUs (plus the gradients and optimizer states) if args.zero_stage == 3: - model_mem /= args.num_gpus + model_mem_per_gpu /= args.num_gpus - # 4 bytes in fp32, 2 bytes in fp16/bf16 - if args.fp32_grads: - bytes_per_grad_element = 4 - else: - bytes_per_grad_element = 2 + # --- GRADIENT MEMORY --- + # E.g. 4 bytes in fp32, 2 bytes in fp16/bf16, 1 byte in fp8 + # Gradient precision is sometimes configurable in training frameworks. + # Since high batch size means many accumulations, higher precision grads may reduce grad overflow. + bytes_per_grad_element = args.bytes_per_grad_ele if args.num_experts > 0: gradient_mem = EP_total_params * bytes_per_grad_element else: gradient_mem = total_params * bytes_per_grad_element + per_gpu_gradient_mem = gradient_mem # ZeRO stage 2 shards the gradients across GPUs (plus the optimizer states) if args.zero_stage >= 2: - gradient_mem /= args.num_gpus - gradient_mem /= args.pipeline_parallel_size + per_gpu_gradient_mem /= args.num_gpus + # --- OPTIMIZER MEMORY --- # For mixed-precision Adam/AdamW, the optimizer must store fp32 copies of the parameters, momentum, and variance (4 + 4 + 4 = 12 bytes per optimizer parameter) # Feel free to change the multiplier for your optimizer (examples include SGD (4 + 4 = 8) and 8-bit ADAM (2 + 2 + 2 = 6) if args.num_experts > 0: optimizer_mem = EP_total_params * 12 else: optimizer_mem = total_params * 12 + per_gpu_optimizer_mem = optimizer_mem # ZeRO stage 3 shards the optimizer states across GPUs if args.zero_stage >= 1: - optimizer_mem /= args.num_gpus + per_gpu_optimizer_mem /= args.num_gpus - communication_mem = 0 + # --- COMMUNICATION MEMORY --- + # Temporary GPU storage for communication buffers may become significant + per_gpu_communication_mem = 0 # The size of the communication buffer DeepSpeed uses to store ZeRO optimizer elements - if args.zero_stage >= 1: - communication_mem += args.zero_allgather_bucket_size * bytes_per_param + if args.zero_stage >= 1 and args.num_gpus > 1: + per_gpu_communication_mem += args.zero_allgather_bucket_size * bytes_per_param # The number of parameters ZeRO-3 keeps alive in GPU memory at a time - if args.zero_stage == 3: - communication_mem += args.zero3_max_live_params * bytes_per_param + if args.zero_stage == 3 and args.num_gpus > 1: + per_gpu_communication_mem += args.zero3_max_live_params * bytes_per_param - # Taken from Table 2 in https://arxiv.org/pdf/1910.02054.pdf - # We find these don't perfectly match with experiment, but are good approximations + # --- ACTIVATION MEMORY --- + # Taken from Table 2 in https://arxiv.org/pdf/1910.02054.pdf and generalized to any precision (instead of just fp16 from the paper) + # 3 cases: [training with activation checkpointing, training without activation checkpointing, inferencing] if not args.infer and args.checkpoint_activations: - activation_mem = bytes_per_param * args.sequence_length * args.batch_size_per_gpu * args.hidden_size * args.num_layers * (10 + (24 / args.tensor_parallel_size)) + activation_mem = args.sequence_length * args.batch_size_per_gpu * args.hidden_size * args.num_layers * ((16 * args.low_prec_bytes_per_val + 2)) elif not args.infer and not args.checkpoint_activations: - activation_mem = bytes_per_param * args.sequence_length * args.batch_size_per_gpu * args.hidden_size * args.num_layers * (10 + (24 / args.tensor_parallel_size) + 5 * ((args.num_attention_heads * args.sequence_length) / (args.hidden_size * args.tensor_parallel_size))) + activation_mem = args.sequence_length * args.batch_size_per_gpu * args.hidden_size * args.num_layers * ((16 * args.low_prec_bytes_per_val + 2) + (2 * args.low_prec_bytes_per_val + 1) * (args.num_attention_heads * args.sequence_length / args.hidden_size)) # If using inference, assume just a single layer's activation memory at peak elif args.infer: - activation_mem = bytes_per_param * args.sequence_length * args.batch_size_per_gpu * args.hidden_size * (10 + (24 / args.tensor_parallel_size) + 5 * ((args.num_attention_heads * args.sequence_length) / (args.hidden_size * args.tensor_parallel_size))) - + activation_mem = args.sequence_length * args.batch_size_per_gpu * args.hidden_size * ((16 * args.low_prec_bytes_per_val + 2)) + per_gpu_activation_mem = activation_mem # DeepSpeed's ZeRO-R partitions activation memory across tensor-parallel GPUs if args.partition_activations: - activation_mem /= args.tensor_parallel_size - + per_gpu_activation_mem = activation_mem / args.tensor_parallel_size + # --- KV CACHE MEMORY (IF INFERENCE) --- if args.infer: - if args.fp32_model: - bytes_per_param = 4 - else: - bytes_per_param = 2 - kv_cache_mem = bytes_per_param * 2 * args.num_layers * args.num_attention_heads * (args.hidden_size / args.num_attention_heads) * args.sequence_length - - # We include a "Miscellaneous Memory" term because we find some 3D-parallel frameworks add a constant memory overhead (~5GB in our experiments with Megatron-DeepSpeed) that we cannot explain. If you know the source of this, add a comment! - gradient_mem_gb = gradient_mem / 1024**3 - activation_mem_gb = activation_mem / 1024**3 - model_mem_gb = model_mem / 1024**3 - optimizer_mem_gb = optimizer_mem / 1024**3 - communication_mem_gb = communication_mem / 1024**3 - total_mem_gb = activation_mem_gb + gradient_mem_gb + model_mem_gb + optimizer_mem_gb + communication_mem_gb + args.misc_mem_gb + # See https://kipp.ly/transformer-inference-arithmetic/ for details + bytes_per_param = args.low_prec_bytes_per_val + per_gpu_kv_cache_mem = bytes_per_param * 2 * args.num_layers * args.num_attention_heads * (args.hidden_size / args.num_attention_heads) * args.sequence_length + kv_cache_mem = args.num_gpus * per_gpu_kv_cache_mem + + gradient_mem_gib = gradient_mem / 1024**3 + activation_mem_gib = activation_mem / 1024**3 + model_mem_gib = model_mem / 1024**3 + optimizer_mem_gib = optimizer_mem / 1024**3 + + per_gpu_gradient_mem_gib = per_gpu_gradient_mem / 1024**3 + per_gpu_activation_mem_gib = per_gpu_activation_mem / 1024**3 + per_gpu_model_mem_gib = per_gpu_model_mem / 1024**3 + per_gpu_optimizer_mem_gib = per_gpu_optimizer_mem / 1024**3 + per_gpu_communication_mem_gib = per_gpu_communication_mem / 1024**3 + + + # We include a "Miscellaneous Memory" per GPU term because we find some 3D-parallel frameworks add a constant memory overhead (~5GiB in our experiments with Megatron-DeepSpeed) that we cannot explain. If you know the source of this, add a comment! if args.infer: - kv_cache_mem_gb = kv_cache_mem / 1024**3 + kv_cache_mem_gib = kv_cache_mem / 1024**3 + per_gpu_kv_cache_mem_gib = per_gpu_kv_cache_mem / 1024**3 if args.infer: - total_mem_gb = activation_mem_gb + kv_cache_mem_gb + model_mem_gb + args.misc_mem_gb + per_gpu_mem_gib = per_gpu_activation_mem_gib + per_gpu_kv_cache_mem_gib + per_gpu_model_mem_gib + args.misc_mem_gib + single_replica_mem_gib = activation_mem_gib + kv_cache_mem_gib + model_mem_gib + args.misc_mem_gib * args.num_gpus else: - total_mem_gb = activation_mem_gb + gradient_mem_gb + model_mem_gb + optimizer_mem_gb + communication_mem_gb + args.misc_mem_gb + per_gpu_mem_gib = per_gpu_activation_mem_gib + per_gpu_gradient_mem_gib + per_gpu_model_mem_gib + per_gpu_optimizer_mem_gib + per_gpu_communication_mem_gib + args.misc_mem_gib + single_replica_mem_gib = activation_mem_gib + gradient_mem_gib + model_mem_gib + optimizer_mem_gib + args.misc_mem_gib * args.num_gpus + # Print number of forward-pass parameters, and account for experts if using MoE print(f'Calculating memory with training configuration: {vars(args)}\n') print(f'Number of Parameters: {convert_params(total_params)}') if args.num_experts > 0: print(f'Total Number of MoE Parameters: {convert_params(total_moe_params)}') - print(f'Activation Memory: {activation_mem_gb:.2f} GB') - print(f'Model Memory: {model_mem_gb:.2f} GB') + print() + + # Print per-GPU memory for each component + print(f'*** Per-GPU Memory') + print(f'Per-GPU Activation Memory: {per_gpu_activation_mem_gib:.2f} GiB') + print(f'Per-GPU Model Memory: {per_gpu_model_mem_gib:.2f} GiB') if args.infer: - print(f'KV Cache Memory: {kv_cache_mem_gb:.2f} GB') + print(f'Per-GPU KV Cache Memory: {per_gpu_kv_cache_mem_gib:.2f} GiB') else: - print(f'Gradient Memory: {gradient_mem_gb:.2f} GB') - print(f'Optimizer Memory: {optimizer_mem_gb:.2f} GB') - print(f'Communication Memory: {communication_mem_gb:.2f} GB') - print(f'Miscellaneous Memory: {args.misc_mem_gb:.2f} GB') - + print(f'Per-GPU Gradient Memory: {per_gpu_gradient_mem_gib:.2f} GiB') + print(f'Per-GPU Optimizer Memory: {per_gpu_optimizer_mem_gib:.2f} GiB') + print(f'Per-GPU Communication Memory: {per_gpu_communication_mem_gib:.2f} GiB') + print(f'Per-GPU Miscellaneous Memory: {args.misc_mem_gib:.2f} GiB') + # Aggregate Per-GPU Memory + if args.infer: + print(f'\nPer-GPU Memory Required for Inference: {per_gpu_mem_gib:.2f} GiB') + else: + print(f'\nPer-GPU Memory Required for Training: {per_gpu_mem_gib:.2f} GiB') + print() + + # Print total GPU memory required to store a complete model replica + print(f'*** Total GPU Memory for a Single Model Replica') + print(f'Total Activation Memory: {activation_mem_gib:.2f} GiB') + print(f'Total Model Memory: {model_mem_gib:.2f} GiB') + if args.infer: + print(f'Total KV Cache Memory: {kv_cache_mem_gib:.2f} GiB') + else: + print(f'Total Gradient Memory: {gradient_mem_gib:.2f} GiB') + print(f'Total Optimizer Memory: {optimizer_mem_gib:.2f} GiB') + print(f'Total Miscellaneous Memory: {args.num_gpus*args.misc_mem_gib:.2f} GiB') + # Aggregate GPU memory if args.infer: - print(f'Total Memory Required for Inference: {total_mem_gb:.2f} GB') + print(f'\nTotal GPU Memory Required to Store a Complete Model Replica for Inference: {single_replica_mem_gib:.2f} GiB') else: - print(f'Total Memory Required for Training: {total_mem_gb:.2f} GB') + print(f'\nTotal GPU Memory Required to Store a Complete Model Replica for Training: {single_replica_mem_gib:.2f} GiB') if __name__ == "__main__": - print('\nExample with pythia 6.9B: python transformer_mem.py --num-layers=32 --sequence-length=2048 --num-attention-heads=32 --hidden-size=4096 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=2 --num-gpus=128 --params=6900000000') - print('Example with pythia 12B: python transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256 --params=11849420800') - print('Example with default 20B: python transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1 --params=20000000000\n') + print('\nExample with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --sequence-length=2048 --num-attention-heads=32 --hidden-size=4096 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=2 --num-gpus=128') + print('Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256') + print('Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1\n') args = config_parser().parse_args() calc_mem(args)