@@ -1227,3 +1227,98 @@ def extra_repr(self) -> str:
1227
1227
s += f", tp_size={ self .tp_size } "
1228
1228
s += f", reduce_results={ self .reduce_results } "
1229
1229
return s
1230
+
1231
+
1232
+ class QKVCrossParallelLinear (torch .nn .Module ):
1233
+
1234
+ def __init__ (self ,
1235
+ hidden_size : int ,
1236
+ head_size : int ,
1237
+ total_num_heads : int ,
1238
+ total_num_kv_heads : Optional [int ] = None ,
1239
+ bias : bool = True ,
1240
+ skip_bias_add : bool = False ,
1241
+ params_dtype : Optional [torch .dtype ] = None ,
1242
+ quant_config : Optional [QuantizationConfig ] = None ,
1243
+ prefix : str = "" ):
1244
+ super ().__init__ ()
1245
+ # Empty placeholders for loading as a single module.
1246
+ self .weight = torch .nn .Parameter ()
1247
+ set_weight_attrs (self .weight , {
1248
+ "weight_loader" : self .weight_loader_weight ,
1249
+ })
1250
+ # Use a dictionary to avoid submodules parameters auto-registration:
1251
+ # drop-in replacement for a `QKVParallelLinear` module.
1252
+ self .proj = dict ()
1253
+ self .proj ["q_proj_decoder" ] = ColumnParallelLinear (
1254
+ input_size = hidden_size ,
1255
+ output_size = total_num_heads * head_size ,
1256
+ bias = bias ,
1257
+ quant_config = quant_config ,
1258
+ skip_bias_add = skip_bias_add ,
1259
+ params_dtype = params_dtype ,
1260
+ prefix = f"{ prefix } .q_proj_decoder" )
1261
+
1262
+ self .proj ["kv_proj_encoder" ] = QKVParallelLinear (
1263
+ hidden_size = hidden_size ,
1264
+ head_size = head_size ,
1265
+ total_num_heads = 0 ,
1266
+ total_num_kv_heads = total_num_kv_heads ,
1267
+ bias = bias ,
1268
+ quant_config = quant_config ,
1269
+ skip_bias_add = skip_bias_add ,
1270
+ params_dtype = params_dtype ,
1271
+ prefix = f"{ prefix } .kv_proj_encoder" )
1272
+
1273
+ # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1274
+ self .kv_size = self .kv_proj_encoder .num_kv_heads * head_size
1275
+
1276
+ if bias :
1277
+ self .bias = torch .nn .Parameter ()
1278
+ set_weight_attrs (self .bias , {
1279
+ "weight_loader" : self .weight_loader_bias ,
1280
+ })
1281
+
1282
+ @property
1283
+ def q_proj_decoder (self ):
1284
+ return self .proj ["q_proj_decoder" ]
1285
+
1286
+ @property
1287
+ def kv_proj_encoder (self ):
1288
+ return self .proj ["kv_proj_encoder" ]
1289
+
1290
+ def forward (self , decoder_hidden_states , encoder_hidden_states ):
1291
+ q , _ = self .q_proj_decoder (decoder_hidden_states )
1292
+ if encoder_hidden_states is None :
1293
+ # Encoder KV already cached.
1294
+ k = None
1295
+ v = None
1296
+ else :
1297
+ # Prefill phase, encoder KV cached here.
1298
+ kv_enc , _ = self .kv_proj_encoder (encoder_hidden_states )
1299
+ # Split kv in half
1300
+ k , v = kv_enc .split (self .kv_size , dim = - 1 )
1301
+ return q , k , v
1302
+
1303
+ def weight_loader_weight (self ,
1304
+ param : torch .nn .Parameter ,
1305
+ loaded_weight : torch .Tensor ,
1306
+ loaded_shard_id : Optional [str ] = None ):
1307
+ # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1308
+ param = self .q_proj_decoder .weight if loaded_shard_id == "q" \
1309
+ else self .kv_proj_encoder .weight
1310
+ param .weight_loader (
1311
+ param ,
1312
+ loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1313
+ param , loaded_weight , loaded_shard_id )
1314
+
1315
+ def weight_loader_bias (self ,
1316
+ param : torch .nn .Parameter ,
1317
+ loaded_weight : torch .Tensor ,
1318
+ loaded_shard_id : Optional [str ] = None ):
1319
+ param = self .q_proj_decoder .bias if loaded_shard_id == "q" \
1320
+ else self .kv_proj_encoder .bias
1321
+ param .weight_loader (
1322
+ param ,
1323
+ loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1324
+ param , loaded_weight , loaded_shard_id )
0 commit comments