diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index 0a38b46..dae11c6 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -153,7 +153,7 @@ def calc_mem(args): per_gpu_model_mem = (EP_total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size) # ZeRO stage 3 shards the model parameters across GPUs (plus the gradients and optimizer states) if args.zero_stage == 3: - model_mem_per_gpu /= args.num_gpus + per_gpu_model_mem /= args.num_gpus # --- GRADIENT MEMORY --- # E.g. 4 bytes in fp32, 2 bytes in fp16/bf16, 1 byte in fp8