@@ -363,7 +363,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
363363 [q_a_proj_scale , fused_a_scale ], dim = 0 )
364364
365365 module .weight_scale .data .copy_ (fused_a_scale )
366- # For DeepseekV32 with fuse_a_indexer_k_weight =True: kv_a_proj_with_mqa is oversized
366+ # For DeepseekV32 with fuse_a_indexer_k =True: kv_a_proj_with_mqa is oversized
367367 # to include indexer weights, which is filled in post_load_weights.
368368 module .weight .data [0 :fused_a .shape [0 ]].copy_ (fused_a )
369369 elif names [- 1 ] in params_map :
@@ -559,9 +559,9 @@ def __init__(
559559 # DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
560560 # TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
561561 if model_config .get_quant_config ().quant_algo == QuantAlgo .NVFP4 :
562- self .fuse_a_indexer_k_weight = True
562+ self .fuse_a_indexer_k = True
563563 else :
564- self .fuse_a_indexer_k_weight = False
564+ self .fuse_a_indexer_k = False
565565
566566 super ().__init__ (hidden_size = config .hidden_size ,
567567 num_attention_heads = config .num_attention_heads ,
@@ -586,13 +586,13 @@ def __init__(
586586
587587 self .indexer = self .mqa .indexer
588588
589- if self .fuse_a_indexer_k_weight :
589+ if self .fuse_a_indexer_k :
590590 # For DeepseekV32, the kv_a_proj_with_mqa includes:
591591 # q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
592592 self .kv_a_proj_with_mqa = DeepseekV3Linear (
593593 config .hidden_size ,
594594 self .kv_lora_rank + self .qk_rope_head_dim + self .q_lora_rank +
595- self .indexer .head_dim + self . indexer . n_heads ,
595+ self .indexer .head_dim ,
596596 bias = False ,
597597 dtype = config .torch_dtype ,
598598 quant_config = model_config .get_quant_config (),
@@ -601,21 +601,15 @@ def __init__(
601601 use_custom_cublas_mm = True )
602602
603603 def post_load_weights (self ):
604- if self .fuse_a_indexer_k_weight :
605- assert self .kv_a_proj_with_mqa .weight .data .dtype == self .indexer .wk .weight .data .dtype == self . indexer . weights_proj . weight . data . dtype , "all weights in kv_a_proj_with_mqa module must have matching dtype"
604+ if self .fuse_a_indexer_k :
605+ assert self .kv_a_proj_with_mqa .weight .data .dtype == self .indexer .wk .weight .data .dtype , "all weights in kv_a_proj_with_mqa module must have matching dtype"
606606 # Copy indexer weights into the fused kv_a_proj_with_mqa module
607607 indexer_wk_weight = self .indexer .wk .weight .data
608608 offset = self .kv_lora_rank + self .qk_rope_head_dim + self .q_lora_rank
609609 self .kv_a_proj_with_mqa .weight .data [offset :offset +
610610 self .indexer .head_dim ].copy_ (
611611 indexer_wk_weight )
612- offset += self .indexer .head_dim
613- indexer_weights_proj_weight = self .indexer .weights_proj .weight .data
614- self .kv_a_proj_with_mqa .weight .data [offset :offset +
615- self .indexer .n_heads ].copy_ (
616- indexer_weights_proj_weight )
617612 self .indexer .wk = None
618- self .indexer .weights_proj = None
619613
620614
621615class Deepseekv3RoutingImpl ():
0 commit comments