@@ -63,14 +63,11 @@ def __init__(
6363 * ,
6464 return_bias : bool = True ,
6565 ):
66- self .comm_group = None
67- if prefix .find ("gate_up_proj" ) != - 1 and mlp_tp_enable ():
68- self .comm_group = get_mlp_tp_group ()
69- else :
66+ # if self has attr `tp_size`, this means it has been customized by subclass
67+ if not hasattr (self , "tp_size" ):
7068 self .comm_group = get_tp_group ()
71-
72- self .tp_size = self .comm_group .world_size
73- self .tp_rank = self .comm_group .rank_in_group
69+ self .tp_size = self .comm_group .world_size
70+ self .tp_rank = self .comm_group .rank_in_group
7471
7572 self .input_size_per_partition = input_size
7673 self .output_size_per_partition = divide (output_size , self .tp_size )
@@ -81,6 +78,8 @@ def __init__(
8178 divide (output_size , self .tp_size )
8279 for output_size in self .output_sizes
8380 ]
81+ # skip ColumnParallelLinear.__init__, as it will create weight_loader with default tp group
82+ # we will create weight_loader by customized comm group
8483 AscendLinearBase .__init__ (self ,
8584 input_size ,
8685 output_size ,
@@ -138,32 +137,35 @@ def __init__(
138137 * ,
139138 return_bias : bool = True ,
140139 ):
141- if prefix .find ("down_proj" ) != - 1 and mlp_tp_enable ():
142- comm_group = get_mlp_tp_group ()
143- self .forward_type = "mlp_tp"
144- elif prefix .find ("o_proj" ) != - 1 and oproj_tp_enable ():
145- comm_group = get_otp_group ()
146- self .forward_type = "oproj_tp"
147- elif matmul_allreduce_enable ():
148- comm_group = get_tp_group ()
149- self .forward_type = "matmul_allreduce"
150- self .hcomm_info = self .get_hcomm_info (comm_group .device_group )
151- elif dense_optim_enable ():
152- comm_group = get_tp_group ()
153- self .forward_type = "dense_optim"
154- else :
155- comm_group = get_tp_group ()
156- self .forward_type = "normal"
157- self .comm_group = comm_group
158-
159- self .tp_size = self .comm_group .world_size
160- self .tp_rank = self .comm_group .rank_in_group
140+ # if self has attr `tp_size`, this means it has been customized by subclass
141+ if not hasattr (self , "tp_size" ):
142+ if prefix .find ("down_proj" ) != - 1 and mlp_tp_enable ():
143+ comm_group = get_mlp_tp_group ()
144+ self .forward_type = "mlp_tp"
145+ elif prefix .find ("o_proj" ) != - 1 and oproj_tp_enable ():
146+ comm_group = get_otp_group ()
147+ self .forward_type = "oproj_tp"
148+ elif matmul_allreduce_enable ():
149+ comm_group = get_tp_group ()
150+ self .forward_type = "matmul_allreduce"
151+ self .hcomm_info = self .get_hcomm_info (comm_group .device_group )
152+ elif dense_optim_enable ():
153+ comm_group = get_tp_group ()
154+ self .forward_type = "dense_optim"
155+ else :
156+ comm_group = get_tp_group ()
157+ self .forward_type = "normal"
158+ self .comm_group = comm_group
159+
160+ self .tp_size = self .comm_group .world_size
161+ self .tp_rank = self .comm_group .rank_in_group
161162
162163 # Divide the weight matrix along the first dimension.
163164 self .input_size_per_partition = divide (input_size , self .tp_size )
164165 self .output_size_per_partition = output_size
165166 self .output_partition_sizes = [output_size ]
166167
168+ # skip RowParallelLinear.__init__, as it will create weight_loader with default tp group
167169 AscendLinearBase .__init__ (self ,
168170 input_size ,
169171 output_size ,
@@ -368,7 +370,6 @@ def _forward_dense_optim(
368370 return output
369371 return output , output_bias
370372
371-
372373class AscendMergedColumnParallelLinear (MergedColumnParallelLinear ):
373374 """Packed linear layers with column parallelism.
374375
@@ -526,7 +527,7 @@ def __init__(
526527 self .output_sizes = [
527528 self .num_heads * self .head_size * tp_size , # q_proj
528529 self .num_kv_heads * self .head_size * tp_size , # k_proj
529- self .num_kv_heads * self .head_size * tp_size , # v_proj
530+ self .num_kv_heads * self .head_size * tp_size , # v_proj
530531 ]
531532 AscendColumnParallelLinear .__init__ (self ,
532533 input_size = input_size ,
@@ -593,22 +594,15 @@ def __init__(
593594 return_bias : bool = True ,
594595 disable_tp : bool = False ,
595596 ):
596- nn .Module .__init__ (self )
597-
598- # Keep input parameters
599- self .input_size = input_size
600- self .output_size = output_size
601- self .skip_bias_add = skip_bias_add
602- if params_dtype is None :
603- params_dtype = torch .get_default_dtype ()
604- self .params_dtype = params_dtype
605- self .quant_config = quant_config
606- self .prefix = prefix
607- if quant_config is None :
608- self .quant_method : Optional [
609- QuantizeMethodBase ] = UnquantizedLinearMethod ()
597+ if hasattr (self , "tp_rank" ) and hasattr (self , "tp_size" ):
598+ tp_rank = self .tp_rank
599+ tp_size = self .tp_size
600+ super ().__init__ (input_size , output_size , skip_bias_add ,
601+ params_dtype , quant_config , prefix ,
602+ return_bias , disable_tp )
603+ self .tp_rank = tp_rank
604+ self .tp_size = tp_size
610605 else :
611- self .quant_method = quant_config .get_quant_method (self ,
612- prefix = prefix )
613- self .return_bias = return_bias
614- self .disable_tp = disable_tp
606+ super ().__init__ (input_size , output_size , skip_bias_add ,
607+ params_dtype , quant_config , prefix ,
608+ return_bias , disable_tp )
0 commit comments