Skip to content

Commit 10aa379

Browse files
[fix] Fix DeepSeek w4a8 weight loading
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent baece56 commit 10aa379

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# SOFTWARE.
2626
# --------------------------------------------------
2727

28+
import copy
2829
import math
2930
import os
3031
import warnings
@@ -1102,6 +1103,18 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
11021103
PretrainedConfig]):
11031104

11041105
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
1106+
# Rename some keys of quant_config_dict to support legacy checkpoints
1107+
if model_config.quant_config_dict is not None:
1108+
model_config = copy.deepcopy(model_config)
1109+
quant_config_dict = {}
1110+
for key, val in model_config.quant_config_dict.items():
1111+
key_split = key.split(".")
1112+
if key_split[-1] == "fused_a":
1113+
key = ".".join(key_split[:-1] + ["kv_a_proj_with_mqa"])
1114+
quant_config_dict[key] = val
1115+
model_config._frozen = False
1116+
model_config.quant_config_dict = quant_config_dict
1117+
model_config._frozen = True
11051118
super().__init__(DeepseekV3Model(model_config),
11061119
config=model_config,
11071120
hidden_size=model_config.pretrained_config.hidden_size,

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,11 @@ def __post_init__(self):
458458
if name + '.q_proj' in n:
459459
module.quant_config = q
460460
break
461-
elif hasattr(module, 'fused_a'):
461+
elif hasattr(module, 'kv_a_proj_with_mqa'):
462462
# DeepseekV3Attention
463463
for n, q in quant_config_dict.items():
464464
# reuse q_proj quant config as the attention quant config
465-
if name + '.fused_a' in n:
465+
if name + '.kv_a_proj_with_mqa' in n:
466466
module.quant_config = q
467467
break
468468

0 commit comments

Comments
 (0)