@@ -2641,18 +2641,47 @@ def prepare_tensors(self):
26412641 super ().prepare_tensors ()
26422642
26432643
2644- @ModelBase .register ("BitnetForCausalLM" )
2644+ @ModelBase .register ("BitnetForCausalLM" , "BitNetForCausalLM" )
26452645class BitnetModel (TextModel ):
26462646 model_arch = gguf .MODEL_ARCH .BITNET
26472647
2648+ def __init__ (self , * args , ** kwargs ):
2649+ super ().__init__ (* args , ** kwargs )
2650+ self ._bitnet_weight_scales : dict [str , torch .Tensor ] = {}
2651+
26482652 def set_vocab (self ):
2649- self ._set_vocab_sentencepiece ()
2653+ if (self .dir_model / "tokenizer.model" ).is_file ():
2654+ self ._set_vocab_sentencepiece ()
2655+ else :
2656+ self ._set_vocab_gpt2 ()
26502657
26512658 def set_gguf_parameters (self ):
26522659 super ().set_gguf_parameters ()
26532660 self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
26542661 self .gguf_writer .add_rope_scaling_factor (1.0 )
26552662
2663+ @staticmethod
2664+ def _unpack_bitnet_weights (packed : torch .Tensor ) -> torch .Tensor :
2665+ if packed .dtype != torch .uint8 :
2666+ raise ValueError (f"Expected packed BitNet weights to be torch.uint8, got { packed .dtype } " )
2667+
2668+ values_per_item = 4
2669+ rows = packed .shape [0 ]
2670+ rest = packed .shape [1 :]
2671+
2672+ unpacked_chunks : list [torch .Tensor ] = []
2673+ mapping = torch .tensor ([- 1.0 , 0.0 , 1.0 , 0.0 ], dtype = torch .float32 , device = packed .device )
2674+
2675+ for i in range (values_per_item ):
2676+ chunk = (packed >> (2 * i )) & 0x03
2677+ chunk = mapping [chunk .long ()].reshape ((rows , * rest ))
2678+ unpacked_chunks .append (chunk )
2679+
2680+ if not unpacked_chunks :
2681+ raise ValueError ("Failed to unpack BitNet weights: no chunks produced" )
2682+
2683+ return torch .cat (unpacked_chunks , dim = 0 )
2684+
26562685 def weight_quant (self , weight : Tensor ) -> Tensor :
26572686 dtype = weight .dtype
26582687 weight = weight .float ()
@@ -2665,8 +2694,36 @@ def weight_quant(self, weight: Tensor) -> Tensor:
26652694 return result .type (dtype )
26662695
26672696 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2697+ if name .endswith (".weight_scale" ):
2698+ weight_name = name [:- 13 ] + ".weight"
2699+ mapped_weight_name = self .map_tensor_name (weight_name )
2700+ if isinstance (data_torch , LazyTorchTensor ):
2701+ data_torch = LazyTorchTensor .to_eager (data_torch )
2702+
2703+ scale_tensor = data_torch .to (torch .float32 )
2704+ self ._bitnet_weight_scales [mapped_weight_name ] = scale_tensor
2705+ return []
2706+
26682707 new_name = self .map_tensor_name (name )
26692708
2709+ ternary_weight = False
2710+
2711+ if name .endswith (".weight" ):
2712+ scale_tensor = self ._bitnet_weight_scales .pop (new_name , None )
2713+ if scale_tensor is not None :
2714+ scale_tensor = scale_tensor .to (torch .float32 )
2715+ if scale_tensor .numel () != 1 :
2716+ raise ValueError (f"Expected scalar weight_scale for '{ name } ', got shape { tuple (scale_tensor .shape )} " )
2717+
2718+ if isinstance (data_torch , LazyTorchTensor ):
2719+ data_torch = LazyTorchTensor .to_eager (data_torch )
2720+
2721+ packed = data_torch .to (torch .uint8 )
2722+ unpacked = self ._unpack_bitnet_weights (packed )
2723+ scale_value = scale_tensor .reshape (- 1 )[0 ].item ()
2724+ data_torch = unpacked * scale_value
2725+ ternary_weight = True
2726+
26702727 if any (self .match_model_tensor_name (new_name , key , bid ) for key in [
26712728 gguf .MODEL_TENSOR .ATTN_Q ,
26722729 gguf .MODEL_TENSOR .ATTN_K ,
@@ -2675,7 +2732,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26752732 gguf .MODEL_TENSOR .FFN_UP ,
26762733 gguf .MODEL_TENSOR .FFN_DOWN ,
26772734 gguf .MODEL_TENSOR .FFN_GATE ,
2678- ]):
2735+ ]) and not ternary_weight :
26792736 # transform weight into 1/0/-1 (in fp32)
26802737 data_torch = self .weight_quant (data_torch )
26812738
0 commit comments