Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.model = DeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
# Set MoE hyperparameters
self.set_moe_parameters()
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)

def set_moe_parameters(self):
self.expert_weights = []
Expand Down Expand Up @@ -241,11 +246,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# Fused indexer wk + weights_proj
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]

if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this guards against cases where we know the wk_weights_proj.wk and wk_weights_proj.weights_proj don't exist. Would it be more brittle to only put them in the mapping when we know they do exist instead?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is wk and weight_proj always exist. But on fp8 checkpoint one of them is fp8 quanted and another one is not quanted so it 's not good to fuse them. On the other hand, on nvfp4, they are all unquanted so we can fuse them into a single GEMM. This is just separate fuse and un-fuse path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the explanation


expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
Expand Down
79 changes: 55 additions & 24 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@ def __init__(
super().__init__()
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
Expand All @@ -639,18 +644,36 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized, so we run both with quant_config=None
# wk may be upcasted from the default quant; experiments show fusion is always
# faster unless WK proj is in FP4, which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
if self.is_fp4_ckpt:
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized,
# so we run both with quant_config=None
# wk may be upcasted from the default quant;
# experiments show fusion is always faster unless WK proj is in FP4,
# which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
else:
self.wk = ReplicatedLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5

Expand Down Expand Up @@ -691,11 +714,14 @@ def forward(
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)

# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
if self.is_fp4_ckpt:
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights = kw[:, self.head_dim :]
else:
k, _ = self.wk(hidden_states)
weights, _ = self.weights_proj(hidden_states)

k = self.k_norm(k)
k_pe, k_nope = torch.split(
Expand Down Expand Up @@ -726,7 +752,7 @@ def forward(
q_scale = q_scale.view(-1, self.n_head, 1)

weights = (
weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)

Expand Down Expand Up @@ -1314,6 +1340,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)

qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
Expand Down Expand Up @@ -1439,12 +1469,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)

if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
Expand Down
Loading