diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 999949fa1b8f..e50511ea7d3a 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -145,11 +145,12 @@ def __init__( self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + router_logits_dtype = torch.float32 self.gate = ReplicatedLinear( config.hidden_size, config.n_routed_experts, bias=False, - params_dtype=torch.float32, + params_dtype=router_logits_dtype, quant_config=None, prefix=f"{prefix}.gate", ) @@ -209,6 +210,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + router_logits_dtype=router_logits_dtype, ) if self.use_latent_moe: