@@ -8102,15 +8102,13 @@ def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
81028102 def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
81038103 blocks0 : Tensor = torch .zeros (1 )
81048104 blocks1 : Tensor = torch .zeros (1 )
8105- found_mxfp4_tensors = False
81068105 # we assume that tensors are loaded in the correct order
81078106 for name , data_torch in self .get_tensors ():
81088107 if "mlp.experts.down_proj_blocks" in name :
81098108 blocks0 = data_torch
81108109 elif "mlp.experts.down_proj_scales" in name :
81118110 new_name = self .map_tensor_name (name .replace ("_scales" , ".weight" ))
81128111 self .repack_mxfp4 (new_name , blocks0 , data_torch )
8113- found_mxfp4_tensors = True
81148112 elif "mlp.experts.gate_up_proj_blocks" in name :
81158113 blocks0 , blocks1 = data_torch [:, ::2 , :, :], data_torch [:, 1 ::2 , :, :]
81168114 elif "mlp.experts.gate_up_proj_scales" in name :
@@ -8119,9 +8117,6 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
81198117 new_name_up = self .map_tensor_name (name .replace ("gate_up_proj_scales" , "up_proj.weight" ))
81208118 self .repack_mxfp4 (new_name_gate , blocks0 , scales0 )
81218119 self .repack_mxfp4 (new_name_up , blocks1 , scales1 )
8122- found_mxfp4_tensors = True
8123- if not found_mxfp4_tensors :
8124- raise ValueError ("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model." )
81258120 return []
81268121
81278122 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
@@ -8134,7 +8129,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
81348129 if "down_proj" in name :
81358130 if name .endswith ("_bias" ):
81368131 name = name .replace ("down_proj_bias" , "down_proj.bias" )
8132+ elif "_blocks" not in name and "_scales" not in name :
8133+ logger .warning (f"{ name } is not in MXFP4, performance may be degraded" )
8134+ name = name .replace ("down_proj" , "down_proj.weight" )
8135+ data_torch = data_torch .transpose (- 1 , - 2 )
81378136 else :
8137+ # otherwise, it should already be repacked to ggml MXFP4 format
81388138 return []
81398139
81408140 # split the gate_up into gate and up
@@ -8147,7 +8147,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
81478147 (self .map_tensor_name (name_gate ), gate_proj_bias ),
81488148 (self .map_tensor_name (name_up ), up_proj_bias )
81498149 ]
8150+ elif "_blocks" not in name and "_scales" not in name :
8151+ logger .warning (f"{ name } is not in MXFP4, performance may be degraded" )
8152+ name_up = name .replace ("gate_up_proj" , "up_proj.weight" )
8153+ name_gate = name .replace ("gate_up_proj" , "gate_proj.weight" )
8154+ data_torch = data_torch .transpose (- 1 , - 2 )
8155+ gate_proj_weight , up_proj_weight = data_torch [:, ::2 , :], data_torch [:, 1 ::2 , :]
8156+ return [
8157+ (self .map_tensor_name (name_gate ), gate_proj_weight ),
8158+ (self .map_tensor_name (name_up ), up_proj_weight )
8159+ ]
81508160 else :
8161+ # otherwise, it should already be repacked to ggml MXFP4 format
81518162 return []
81528163
81538164 return [(self .map_tensor_name (name ), data_torch )]
0 commit comments