Skip to content
70 changes: 70 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,76 @@ def __init__(
)


class MergedReplicatedLinear(ReplicatedLinear):
"""
MergedReplicatedLinear linear layer.
"""

def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_sizes: list[int] = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a mergedreplicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_sizes (list[int]): Number of output features list. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=sum(output_sizes),
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.output_sizes = output_sizes

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
loaded_weight = get_tensor(loaded_weight)

if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])

assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
param.initialize()

if loaded_shard_id == "q_a":
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
else:
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]

if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)


class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.
Expand Down
9 changes: 7 additions & 2 deletions fastdeploy/model_executor/layers/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
Expand Down Expand Up @@ -203,11 +204,15 @@ def create_weights(self, layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
if (
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
),
}
set_weight_attrs(
Expand Down
14 changes: 12 additions & 2 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ColumnParallelLinear,
KVBatchLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear,
)
Expand Down Expand Up @@ -169,6 +170,13 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:

def load_state_dict(self, state_dict):
""" """
if self.experts.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
self.experts.gate_correction_bias.shape
)
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
Expand Down Expand Up @@ -211,11 +219,11 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

if self.q_lora_rank is not None:
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
self.qkv_a_proj_with_mqa = ReplicatedLinear(
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
with_bias=False,
)

Expand Down Expand Up @@ -636,6 +644,8 @@ def load_weights(self, weights_iterator) -> None:
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
Expand Down
13 changes: 13 additions & 0 deletions tests/model_loader/test_common_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
],
},
"DeepSeek-V3-0324": {
"tensor_parallel_size": 2,
"quantizations": [
{
"quant_type": "wint4",
"env": {
"FD_ATTENTION_BACKEND": "MLA_ATTN",
"FLAGS_mla_use_tensorcore": "1",
"FLAGS_flash_attn_version": "3",
},
},
],
},
}


Expand Down
Loading