diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index 48628fec46e0..41dc85682b05 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, + UnquantizedLinearMethod, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -226,6 +227,11 @@ def __init__( prefix=f"{prefix}.fused_wkv_wgate", ) self.norm = RMSNorm(self.head_dim, self.rms_norm_eps) + if not isinstance(self.fused_wkv_wgate.quant_method, UnquantizedLinearMethod): + raise NotImplementedError( + "Quantization of `indexer.compressor.wkv/wgate` is not supported " + "due to accuracy concerns. See #42001." + ) self.state_cache = CompressorStateCache( state_dim=2 * self.coff * self.head_dim, # kv_state + score_state diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 494d61338084..087ef375aa5b 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.model_executor.layers.linear import ( ReplicatedLinear, + UnquantizedLinearMethod, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.utils.deep_gemm import fp8_einsum @@ -284,6 +285,17 @@ def __init__( k_cache_prefix=self.mla_attn.prefix, ) + # For now, model requires fp8 quantization for attention + # assume that there exists a `wo` weight scale + if hasattr(self.wo_a, "weight_scale_inv"): + self.wo_scale_name = "weight_scale_inv" + elif hasattr(self.wo_a, "weight_scale"): + self.wo_scale_name = "weight_scale" + else: + raise NotImplementedError( + "DeepSeekV4 requires FP8 quantization of `attn.wo_a.weight`" + ) + def forward( self, positions: torch.Tensor, @@ -334,7 +346,7 @@ def forward( ) wo_a_fp8 = self.wo_a.weight - wo_a_scale = self.wo_a.weight_scale_inv + wo_a_scale = getattr(self.wo_a, self.wo_scale_name) z = torch.empty( (num_tokens, self.n_local_groups, self.o_lora_rank), @@ -1133,6 +1145,8 @@ def __init__( ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 + if not isinstance(self.weights_proj.quant_method, UnquantizedLinearMethod): + raise NotImplementedError("Quantization of `attn.indexer.weights_proj`") self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index cef4038dc2e6..34bb083bb4f3 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -143,6 +143,9 @@ def __init__(self, *args, **kwargs): # ``is_scale_e8m0`` is a property that resolves on first read, # by which time the current vllm_config has been set. + # implicitly ignored layers for DSV4 + self.ignored_layers += ["weights_proj", "fused_wkv_wgate"] + @property def expert_dtype(self) -> str: if self._resolved_expert_dtype is None: @@ -1004,7 +1007,14 @@ def __init__( prefix=f"{prefix}.wo_b", ) self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] + # scale_fmt is only used in the indexer (for C4A layers), not in + # the main attention. Default to "ue8m0" for compatibility. + self.scale_fmt = ( + config.quantization_config.get("scale_fmt", "ue8m0") + if hasattr(config, "quantization_config") + and isinstance(config.quantization_config, dict) + else "ue8m0" + ) self.rope_parameters = config.rope_scaling @@ -1442,7 +1452,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer))) if first_layer.ffn.use_mega_moe: return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts) - # Params for weights, fp8 weight scales, fp8 activation scales + # Params for unfused moe weights # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( self,