@@ -1231,24 +1231,55 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
12311231 down_proj_attrs ,
12321232 )
12331233 else :
1234- if self .model_format != "torch" and layer .fd_config .load_config .load_choices == "default_v1" :
1235- # transpose [0,2,1]
1236- up_gate_proj_weight_shape = (
1237- self .up_gate_proj_weight_shape [:1 ] + self .up_gate_proj_weight_shape [1 :][::- 1 ]
1238- )
1239- up_gate_proj_scale_shape = self .up_gate_proj_scale_shape [:1 ] + self .up_gate_proj_scale_shape [1 :][::- 1 ]
1240- down_proj_weight_shape = self .down_proj_weight_shape [:1 ] + self .down_proj_weight_shape [1 :][::- 1 ]
1241- down_proj_scale_shape = self .down_proj_scale_shape [:1 ] + self .down_proj_scale_shape [1 :][::- 1 ]
1242- extra_weight_attrs = {
1243- ** extra_weight_attrs ,
1244- "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
1245- }
1234+ # 1.init shape
1235+ extra_weight_attrs = {** extra_weight_attrs }
1236+ if layer .fd_config .load_config .load_choices == "default_v1" :
1237+ if self .model_format != "torch" :
1238+ # transpose [0,2,1]
1239+ up_gate_proj_weight_shape = (
1240+ self .up_gate_proj_weight_shape [:1 ] + self .up_gate_proj_weight_shape [1 :][::- 1 ]
1241+ )
1242+ up_gate_proj_scale_shape = (
1243+ self .up_gate_proj_scale_shape [:1 ] + self .up_gate_proj_scale_shape [1 :][::- 1 ]
1244+ )
1245+ down_proj_weight_shape = self .down_proj_weight_shape [:1 ] + self .down_proj_weight_shape [1 :][::- 1 ]
1246+ down_proj_scale_shape = self .down_proj_scale_shape [:1 ] + self .down_proj_scale_shape [1 :][::- 1 ]
1247+ up_gate_proj_attrs = {
1248+ ** extra_weight_attrs ,
1249+ "tensor_track" : TensorTracker (
1250+ shape = up_gate_proj_weight_shape ,
1251+ output_dim = False ,
1252+ ),
1253+ }
1254+ down_proj_attrs = {
1255+ ** extra_weight_attrs ,
1256+ "tensor_track" : TensorTracker (
1257+ shape = down_proj_weight_shape ,
1258+ output_dim = False ,
1259+ ),
1260+ }
1261+ else :
1262+ up_gate_proj_weight_shape = self .up_gate_proj_weight_shape
1263+ up_gate_proj_scale_shape = self .up_gate_proj_scale_shape
1264+ down_proj_weight_shape = self .down_proj_weight_shape
1265+ down_proj_scale_shape = self .down_proj_scale_shape
1266+ up_gate_proj_attrs = {
1267+ ** extra_weight_attrs ,
1268+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
1269+ }
1270+ down_proj_attrs = {
1271+ ** extra_weight_attrs ,
1272+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
1273+ }
12461274 else :
1247- # v0 loader or torch model format
1275+ # v0 loader
12481276 up_gate_proj_weight_shape = self .up_gate_proj_weight_shape
12491277 up_gate_proj_scale_shape = self .up_gate_proj_scale_shape
12501278 down_proj_weight_shape = self .down_proj_weight_shape
12511279 down_proj_scale_shape = self .down_proj_scale_shape
1280+ up_gate_proj_attrs = {}
1281+ down_proj_attrs = {}
1282+
12521283 self .weight_dtype = paddle .float8_e4m3fn
12531284 self .added_scale_attrs = ["up_gate_proj_weight_scale_inv" , "down_proj_weight_scale_inv" ]
12541285 up_gate_proj_weight_name = self .added_weight_attrs [0 ]
@@ -1295,20 +1326,20 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
12951326 )
12961327 set_weight_attrs (
12971328 getattr (layer , up_gate_proj_weight_name ),
1298- extra_weight_attrs ,
1329+ up_gate_proj_attrs ,
12991330 )
13001331 set_weight_attrs (
13011332 getattr (layer , up_gate_proj_scale_name ),
1302- extra_weight_attrs ,
1333+ up_gate_proj_attrs ,
13031334 )
13041335
13051336 set_weight_attrs (
13061337 getattr (layer , down_proj_weight_name ),
1307- extra_weight_attrs ,
1338+ down_proj_attrs ,
13081339 )
13091340 set_weight_attrs (
13101341 getattr (layer , down_proj_scale_name ),
1311- extra_weight_attrs ,
1342+ down_proj_attrs ,
13121343 )
13131344
13141345 def process_weights_after_loading (self , layer ):
@@ -1385,6 +1416,13 @@ def _process_quantize(weight_idx):
13851416 down_proj_weight_name = self .added_weight_attrs [1 ]
13861417 up_gate_proj_scale_name = self .added_scale_attrs [0 ]
13871418 down_proj_scale_name = self .added_scale_attrs [1 ]
1419+ if (
1420+ not weight_fully_copied (getattr (layer , up_gate_proj_weight_name ))
1421+ or not weight_fully_copied (getattr (layer , down_proj_weight_name ))
1422+ or not weight_fully_copied (getattr (layer , up_gate_proj_scale_name ))
1423+ or not weight_fully_copied (getattr (layer , down_proj_scale_name ))
1424+ ):
1425+ return
13881426 process_weight_transpose (layer , up_gate_proj_weight_name )
13891427 process_weight_transpose (layer , down_proj_weight_name )
13901428 process_weight_transpose (layer , up_gate_proj_scale_name )
0 commit comments