Skip to content

Commit 41aee08

Browse files
【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)
* support MergedReplicatedLinear * update MergedReplicatedLinear to support DSK_wint4 V1_load * update model name * update linear class * fix * fix v0 moe_bias load --------- Co-authored-by: bukejiyu <[email protected]>
1 parent b23fc65 commit 41aee08

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,76 @@ def __init__(
298298
)
299299

300300

301+
class MergedReplicatedLinear(ReplicatedLinear):
302+
"""
303+
MergedReplicatedLinear linear layer.
304+
"""
305+
306+
def __init__(
307+
self,
308+
fd_config: FDConfig,
309+
prefix: str = "",
310+
input_size: int = None,
311+
output_sizes: list[int] = None,
312+
with_bias: bool = False,
313+
add_bias: bool = False,
314+
skip_quant: bool = False,
315+
weight_dtype: str = "",
316+
weight_key: str = "",
317+
):
318+
"""
319+
Initializes a mergedreplicated linear layer.
320+
Args:
321+
fd_config (FDConfig): Inference-related parameters.
322+
prefix (str): Unique name of the layer, used to name internal attributes.
323+
Can be arbitrarily named.
324+
input_size (int): Number of input features. Defaults to None.
325+
output_sizes (list[int]): Number of output features list. Defaults to None.
326+
with_bias (bool): Whether to include bias or not. Defaults to False.
327+
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
328+
skip_quant (bool): Whether to skip quantization. Defaults to False.
329+
"""
330+
super().__init__(
331+
fd_config=fd_config,
332+
prefix=prefix,
333+
input_size=input_size,
334+
output_size=sum(output_sizes),
335+
with_bias=with_bias,
336+
add_bias=add_bias,
337+
skip_quant=skip_quant,
338+
weight_dtype=weight_dtype,
339+
weight_key=weight_key,
340+
)
341+
self.output_sizes = output_sizes
342+
343+
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
344+
model_format = getattr(param, "model_format", "")
345+
loaded_weight = get_tensor(loaded_weight)
346+
347+
if model_format == "torch":
348+
loaded_weight = loaded_weight.transpose([1, 0])
349+
350+
assert loaded_shard_id in ["q_a", "kv_a"]
351+
if not param._is_initialized():
352+
param.initialize()
353+
354+
if loaded_shard_id == "q_a":
355+
param_shard_offset = 0
356+
param_shard_size = self.output_sizes[0]
357+
else:
358+
# loaded_shard_id == "kv_a"
359+
param_shard_offset = self.output_sizes[0]
360+
param_shard_size = self.output_sizes[1]
361+
362+
if hasattr(param, "tensor_track"):
363+
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
364+
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
365+
assert param.shape == loaded_weight.shape, (
366+
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
367+
)
368+
param.copy_(loaded_weight, False)
369+
370+
301371
class ColumnParallelLinear(LinearBase):
302372
"""
303373
ColumnParallelLinear Layer.

fastdeploy/model_executor/layers/quantization/weight_only.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from fastdeploy import envs
2525
from fastdeploy.model_executor.layers.linear import (
2626
MergedColumnParallelLinear,
27+
MergedReplicatedLinear,
2728
QKVParallelLinear,
2829
)
2930
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
@@ -203,11 +204,15 @@ def create_weights(self, layer, **extra_weight_attrs):
203204
default_initializer=paddle.nn.initializer.Constant(0),
204205
)
205206
quant_attrs = extra_weight_attrs
206-
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
207+
if (
208+
isinstance(layer, MergedColumnParallelLinear)
209+
or isinstance(layer, QKVParallelLinear)
210+
or isinstance(layer, MergedReplicatedLinear)
211+
):
207212
quant_attrs = {
208213
**extra_weight_attrs,
209214
"tensor_track": TensorTracker(
210-
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
215+
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
211216
),
212217
}
213218
set_weight_attrs(

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ColumnParallelLinear,
3939
KVBatchLinear,
4040
MergedColumnParallelLinear,
41+
MergedReplicatedLinear,
4142
ReplicatedLinear,
4243
RowParallelLinear,
4344
)
@@ -169,6 +170,13 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
169170

170171
def load_state_dict(self, state_dict):
171172
""" """
173+
if self.experts.gate_correction_bias is not None:
174+
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
175+
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
176+
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
177+
self.experts.gate_correction_bias.shape
178+
)
179+
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
172180
self.gate.load_state_dict(state_dict)
173181
self.experts.load_state_dict(state_dict)
174182
self.shared_experts.load_state_dict(state_dict)
@@ -211,11 +219,11 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
211219

212220
if self.q_lora_rank is not None:
213221
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
214-
self.qkv_a_proj_with_mqa = ReplicatedLinear(
222+
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
215223
fd_config=fd_config,
216224
prefix=f"{prefix}.qkv_a_proj_with_mqa",
217225
input_size=self.hidden_size,
218-
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
226+
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
219227
with_bias=False,
220228
)
221229

@@ -636,6 +644,8 @@ def load_weights(self, weights_iterator) -> None:
636644
("embed_tokens.embeddings", "embed_tokens", None),
637645
("lm_head.linear", "lm_head", None),
638646
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
647+
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
648+
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
639649
]
640650
# (param_name, weight_name, expert_id, shard_id)
641651
expert_params_mapping = FusedMoE.make_expert_params_mapping(

tests/model_loader/test_common_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@
5858
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
5959
],
6060
},
61+
"DeepSeek-V3-0324": {
62+
"tensor_parallel_size": 2,
63+
"quantizations": [
64+
{
65+
"quant_type": "wint4",
66+
"env": {
67+
"FD_ATTENTION_BACKEND": "MLA_ATTN",
68+
"FLAGS_mla_use_tensorcore": "1",
69+
"FLAGS_flash_attn_version": "3",
70+
},
71+
},
72+
],
73+
},
6174
}
6275

6376

0 commit comments

Comments
 (0)