Skip to content
43 changes: 43 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,49 @@ def __init__(
)


class MergedReplicatedLinear(ReplicatedLinear):
"""
Replicated linear layer.
"""

def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_sizes: list[int] = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a replicated linear layer.

Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_sizes (list[int]): Number of output features list. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=sum(output_sizes),
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)


class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ColumnParallelLinear,
KVBatchLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear,
)
Expand Down Expand Up @@ -211,11 +212,11 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

if self.q_lora_rank is not None:
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
self.qkv_a_proj_with_mqa = ReplicatedLinear(
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
with_bias=False,
)

Expand Down
Loading