Skip to content

Commit 8b5b161

Browse files
[fix] Fix DeepSeek w4a8 weight loading
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent baece56 commit 8b5b161

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# SOFTWARE.
2626
# --------------------------------------------------
2727

28+
import copy
2829
import math
2930
import os
3031
import warnings
@@ -1102,6 +1103,17 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
11021103
PretrainedConfig]):
11031104

11041105
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
1106+
# Rename some keys of quant_config_dict to support legacy checkpoints
1107+
model_config = copy.deepcopy(model_config)
1108+
quant_config_dict = {}
1109+
for key, val in model_config.quant_config_dict.items():
1110+
key_split = key.split(".")
1111+
if key_split[-1] == "fused_a":
1112+
key = ".".join(key_split[:-1] + ["kv_a_proj_with_mqa"])
1113+
quant_config_dict[key] = val
1114+
model_config._frozen = False
1115+
model_config.quant_config_dict = quant_config_dict
1116+
model_config._frozen = True
11051117
super().__init__(DeepseekV3Model(model_config),
11061118
config=model_config,
11071119
hidden_size=model_config.pretrained_config.hidden_size,

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,11 @@ def __post_init__(self):
458458
if name + '.q_proj' in n:
459459
module.quant_config = q
460460
break
461-
elif hasattr(module, 'fused_a'):
461+
elif hasattr(module, 'kv_a_proj_with_mqa'):
462462
# DeepseekV3Attention
463463
for n, q in quant_config_dict.items():
464464
# reuse q_proj quant config as the attention quant config
465-
if name + '.fused_a' in n:
465+
if name + '.kv_a_proj_with_mqa' in n:
466466
module.quant_config = q
467467
break
468468

0 commit comments

Comments
 (0)