Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DeepseekV2DecoderLayer,
DeepseekV2MixtureOfExperts,
DeepseekV2MoE,
_try_load_fp8_indexer_wk,
get_spec_layer_idx_from_weight_name,
)
from .utils import maybe_prefix
Expand Down Expand Up @@ -190,10 +191,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
# Set MoE hyperparameters
self.set_moe_parameters()
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)

def set_moe_parameters(self):
self.expert_weights = []
Expand Down Expand Up @@ -248,13 +245,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]

if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)

expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self,
Expand All @@ -271,6 +267,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
_pending_wk_fp8: dict = {} # FP8 indexer wk dequant buffer
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand All @@ -281,6 +278,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name)

if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
):
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
Expand Down
123 changes: 70 additions & 53 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
scaled_dequantize,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sparse_attn_indexer import (
SparseAttnIndexer,
Expand Down Expand Up @@ -628,10 +632,6 @@ def __init__(
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
Expand All @@ -646,36 +646,16 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
if self.is_fp4_ckpt:
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized,
# so we run both with quant_config=None
# wk may be upcasted from the default quant;
# experiments show fusion is always faster unless WK proj is in FP4,
# which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
else:
self.wk = ReplicatedLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# FP8 wk weights are upcasted to BF16 during loading to maintain fusion.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5

Expand Down Expand Up @@ -716,14 +696,10 @@ def forward(
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
if self.is_fp4_ckpt:
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights = kw[:, self.head_dim :]
else:
k, _ = self.wk(hidden_states)
weights, _ = self.weights_proj(hidden_states)
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights = kw[:, self.head_dim :]

k = self.k_norm(k)
k_pe, k_nope = torch.split(
Expand Down Expand Up @@ -761,6 +737,46 @@ def forward(
return self.indexer_op(hidden_states, q_fp8, k, weights)


def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
"""
We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK
weights and scale, and when both are available, dequantizes to BF16 and stores into
the fused wk_weights_proj.weight parameter.
"""
if "indexer.wk." not in name or "wk_weights" in name:
return False # Weight is not an isolated WK weight for the indexer, ignore.
is_weight = name.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn
is_scale = "weight_scale_inv" in name
if not is_weight and not is_scale:
return False # WK is not in FP8 format, ignore.
# Buffer this tensor (weight or scale) until both have arrived.
layer_prefix = name.rsplit(".wk.", 1)[0] # e.g. "model.layers.0.self_attn.indexer"
entry = buf.setdefault(layer_prefix, {})
entry["weight" if is_weight else "scale"] = tensor
if "weight" not in entry or "scale" not in entry:
return True # still waiting for the other param
Comment on lines +758 to +759
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to buffer tensors until both weight and scale are present is susceptible to memory leaks if the checkpoint loading process is interrupted or if certain keys are missing from the checkpoint. Consider adding a timeout or a mechanism to clear the buffer if loading fails.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They would go out of scope once the weight loading function returns. The references are not stored on self.


# We have both weight and scale: dequantize FP8 to BF16.
weight_fp8, scale_inv = entry["weight"], entry["scale"]
del buf[layer_prefix]
block_size = weight_fp8.shape[1] // scale_inv.shape[1]
weight_bf16 = scaled_dequantize(
weight_fp8,
scale_inv,
group_shape=GroupShape(block_size, block_size),
out_dtype=torch.bfloat16,
)

# Load the dequantized weight into shard 0 of the fused buffer.
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
param = params_dict[fused_name]
param.weight_loader(param, weight_bf16, 0)
loaded_params.add(fused_name)
return True


def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -1344,10 +1360,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)

qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
Expand Down Expand Up @@ -1473,13 +1485,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
_pending_wk_fp8: dict = {} # When WK is in FP8, we dequant to BF16 for fusion
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)

if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
Expand Down Expand Up @@ -1516,6 +1528,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)

if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
):
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
Expand Down
Loading