4141_HCOMM_INFO = None
4242
4343
44- class CustomTensorParallelBase :
44+ class CustomTensorParallelOp :
4545
4646 def __init__ (self , layer ):
4747 self .layer = layer
@@ -73,7 +73,7 @@ def apply(self, input_):
7373 raise NotImplementedError
7474
7575
76- class CustomColumnParallel ( CustomTensorParallelBase ):
76+ class CustomColumnParallelOp ( CustomTensorParallelOp ):
7777
7878 def __init__ (self , layer ):
7979 super ().__init__ (layer )
@@ -84,7 +84,7 @@ def after_create_weights_hook(self):
8484 self .gather_output = self .layer .gather_output
8585
8686
87- class CustomRowParallel ( CustomTensorParallelBase ):
87+ class CustomRowParallelOp ( CustomTensorParallelOp ):
8888
8989 def __init__ (self , layer ):
9090 super ().__init__ (layer )
@@ -99,7 +99,7 @@ def after_create_weights_hook(self):
9999 self .input_size_per_partition = self .layer .input_size_per_partition
100100
101101
102- class MLPCustomColumnParallel ( CustomColumnParallel ):
102+ class MLPColumnParallelOp ( CustomColumnParallelOp ):
103103
104104 def __init__ (self , layer ):
105105 super ().__init__ (layer )
@@ -124,7 +124,7 @@ def apply(
124124 return output , output_bias
125125
126126
127- class DenseOptimMergedColumnParallel ( CustomColumnParallel ):
127+ class DenseOptimMergedColumnParallelOp ( CustomColumnParallelOp ):
128128
129129 def apply (
130130 self , input_ : torch .Tensor
@@ -154,7 +154,7 @@ def apply(
154154 return output , output_bias
155155
156156
157- class DenseOptimQKVParallelLinear ( CustomColumnParallel ):
157+ class DenseOptimQKVParallelOp ( CustomColumnParallelOp ):
158158
159159 def __init__ (self , layer , prefix ):
160160 super ().__init__ (layer )
@@ -191,33 +191,33 @@ def apply(
191191 return output , output_bias
192192
193193
194- def get_custom_tp_group_column (
194+ def get_column_parallel_op (
195195 disable_tp , prefix , layer
196196) -> Tuple [
197- Optional [Union [MLPCustomColumnParallel , DenseOptimMergedColumnParallel ,
198- DenseOptimQKVParallelLinear ]], int , int ]:
197+ Optional [Union [MLPColumnParallelOp , DenseOptimMergedColumnParallelOp ,
198+ DenseOptimQKVParallelOp ]], int , int ]:
199199 if disable_tp :
200200 return None , 0 , 1
201201
202- custom_tp_group : Optional [Union [
203- MLPCustomColumnParallel ,
204- DenseOptimMergedColumnParallel ,
205- DenseOptimQKVParallelLinear ,
202+ custom_op : Optional [Union [
203+ MLPColumnParallelOp ,
204+ DenseOptimMergedColumnParallelOp ,
205+ DenseOptimQKVParallelOp ,
206206 ]] = None
207207 if "gate_up_proj" in prefix and mlp_tp_enable ():
208- custom_tp_group = MLPCustomColumnParallel (layer )
208+ custom_op = MLPColumnParallelOp (layer )
209209 elif "gate_up_proj" in prefix and dense_optim_enable ():
210- custom_tp_group = DenseOptimMergedColumnParallel (layer )
210+ custom_op = DenseOptimMergedColumnParallelOp (layer )
211211 elif dense_optim_enable ():
212- custom_tp_group = DenseOptimQKVParallelLinear (layer , prefix )
212+ custom_op = DenseOptimQKVParallelOp (layer , prefix )
213213
214- if custom_tp_group is not None :
215- return custom_tp_group , custom_tp_group .tp_rank , custom_tp_group .tp_size
214+ if custom_op is not None :
215+ return custom_op , custom_op .tp_rank , custom_op .tp_size
216216
217217 return None , get_tp_group ().rank_in_group , get_tp_group ().world_size
218218
219219
220- class MLPCustomRowParallel ( CustomRowParallel ):
220+ class MLPRowParallelOp ( CustomRowParallelOp ):
221221
222222 def __init__ (self , layer ):
223223 super ().__init__ (layer )
@@ -250,7 +250,7 @@ def apply(
250250 return output , output_bias
251251
252252
253- class OProjCustomRowParallel ( CustomRowParallel ):
253+ class OProjRowParallelOp ( CustomRowParallelOp ):
254254
255255 def __init__ (self , layer ):
256256 super ().__init__ (layer )
@@ -312,10 +312,10 @@ def apply(
312312 def after_create_weights_hook (self ):
313313 super ().after_create_weights_hook ()
314314 self .input_is_parallel = self .layer .input_is_parallel
315- self .input_size_per_partition = self .layer .input_is_parallel
315+ self .input_size_per_partition = self .layer .input_size_per_partition
316316
317317
318- class MatmulAllreduceCustomRowParallel ( CustomTensorParallelBase ):
318+ class MatmulAllreduceRowParallelOp ( CustomTensorParallelOp ):
319319
320320 def __init__ (self , layer ):
321321 super ().__init__ (layer )
@@ -372,7 +372,7 @@ def after_create_weights_hook(self):
372372 self .weight_t = self .layer .weight .t ()
373373
374374
375- class DenseOptimCustomRowParallel ( CustomRowParallel ):
375+ class DenseOptimRowParallelOp ( CustomRowParallelOp ):
376376
377377 def __init__ (self , layer , prefix ):
378378 super ().__init__ (layer )
@@ -418,29 +418,28 @@ def after_create_weights_hook(self):
418418 self .reduce_results = self .layer .reduce_results
419419
420420
421- def get_custom_tp_group_row (
421+ def get_row_parallel_op (
422422 disable_tp , prefix , layer
423- ) -> Tuple [Optional [Union [MLPCustomRowParallel , OProjCustomRowParallel ,
424- MatmulAllreduceCustomRowParallel ,
425- DenseOptimCustomRowParallel ]], int , int ]:
423+ ) -> Tuple [Optional [Union [MLPRowParallelOp , OProjRowParallelOp ,
424+ MatmulAllreduceRowParallelOp ,
425+ DenseOptimRowParallelOp ]], int , int ]:
426426 if disable_tp :
427427 return None , 0 , 1
428428
429- custom_tp_group : Optional [Union [MLPCustomRowParallel ,
430- OProjCustomRowParallel ,
431- MatmulAllreduceCustomRowParallel ,
432- DenseOptimCustomRowParallel ]] = None
429+ custom_op : Optional [Union [MLPRowParallelOp , OProjRowParallelOp ,
430+ MatmulAllreduceRowParallelOp ,
431+ DenseOptimRowParallelOp ]] = None
433432 if prefix .find ("down_proj" ) != - 1 and mlp_tp_enable ():
434- custom_tp_group = MLPCustomRowParallel (layer )
433+ custom_op = MLPRowParallelOp (layer )
435434 elif prefix .find ("o_proj" ) != - 1 and oproj_tp_enable ():
436- custom_tp_group = OProjCustomRowParallel (layer )
435+ custom_op = OProjRowParallelOp (layer )
437436 elif matmul_allreduce_enable ():
438- custom_tp_group = MatmulAllreduceCustomRowParallel (layer )
437+ custom_op = MatmulAllreduceRowParallelOp (layer )
439438 elif dense_optim_enable ():
440- custom_tp_group = DenseOptimCustomRowParallel (layer , prefix )
439+ custom_op = DenseOptimRowParallelOp (layer , prefix )
441440
442- if custom_tp_group is not None :
443- return custom_tp_group , custom_tp_group .tp_rank , custom_tp_group .tp_size
441+ if custom_op is not None :
442+ return custom_op , custom_op .tp_rank , custom_op .tp_size
444443
445444 return None , get_tp_group ().rank_in_group , get_tp_group ().world_size
446445
@@ -467,7 +466,7 @@ def __init__(
467466 return_bias : bool = True ,
468467 disable_tp : bool = False ,
469468 ):
470- self .custom_group , self .tp_rank , self .tp_size = get_custom_tp_group_column (
469+ self .custom_op , self .tp_rank , self .tp_size = get_column_parallel_op (
471470 disable_tp , prefix , self )
472471
473472 self .input_size_per_partition = input_size
@@ -516,15 +515,15 @@ def __init__(
516515 else :
517516 self .register_parameter ("bias" , None )
518517
519- if self .custom_group is not None :
520- self .custom_group .after_create_weights_hook ()
518+ if self .custom_op is not None :
519+ self .custom_op .after_create_weights_hook ()
521520
522521 def forward (
523522 self ,
524523 input_ ,
525524 ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
526- if self .custom_group is not None :
527- return self .custom_group .apply (input_ )
525+ if self .custom_op is not None :
526+ return self .custom_op .apply (input_ )
528527
529528 return super ().forward (input_ )
530529
@@ -550,7 +549,7 @@ def __init__(
550549 return_bias : bool = True ,
551550 disable_tp : bool = False ,
552551 ):
553- self .custom_group , self .tp_rank , self .tp_size = get_custom_tp_group_row (
552+ self .custom_op , self .tp_rank , self .tp_size = get_row_parallel_op (
554553 disable_tp , prefix , self )
555554
556555 # Divide the weight matrix along the first dimension.
@@ -596,16 +595,16 @@ def __init__(
596595 else :
597596 self .register_parameter ("bias" , None )
598597
599- if self .custom_group is not None :
600- self .custom_group .after_create_weights_hook ()
598+ if self .custom_op is not None :
599+ self .custom_op .after_create_weights_hook ()
601600
602601 def forward (
603602 self ,
604603 input_ ,
605604 is_prefill : bool = True ,
606605 ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
607- if self .custom_group is not None :
608- return self .custom_group .apply (input_ )
606+ if self .custom_op is not None :
607+ return self .custom_op .apply (input_ )
609608
610609 return super ().forward (input_ )
611610
@@ -635,7 +634,7 @@ def __init__(
635634 return_bias : bool = True ,
636635 disable_tp : bool = False ,
637636 ):
638- self .custom_group , self .tp_rank , self .tp_size = get_custom_tp_group_column (
637+ self .custom_op , self .tp_rank , self .tp_size = get_column_parallel_op (
639638 disable_tp , prefix , self )
640639
641640 self .output_sizes = output_sizes
@@ -657,8 +656,8 @@ def forward(
657656 self ,
658657 input_ ,
659658 ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
660- if self .custom_group is not None :
661- return self .custom_group .apply (input_ )
659+ if self .custom_op is not None :
660+ return self .custom_op .apply (input_ )
662661
663662 return super ().forward (input_ )
664663
@@ -689,7 +688,7 @@ def __init__(
689688 return_bias : bool = True ,
690689 disable_tp : bool = False ,
691690 ):
692- self .custom_group , _ , tp_size = get_custom_tp_group_column (
691+ self .custom_op , _ , tp_size = get_column_parallel_op (
693692 disable_tp , prefix , self )
694693 self .hidden_size = hidden_size
695694 self .head_size = head_size
@@ -730,8 +729,8 @@ def forward(
730729 self ,
731730 input_ ,
732731 ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
733- if self .custom_group is not None :
734- return self .custom_group .apply (input_ )
732+ if self .custom_op is not None :
733+ return self .custom_op .apply (input_ )
735734
736735 return super ().forward (input_ )
737736
0 commit comments