1515limitations under the License.
1616"""
1717
18- from typing import Optional , Union
18+ from typing import Optional , Tuple , Union
1919
2020import torch
2121import torch .distributed as dist
@@ -69,6 +69,9 @@ def after_create_weights_hook(self):
6969 self .return_bias = self .layer .return_bias
7070 self .quant_method = self .layer .quant_method
7171
72+ def apply (self , input_ ):
73+ raise NotImplementedError
74+
7275
7376class CustomColumnParallel (CustomTensorParallelBase ):
7477
@@ -188,18 +191,25 @@ def apply(
188191 return output , output_bias
189192
190193
191- def get_custom_tp_group_column (disable_tp , prefix , layer ):
194+ def get_custom_tp_group_column (
195+ disable_tp , prefix , layer
196+ ) -> Tuple [
197+ Optional [Union [MLPCustomColumnParallel , DenseOptimMergedColumnParallel ,
198+ DenseOptimQKVParallelLinear ]], int , int ]:
192199 if disable_tp :
193200 return None , 0 , 1
194201
202+ custom_tp_group : Optional [Union [
203+ MLPCustomColumnParallel ,
204+ DenseOptimMergedColumnParallel ,
205+ DenseOptimQKVParallelLinear ,
206+ ]] = None
195207 if "gate_up_proj" in prefix and mlp_tp_enable ():
196208 custom_tp_group = MLPCustomColumnParallel (layer )
197209 elif "gate_up_proj" in prefix and dense_optim_enable ():
198210 custom_tp_group = DenseOptimMergedColumnParallel (layer )
199211 elif dense_optim_enable ():
200212 custom_tp_group = DenseOptimQKVParallelLinear (layer , prefix )
201- else :
202- custom_tp_group = None
203213
204214 if custom_tp_group is not None :
205215 return custom_tp_group , custom_tp_group .tp_rank , custom_tp_group .tp_size
@@ -329,6 +339,7 @@ def apply(
329339 self .hcomm_info ,
330340 bias = bias_ )
331341 else :
342+ assert self .quant_method is not None
332343 output = self .quant_method .apply (self .layer ,
333344 input_parallel ,
334345 bias = bias_ )
@@ -407,10 +418,18 @@ def after_create_weights_hook(self):
407418 self .reduce_results = self .layer .reduce_results
408419
409420
410- def get_custom_tp_group_row (disable_tp , prefix , layer ):
421+ def get_custom_tp_group_row (
422+ disable_tp , prefix , layer
423+ ) -> Tuple [Optional [Union [MLPCustomRowParallel , OProjCustomRowParallel ,
424+ MatmulAllreduceCustomRowParallel ,
425+ DenseOptimCustomRowParallel ]], int , int ]:
411426 if disable_tp :
412427 return None , 0 , 1
413428
429+ custom_tp_group : Optional [Union [MLPCustomRowParallel ,
430+ OProjCustomRowParallel ,
431+ MatmulAllreduceCustomRowParallel ,
432+ DenseOptimCustomRowParallel ]] = None
414433 if prefix .find ("down_proj" ) != - 1 and mlp_tp_enable ():
415434 custom_tp_group = MLPCustomRowParallel (layer )
416435 elif prefix .find ("o_proj" ) != - 1 and oproj_tp_enable ():
@@ -419,8 +438,6 @@ def get_custom_tp_group_row(disable_tp, prefix, layer):
419438 custom_tp_group = MatmulAllreduceCustomRowParallel (layer )
420439 elif dense_optim_enable ():
421440 custom_tp_group = DenseOptimCustomRowParallel (layer , prefix )
422- else :
423- custom_tp_group = None
424441
425442 if custom_tp_group is not None :
426443 return custom_tp_group , custom_tp_group .tp_rank , custom_tp_group .tp_size
0 commit comments