diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 7b1149a781d..a856ce9ea23 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -188,23 +188,22 @@ def __init__( ) ) setattr(self.dt_bias, "tensor_model_parallel", True) - # A_log parameter + # A_log is kept in FP32 for numerical stability self.A_log = nn.Parameter( torch.empty( - self.num_v_heads_local_tp, - dtype=config.params_dtype, - device=torch.cuda.current_device(), + self.num_v_heads_local_tp, dtype=torch.float32, device=torch.cuda.current_device() ) ) setattr(self.A_log, "tensor_model_parallel", True) - # Output layernorm before projection + # Output layernorm before projection — kept in FP32 to match HF checkpoint precision self.out_norm = build_module( submodules.out_norm, config=self.config, hidden_size=self.value_head_dim, eps=self.config.layernorm_epsilon, ) + self.out_norm.to(torch.float32) self.out_proj = build_module( submodules.out_proj, @@ -238,10 +237,10 @@ def reset_parameters(self): dtype=self.config.params_dtype, device=torch.cuda.current_device(), ) - # A_log + # A_log (FP32) A = torch.empty( self.num_v_heads_local_tp, - dtype=self.config.params_dtype, + dtype=torch.float32, device=torch.cuda.current_device(), ).uniform_(*self.A_init_range) self.A_log.data.copy_(torch.log(A))