Skip to content

Commit e665f8e

Browse files
committed
rename and fix ut
Signed-off-by: realliujiaxu <[email protected]>
1 parent 66a7d6e commit e665f8e

File tree

1 file changed

+52
-53
lines changed

1 file changed

+52
-53
lines changed

vllm_ascend/ops/linear.py

Lines changed: 52 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
_HCOMM_INFO = None
4242

4343

44-
class CustomTensorParallelBase:
44+
class CustomTensorParallelOp:
4545

4646
def __init__(self, layer):
4747
self.layer = layer
@@ -73,7 +73,7 @@ def apply(self, input_):
7373
raise NotImplementedError
7474

7575

76-
class CustomColumnParallel(CustomTensorParallelBase):
76+
class CustomColumnParallelOp(CustomTensorParallelOp):
7777

7878
def __init__(self, layer):
7979
super().__init__(layer)
@@ -84,7 +84,7 @@ def after_create_weights_hook(self):
8484
self.gather_output = self.layer.gather_output
8585

8686

87-
class CustomRowParallel(CustomTensorParallelBase):
87+
class CustomRowParallelOp(CustomTensorParallelOp):
8888

8989
def __init__(self, layer):
9090
super().__init__(layer)
@@ -99,7 +99,7 @@ def after_create_weights_hook(self):
9999
self.input_size_per_partition = self.layer.input_size_per_partition
100100

101101

102-
class MLPCustomColumnParallel(CustomColumnParallel):
102+
class MLPColumnParallelOp(CustomColumnParallelOp):
103103

