Skip to content

Commit

Permalink
estimate_zero2_model_states_mem_needs: fixing memory estiamtion (#5099)
Browse files Browse the repository at this point in the history
was considering 4 bytes per model param, and 4 bytes per gradient. 
fixed it to 2 bytes - under the assumption of FP16/BF16

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
nelyahu and tjruwase authored Jun 4, 2024
1 parent e7dd28a commit f4cb866
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,7 +2432,9 @@ def estimate_zero2_model_states_mem_needs(total_params,
gpu_mem = 2 * total_params
cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
# GPU's total_params multipliers: 2 = params_16bit,
# 18 = 2_grads_16bit + 4_grads_32bit + 4_params_32bit + 8_optimizer_states_32bit(momentum and variance)
gpu_mem = 2 * total_params + int(18 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor

return int(cpu_mem), int(gpu_mem)
Expand Down

0 comments on commit f4cb866

Please sign in to comment.