Skip to content

Commit 4b069b1

Browse files
committed
refactor linear
Signed-off-by: realliujiaxu <[email protected]>
1 parent 93e28e6 commit 4b069b1

File tree

1 file changed

+41
-47
lines changed

1 file changed

+41
-47
lines changed

vllm_ascend/ops/linear.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,11 @@ def __init__(
6363
*,
6464
return_bias: bool = True,
6565
):
66-
self.comm_group = None
67-
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
68-
self.comm_group = get_mlp_tp_group()
69-
else:
66+
# if self has attr `tp_size`, this means it has been customized by subclass
67+
if not hasattr(self, "tp_size"):
7068
self.comm_group = get_tp_group()
71-
72-
self.tp_size = self.comm_group.world_size
73-
self.tp_rank = self.comm_group.rank_in_group
69+
self.tp_size = self.comm_group.world_size
70+
self.tp_rank = self.comm_group.rank_in_group
7471

7572
self.input_size_per_partition = input_size
7673
self.output_size_per_partition = divide(output_size, self.tp_size)
@@ -81,6 +78,8 @@ def __init__(
8178
divide(output_size, self.tp_size)
8279
for output_size in self.output_sizes
8380
]
81+
# skip ColumnParallelLinear.__init__, as it will create weight_loader with default tp group
82+
# we will create weight_loader by customized comm group
8483
AscendLinearBase.__init__(self,
8584
input_size,
8685
output_size,
@@ -138,32 +137,35 @@ def __init__(
138137
*,
139138
return_bias: bool = True,
140139
):
141-
if prefix.find("down_proj") != -1 and mlp_tp_enable():
142-
comm_group = get_mlp_tp_group()
143-
self.forward_type = "mlp_tp"
144-
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
145-
comm_group = get_otp_group()
146-
self.forward_type = "oproj_tp"
147-
elif matmul_allreduce_enable():
148-
comm_group = get_tp_group()
149-
self.forward_type = "matmul_allreduce"
150-
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
151-
elif dense_optim_enable():
152-
comm_group = get_tp_group()
153-
self.forward_type = "dense_optim"
154-
else:
155-
comm_group = get_tp_group()
156-
self.forward_type = "normal"
157-
self.comm_group = comm_group
158-
159-
self.tp_size = self.comm_group.world_size
160-
self.tp_rank = self.comm_group.rank_in_group
140+
# if self has attr `tp_size`, this means it has been customized by subclass
141+
if not hasattr(self, "tp_size"):
142+
if prefix.find("down_proj") != -1 and mlp_tp_enable():
143+
comm_group = get_mlp_tp_group()
144+
self.forward_type = "mlp_tp"
145+
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
146+
comm_group = get_otp_group()
147+
self.forward_type = "oproj_tp"
148+
elif matmul_allreduce_enable():
149+
comm_group = get_tp_group()
150+
self.forward_type = "matmul_allreduce"
151+
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
152+
elif dense_optim_enable():
153+
comm_group = get_tp_group()
154+
self.forward_type = "dense_optim"
155+
else:
156+
comm_group = get_tp_group()
157+
self.forward_type = "normal"
158+
self.comm_group = comm_group
159+
160+
self.tp_size = self.comm_group.world_size
161+
self.tp_rank = self.comm_group.rank_in_group
161162

162163
# Divide the weight matrix along the first dimension.
163164
self.input_size_per_partition = divide(input_size, self.tp_size)
164165
self.output_size_per_partition = output_size
165166
self.output_partition_sizes = [output_size]
166167

168+
# skip RowParallelLinear.__init__, as it will create weight_loader with default tp group
167169
AscendLinearBase.__init__(self,
168170
input_size,
169171
output_size,
@@ -368,7 +370,6 @@ def _forward_dense_optim(
368370
return output
369371
return output, output_bias
370372

371-
372373
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
373374
"""Packed linear layers with column parallelism.
374375
@@ -526,7 +527,7 @@ def __init__(
526527
self.output_sizes = [
527528
self.num_heads * self.head_size * tp_size, # q_proj
528529
self.num_kv_heads * self.head_size * tp_size, # k_proj
529-
self.num_kv_heads * self.head_size * tp_size, # v_proj
530+
self.num_kv_heads * self.head_size * tp_size, # v_proj
530531
]
531532
AscendColumnParallelLinear.__init__(self,
532533
input_size=input_size,
@@ -593,22 +594,15 @@ def __init__(
593594
return_bias: bool = True,
594595
disable_tp: bool = False,
595596
):
596-
nn.Module.__init__(self)
597-
598-
# Keep input parameters
599-
self.input_size = input_size
600-
self.output_size = output_size
601-
self.skip_bias_add = skip_bias_add
602-
if params_dtype is None:
603-
params_dtype = torch.get_default_dtype()
604-
self.params_dtype = params_dtype
605-
self.quant_config = quant_config
606-
self.prefix = prefix
607-
if quant_config is None:
608-
self.quant_method: Optional[
609-
QuantizeMethodBase] = UnquantizedLinearMethod()
597+
if hasattr(self, "tp_rank") and hasattr(self, "tp_size"):
598+
tp_rank = self.tp_rank
599+
tp_size = self.tp_size
600+
super().__init__(input_size, output_size, skip_bias_add,
601+
params_dtype, quant_config, prefix,
602+
return_bias, disable_tp)
603+
self.tp_rank = tp_rank
604+
self.tp_size = tp_size
610605
else:
611-
self.quant_method = quant_config.get_quant_method(self,
612-
prefix=prefix)
613-
self.return_bias = return_bias
614-
self.disable_tp = disable_tp
606+
super().__init__(input_size, output_size, skip_bias_add,
607+
params_dtype, quant_config, prefix,
608+
return_bias, disable_tp)

0 commit comments

Comments
 (0)