From ef6afa8df4470aacbe71d3b3abc9daa96e629c7d Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sun, 11 Feb 2024 15:51:01 -0800 Subject: [PATCH] fix varname bug (#25) --- calc/calc_transformer_mem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index 0a38b46..dae11c6 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -153,7 +153,7 @@ def calc_mem(args): per_gpu_model_mem = (EP_total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) # ZeRO stage 3 shards the model parameters across GPUs (plus the gradients and optimizer states) if args.zero_stage == 3: - model_mem_per_gpu /= args.num_gpus + per_gpu_model_mem /= args.num_gpus # --- GRADIENT MEMORY --- # E.g. 4 bytes in fp32, 2 bytes in fp16/bf16, 1 byte in fp8