File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -188,18 +188,19 @@ def apply(
188188 return output , output_bias
189189
190190
191- def get_custom_tp_group_column (disable_tp , prefix , layer ):
191+ def get_custom_tp_group_column (
192+ disable_tp , prefix ,
193+ layer ) -> tuple [Optional [CustomTensorParallelBase ], int , int ]:
192194 if disable_tp :
193195 return None , 0 , 1
194196
197+ custom_tp_group = None
195198 if "gate_up_proj" in prefix and mlp_tp_enable ():
196199 custom_tp_group = MLPCustomColumnParallel (layer )
197200 elif "gate_up_proj" in prefix and dense_optim_enable ():
198201 custom_tp_group = DenseOptimMergedColumnParallel (layer )
199202 elif dense_optim_enable ():
200203 custom_tp_group = DenseOptimQKVParallelLinear (layer , prefix )
201- else :
202- custom_tp_group = None
203204
204205 if custom_tp_group is not None :
205206 return custom_tp_group , custom_tp_group .tp_rank , custom_tp_group .tp_size
@@ -329,6 +330,7 @@ def apply(
329330 self .hcomm_info ,
330331 bias = bias_ )
331332 else :
333+ assert self .quant_method is not None
332334 output = self .quant_method .apply (self .layer ,
333335 input_parallel ,
334336 bias = bias_ )
You can’t perform that action at this time.
0 commit comments