diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 126efb6f88e..a66ec7aa3e6 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -30,6 +30,7 @@ DeepseekV2DecoderLayer, DeepseekV2MixtureOfExperts, DeepseekV2MoE, + _try_load_fp8_indexer_wk, get_spec_layer_idx_from_weight_name, ) from .utils import maybe_prefix @@ -190,10 +191,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) # Set MoE hyperparameters self.set_moe_parameters() - self.is_fp4_ckpt = ( - self.quant_config is not None - and self.quant_config.get_name() == "modelopt_fp4" - ) def set_moe_parameters(self): self.expert_weights = [] @@ -248,13 +245,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) - indexer_fused_mapping = [ - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), - ] - stacked_params_mapping.extend(indexer_fused_mapping) + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( self, @@ -271,6 +267,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + _pending_wk_fp8: dict = {} # FP8 indexer wk dequant buffer for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -281,6 +278,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) ) name = self._rewrite_spec_layer_name(spec_layer, name) + + if _try_load_fp8_indexer_wk( + name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params + ): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 17ddd5edece..cd28fb0192f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -66,6 +66,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + scaled_dequantize, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sparse_attn_indexer import ( SparseAttnIndexer, @@ -628,10 +632,6 @@ def __init__( self.vllm_config = vllm_config self.config = config self.quant_config = quant_config - self.is_fp4_ckpt = ( - self.quant_config is not None - and self.quant_config.get_name() == "modelopt_fp4" - ) # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.index_topk self.n_head = config.index_n_heads # 64 @@ -646,36 +646,16 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.wq_b", ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. - # weights_proj does not get quantized, - # so we run both with quant_config=None - # wk may be upcasted from the default quant; - # experiments show fusion is always faster unless WK proj is in FP4, - # which is not the case for all known quants. - self.wk_weights_proj = MergedColumnParallelLinear( - hidden_size, - [self.head_dim, self.n_head], - bias=False, - quant_config=None, - disable_tp=True, - prefix=f"{prefix}.wk_weights_proj", - ) - else: - self.wk = ReplicatedLinear( - hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - ) - self.weights_proj = ReplicatedLinear( - hidden_size, - self.n_head, - bias=False, - quant_config=None, - prefix=f"{prefix}.weights_proj", - ) + # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. + # FP8 wk weights are upcasted to BF16 during loading to maintain fusion. + self.wk_weights_proj = MergedColumnParallelLinear( + hidden_size, + [self.head_dim, self.n_head], + bias=False, + quant_config=None, + disable_tp=True, + prefix=f"{prefix}.wk_weights_proj", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 @@ -716,14 +696,10 @@ def forward( q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - else: - k, _ = self.wk(hidden_states) - weights, _ = self.weights_proj(hidden_states) + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] k = self.k_norm(k) k_pe, k_nope = torch.split( @@ -761,6 +737,46 @@ def forward( return self.indexer_op(hidden_states, q_fp8, k, weights) +def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params): + """ + We fuse the WK and weights_proj projections, but in some checkpoints WK is stored + in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16. + Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK + weights and scale, and when both are available, dequantizes to BF16 and stores into + the fused wk_weights_proj.weight parameter. + """ + if "indexer.wk." not in name or "wk_weights" in name: + return False # Weight is not an isolated WK weight for the indexer, ignore. + is_weight = name.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn + is_scale = "weight_scale_inv" in name + if not is_weight and not is_scale: + return False # WK is not in FP8 format, ignore. + # Buffer this tensor (weight or scale) until both have arrived. + layer_prefix = name.rsplit(".wk.", 1)[0] # e.g. "model.layers.0.self_attn.indexer" + entry = buf.setdefault(layer_prefix, {}) + entry["weight" if is_weight else "scale"] = tensor + if "weight" not in entry or "scale" not in entry: + return True # still waiting for the other param + + # We have both weight and scale: dequantize FP8 to BF16. + weight_fp8, scale_inv = entry["weight"], entry["scale"] + del buf[layer_prefix] + block_size = weight_fp8.shape[1] // scale_inv.shape[1] + weight_bf16 = scaled_dequantize( + weight_fp8, + scale_inv, + group_shape=GroupShape(block_size, block_size), + out_dtype=torch.bfloat16, + ) + + # Load the dequantized weight into shard 0 of the fused buffer. + fused_name = f"{layer_prefix}.wk_weights_proj.weight" + param = params_dict[fused_name] + param.weight_loader(param, weight_bf16, 0) + loaded_params.add(fused_name) + return True + + def _min_latency_fused_qkv_a_proj_impl( input_: torch.Tensor, weight: torch.Tensor, @@ -1344,10 +1360,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.is_fp4_ckpt = ( - self.quant_config is not None - and self.quant_config.get_name() == "modelopt_fp4" - ) qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) @@ -1473,13 +1485,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) - indexer_fused_mapping = [ - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), - ] - stacked_params_mapping.extend(indexer_fused_mapping) + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + _pending_wk_fp8: dict = {} # When WK is in FP8, we dequant to BF16 for fusion + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) if self.use_mha: stacked_params_mapping.extend(mha_params_mapping) @@ -1516,6 +1528,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) ) + if _try_load_fp8_indexer_wk( + name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params + ): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: