diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8915d8172c56..0d303e3eb8a4 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -875,6 +875,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/nemotron_h_mtp.py b/vllm/model_executor/models/nemotron_h_mtp.py index fe737438c30f..1f0693c5ae5e 100644 --- a/vllm/model_executor/models/nemotron_h_mtp.py +++ b/vllm/model_executor/models/nemotron_h_mtp.py @@ -11,6 +11,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config.parallel import ParallelConfig +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( fused_moe_make_expert_params_mapping, ) @@ -36,6 +37,8 @@ NemotronHMoEDecoderLayer, ) +logger = init_logger(__name__) + class NemotronHMTPAttentionDecoderLayer(NemotronHAttentionDecoderLayer): def __init__( @@ -242,6 +245,35 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Total number of physical layers = num_steps * pattern_len total_layers = self.num_mtp_layers * self.pattern_len + + quant_config = vllm_config.quant_config + if ( + quant_config is not None + and quant_config.get_name() == "compressed-tensors" + and hasattr(quant_config, "ignore") + ): + num_experts = getattr(config, "n_routed_experts", None) + if getattr(config, "model_type", None) == "nemotron_h_puzzle": + num_experts = getattr(config, "mtp_n_routed_experts", num_experts) + if num_experts: + extra: list[str] = [] + for i in range(total_layers): + if self.pattern_str[i % self.pattern_len] != "E": + continue + for eid in range(num_experts): + for proj in ("gate_proj", "up_proj", "down_proj"): + extra.append( + f"{prefix}.layers.{i}.mixer.experts.{eid}.{proj}" + ) + new_entries = [n for n in extra if n not in quant_config.ignore] + quant_config.ignore.extend(new_entries) + if new_entries: + logger.info( + "NemotronH-MTP: extended compressed-tensors ignore " + "with %d per-expert MTP linears (BF16 in the checkpoint)", + len(new_entries), + ) + for i in range(total_layers): step_rel_idx = i % self.pattern_len @@ -346,6 +378,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), )