Skip to content

Commit e84d056

Browse files
committed
Unfuse weight_proj and promote to fp32
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent fc088e6 commit e84d056

File tree

5 files changed

+1005
-25
lines changed

5 files changed

+1005
-25
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def __init__(self,
715715
self.hidden_size,
716716
self.n_heads,
717717
bias=False,
718-
dtype=dtype,
718+
dtype=torch.float32,
719719
quant_config=None,
720720
skip_create_weights_in_init=skip_create_weights_in_init,
721721
use_custom_cublas_mm=True)
@@ -1234,18 +1234,15 @@ def sparse_attn_indexer(
12341234
return topk_indices_buffer
12351235

12361236
def weight_scale(self, hidden_states: torch.Tensor,
1237-
indexer_weights: Optional[torch.Tensor],
12381237
q_scale: torch.Tensor) -> torch.Tensor:
1239-
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
1240-
hidden_states)
1238+
weights = self.weights_proj(hidden_states.float())
12411239
weights = _scale(weights, q_scale, self.weight_scale_factor)
12421240
return weights
12431241

12441242
@torch.inference_mode()
12451243
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
12461244
metadata: DSAtrtllmAttentionMetadata,
1247-
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor],
1248-
indexer_weights: Optional[torch.Tensor]):
1245+
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor]):
12491246
quant_block_size = metadata.kv_cache_manager.quant_block_size
12501247
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
12511248

@@ -1308,7 +1305,7 @@ def _prep_q_or_k(qk_pe, qk_nope):
13081305
q_scale = q_scale.view(-1, self.n_heads, 1)
13091306

13101307
weights, _ = maybe_execute_in_parallel(
1311-
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
1308+
lambda: self.weight_scale(hidden_states, q_scale),
13121309
lambda: self._update_k_cache(
13131310
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
13141311
self.ln_events[0],

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
363363
[q_a_proj_scale, fused_a_scale], dim=0)
364364

365365
module.weight_scale.data.copy_(fused_a_scale)
366-
# For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized
366+
# For DeepseekV32 with fuse_a_indexer_k=True: kv_a_proj_with_mqa is oversized
367367
# to include indexer weights, which is filled in post_load_weights.
368368
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
369369
elif names[-1] in params_map:
@@ -559,9 +559,9 @@ def __init__(
559559
# DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
560560
# TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
561561
if model_config.get_quant_config().quant_algo == QuantAlgo.NVFP4:
562-
self.fuse_a_indexer_k_weight = True
562+
self.fuse_a_indexer_k = True
563563
else:
564-
self.fuse_a_indexer_k_weight = False
564+
self.fuse_a_indexer_k = False
565565

566566
super().__init__(hidden_size=config.hidden_size,
567567
num_attention_heads=config.num_attention_heads,
@@ -586,13 +586,13 @@ def __init__(
586586

587587
self.indexer = self.mqa.indexer
588588

589-
if self.fuse_a_indexer_k_weight:
589+
if self.fuse_a_indexer_k:
590590
# For DeepseekV32, the kv_a_proj_with_mqa includes:
591591
# q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
592592
self.kv_a_proj_with_mqa = DeepseekV3Linear(
593593
config.hidden_size,
594594
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
595-
self.indexer.head_dim + self.indexer.n_heads,
595+
self.indexer.head_dim,
596596
bias=False,
597597
dtype=config.torch_dtype,
598598
quant_config=model_config.get_quant_config(),
@@ -601,21 +601,15 @@ def __init__(
601601
use_custom_cublas_mm=True)
602602

603603
def post_load_weights(self):
604-
if self.fuse_a_indexer_k_weight:
605-
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype == self.indexer.weights_proj.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
604+
if self.fuse_a_indexer_k:
605+
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
606606
# Copy indexer weights into the fused kv_a_proj_with_mqa module
607607
indexer_wk_weight = self.indexer.wk.weight.data
608608
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
609609
self.kv_a_proj_with_mqa.weight.data[offset:offset +
610610
self.indexer.head_dim].copy_(
611611
indexer_wk_weight)
612-
offset += self.indexer.head_dim
613-
indexer_weights_proj_weight = self.indexer.weights_proj.weight.data
614-
self.kv_a_proj_with_mqa.weight.data[offset:offset +
615-
self.indexer.n_heads].copy_(
616-
indexer_weights_proj_weight)
617612
self.indexer.wk = None
618-
self.indexer.weights_proj = None
619613

620614

621615
class Deepseekv3RoutingImpl():

0 commit comments

Comments
 (0)