104104
def __init__(self, layer):
105105
super().__init__(layer)
@@ -124,7 +124,7 @@ def apply(
124124
return output, output_bias
125125

126126

127-
class DenseOptimMergedColumnParallel(CustomColumnParallel):
127+
class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
128128

129129
def apply(
130130
self, input_: torch.Tensor
@@ -154,7 +154,7 @@ def apply(
154154
return output, output_bias
155155

156156

157-
class DenseOptimQKVParallelLinear(CustomColumnParallel):
157+
class DenseOptimQKVParallelOp(CustomColumnParallelOp):
158158

159159
def __init__(self, layer, prefix):
160160
super().__init__(layer)
@@ -191,33 +191,33 @@ def apply(
191191
return output, output_bias
192192

193193

194-
def get_custom_tp_group_column(
194+
def get_column_parallel_op(
195195
disable_tp, prefix, layer
196196
) -> Tuple[
197-
Optional[Union[MLPCustomColumnParallel, DenseOptimMergedColumnParallel,
198-
DenseOptimQKVParallelLinear]], int, int]:
197+
Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp,
198+
DenseOptimQKVParallelOp]], int, int]:
199199
if disable_tp:
200200
return None, 0, 1
201201

202-
custom_tp_group: Optional[Union[
203-
MLPCustomColumnParallel,
204-
DenseOptimMergedColumnParallel,
205-
DenseOptimQKVParallelLinear,
202+
custom_op: Optional[Union[
203+
MLPColumnParallelOp,
204+
DenseOptimMergedColumnParallelOp,
205+
DenseOptimQKVParallelOp,
206206
]] = None
207207
if "gate_up_proj" in prefix and mlp_tp_enable():
208-
custom_tp_group = MLPCustomColumnParallel(layer)
208+
custom_op = MLPColumnParallelOp(layer)
209209
elif "gate_up_proj" in prefix and dense_optim_enable():
210-
custom_tp_group = DenseOptimMergedColumnParallel(layer)
210+
custom_op = DenseOptimMergedColumnParallelOp(layer)
211211
elif dense_optim_enable():
212-
custom_tp_group = DenseOptimQKVParallelLinear(layer, prefix)
212+
custom_op = DenseOptimQKVParallelOp(layer, prefix)
213213

214-
if custom_tp_group is not None:
215-
return custom_tp_group, custom_tp_group.tp_rank, custom_tp_group.tp_size
214+
if custom_op is not None:
215+
return custom_op, custom_op.tp_rank, custom_op.tp_size
216216

217217
return None, get_tp_group().rank_in_group, get_tp_group().world_size
218218

219219

220-
class MLPCustomRowParallel(CustomRowParallel):
220+
class MLPRowParallelOp(CustomRowParallelOp):
221221

222222
def __init__(self, layer):
223223
super().__init__(layer)
@@ -250,7 +250,7 @@ def apply(
250250
return output, output_bias
251251

252252

253-
class OProjCustomRowParallel(CustomRowParallel):
253+
class OProjRowParallelOp(CustomRowParallelOp):
254254

255255
def __init__(self, layer):
256256
super().__init__(layer)
@@ -312,10 +312,10 @@ def apply(
312312
def after_create_weights_hook(self):
313313
super().after_create_weights_hook()
314314
self.input_is_parallel = self.layer.input_is_parallel
315-
self.input_size_per_partition = self.layer.input_is_parallel
315+
self.input_size_per_partition = self.layer.input_size_per_partition
316316

317317

318-
class MatmulAllreduceCustomRowParallel(CustomTensorParallelBase):
318+
class MatmulAllreduceRowParallelOp(CustomTensorParallelOp):
319319

320320
def __init__(self, layer):
321321
super().__init__(layer)
@@ -372,7 +372,7 @@ def after_create_weights_hook(self):
372372
self.weight_t = self.layer.weight.t()
373373

374374

375-
class DenseOptimCustomRowParallel(CustomRowParallel):
375+
class DenseOptimRowParallelOp(CustomRowParallelOp):
376376

377377
def __init__(self, layer, prefix):
378378
super().__init__(layer)
@@ -418,29 +418,28 @@ def after_create_weights_hook(self):
418418
self.reduce_results = self.layer.reduce_results
419419

420420

421-
def get_custom_tp_group_row(
421+
def get_row_parallel_op(
422422
disable_tp, prefix, layer
423-
) -> Tuple[Optional[Union[MLPCustomRowParallel, OProjCustomRowParallel,
424-
MatmulAllreduceCustomRowParallel,
425-
DenseOptimCustomRowParallel]], int, int]:
423+
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
424+
MatmulAllreduceRowParallelOp,
425+
DenseOptimRowParallelOp]], int, int]:
426426
if disable_tp:
427427
return None, 0, 1
428428

429-
custom_tp_group: Optional[Union[MLPCustomRowParallel,
430-
OProjCustomRowParallel,
431-
MatmulAllreduceCustomRowParallel,
432-
DenseOptimCustomRowParallel]] = None
429+
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
430+
MatmulAllreduceRowParallelOp,
431+
DenseOptimRowParallelOp]] = None
433432
if prefix.find("down_proj") != -1 and mlp_tp_enable():
434-
custom_tp_group = MLPCustomRowParallel(layer)
433+
custom_op = MLPRowParallelOp(layer)
435434
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
436-
custom_tp_group = OProjCustomRowParallel(layer)
435+
custom_op = OProjRowParallelOp(layer)
437436
elif matmul_allreduce_enable():
438-
custom_tp_group = MatmulAllreduceCustomRowParallel(layer)
437+
custom_op = MatmulAllreduceRowParallelOp(layer)
439438
elif dense_optim_enable():
440-
custom_tp_group = DenseOptimCustomRowParallel(layer, prefix)
439+
custom_op = DenseOptimRowParallelOp(layer, prefix)
441440

442-
if custom_tp_group is not None:
443-
return custom_tp_group, custom_tp_group.tp_rank, custom_tp_group.tp_size
441+
if custom_op is not None:
442+
return custom_op, custom_op.tp_rank, custom_op.tp_size
444443

445444
return None, get_tp_group().rank_in_group, get_tp_group().world_size
446445

@@ -467,7 +466,7 @@ def __init__(
467466
return_bias: bool = True,
468467
disable_tp: bool = False,
469468
):
470-
self.custom_group, self.tp_rank, self.tp_size = get_custom_tp_group_column(
469+
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
471470
disable_tp, prefix, self)
472471

473472
self.input_size_per_partition = input_size
@@ -516,15 +515,15 @@ def __init__(
516515
else:
517516
self.register_parameter("bias", None)
518517

519-
if self.custom_group is not None:
520-
self.custom_group.after_create_weights_hook()
518+
if self.custom_op is not None:
519+
self.custom_op.after_create_weights_hook()
521520

522521
def forward(
523522
self,
524523
input_,
525524
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
526-
if self.custom_group is not None:
527-
return self.custom_group.apply(input_)
525+
if self.custom_op is not None:
526+
return self.custom_op.apply(input_)
528527

529528
return super().forward(input_)
530529

@@ -550,7 +549,7 @@ def __init__(
550549
return_bias: bool = True,
551550
disable_tp: bool = False,
552551
):
553-
self.custom_group, self.tp_rank, self.tp_size = get_custom_tp_group_row(
552+
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
554553
disable_tp, prefix, self)
555554

556555
# Divide the weight matrix along the first dimension.
@@ -596,16 +595,16 @@ def __init__(
596595
else:
597596
self.register_parameter("bias", None)
598597

599-
if self.custom_group is not None:
600-
self.custom_group.after_create_weights_hook()
598+
if self.custom_op is not None:
599+
self.custom_op.after_create_weights_hook()
601600

602601
def forward(
603602
self,
604603
input_,
605604
is_prefill: bool = True,
606605
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
607-
if self.custom_group is not None:
608-
return self.custom_group.apply(input_)
606+
if self.custom_op is not None:
607+
return self.custom_op.apply(input_)
609608

610609
return super().forward(input_)
611610

@@ -635,7 +634,7 @@ def __init__(
635634
return_bias: bool = True,
636635
disable_tp: bool = False,
637636
):
638-
self.custom_group, self.tp_rank, self.tp_size = get_custom_tp_group_column(
637+
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
639638
disable_tp, prefix, self)
640639

641640
self.output_sizes = output_sizes
@@ -657,8 +656,8 @@ def forward(
657656
self,
658657
input_,
659658
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
660-
if self.custom_group is not None:
661-
return self.custom_group.apply(input_)
659+
if self.custom_op is not None:
660+
return self.custom_op.apply(input_)
662661

663662
return super().forward(input_)
664663

@@ -689,7 +688,7 @@ def __init__(
689688
return_bias: bool = True,
690689
disable_tp: bool = False,
691690
):
692-
self.custom_group, _, tp_size = get_custom_tp_group_column(
691+
self.custom_op, _, tp_size = get_column_parallel_op(
693692
disable_tp, prefix, self)
694693
self.hidden_size = hidden_size
695694
self.head_size = head_size
@@ -730,8 +729,8 @@ def forward(
730729
self,
731730
input_,
732731
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
733-
if self.custom_group is not None:
734-
return self.custom_group.apply(input_)
732+
if self.custom_op is not None:
733+
return self.custom_op.apply(input_)
735734

736735
return super().forward(input_)
737736

0 commit comments

Comments
 (0)