Skip to content
56 changes: 54 additions & 2 deletions src/transformers/integrations/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
import torch.nn as nn
import torch.nn.functional as F

# We reuse the RMSNorm implementation shipped with the reference BitNet model to avoid code duplication
# and guarantee consistency with the official implementation.
try:
from ..models.bitnet.modeling_bitnet import BitNetRMSNorm
except (ModuleNotFoundError, ImportError):
BitNetRMSNorm = None # BitNet model might not be available in minimal installations

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -124,7 +131,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:


class BitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.dtype = dtype
self.in_features = in_features
Expand All @@ -150,6 +166,17 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None,
else:
self.bias = None

# Optional RMSNorm (applied on the activations before quantization).
self.rms_norm = None
if use_rms_norm:
if BitNetRMSNorm is None:
raise ImportError(
"`use_rms_norm=True` requires the BitNet model code to be available in the current Transformers"
" installation. Please install the full `transformers` package or ensure the BitNet model files"
" are accessible."
)
self.rms_norm = BitNetRMSNorm(in_features, eps=rms_norm_eps)

@torch.compile
def activation_quant(self, input, num_bits=8):
"""
Expand Down Expand Up @@ -180,6 +207,10 @@ def post_quant_process(self, input, input_scale, weight_scale):
return out

def forward(self, input):
# Apply RMSNorm on the input if requested.
if self.rms_norm is not None:
input = self.rms_norm(input)

w = self.weight
w_quant = unpack_weights(w, dtype=self.dtype)
input_quant, input_scale = self.activation_quant(input)
Expand Down Expand Up @@ -245,9 +276,21 @@ def __init__(
device=None,
dtype=None,
online_quant: bool = False,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__(in_features, out_features, bias)
self.online_quant = online_quant
# Optional RMSNorm
self.rms_norm = None
if use_rms_norm:
if BitNetRMSNorm is None:
raise ImportError(
"`use_rms_norm=True` requires the BitNet model code to be available in the current Transformers"
" installation. Please install the full `transformers` package or ensure the BitNet model files"
" are accessible."
)
self.rms_norm = BitNetRMSNorm(in_features, eps=rms_norm_eps)
if not online_quant:
self.register_buffer(
"weight_scale",
Expand All @@ -271,6 +314,10 @@ def load_hook(
return state_dict

def forward(self, input):
# Optional RMSNorm on activations prior to quantization.
if self.rms_norm is not None:
input = self.rms_norm(input)

if self.online_quant:
weight = WeightQuant.apply(self.weight)
else:
Expand Down Expand Up @@ -318,6 +365,8 @@ def _replace_with_bitnet_linear(
device=module.weight.device,
dtype=module.weight.dtype,
online_quant=(quantization_config.quantization_mode == "online"),
use_rms_norm=getattr(quantization_config, "use_rms_norm", False),
rms_norm_eps=getattr(quantization_config, "rms_norm_eps", 1e-6),
)
if quantization_config.quantization_mode == "offline":
model._modules[name].requires_grad_(False)
Expand All @@ -328,6 +377,8 @@ def _replace_with_bitnet_linear(
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
use_rms_norm=getattr(quantization_config, "use_rms_norm", False),
rms_norm_eps=getattr(quantization_config, "rms_norm_eps", 1e-6),
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
Expand Down Expand Up @@ -363,7 +414,7 @@ def replace_with_bitnet_linear(
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
Expand All @@ -390,3 +441,4 @@ def replace_with_bitnet_linear(
)

return model

9 changes: 9 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin):
In `offline` mode, quantization parameters are pre-calculated *before* inference.
These parameters are then fixed and loaded into the quantized model. This
generally results in lower runtime overhead compared to online quantization.
use_rms_norm (`bool`, *optional*, defaults to `False`):
Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
of normalizing activations before quantization/packing.
rms_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon value used in the RMSNorm layer for numerical stability.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments that may be used by specific quantization
backends or future versions.
Expand All @@ -1801,6 +1806,8 @@ def __init__(
modules_to_not_convert: Optional[List] = None,
linear_class: Optional[str] = "bitlinear",
quantization_mode: Optional[str] = "offline",
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
**kwargs,
):
if linear_class not in ["bitlinear", "autobitlinear"]:
Expand All @@ -1811,6 +1818,8 @@ def __init__(
self.modules_to_not_convert = modules_to_not_convert
self.linear_class = linear_class
self.quantization_mode = quantization_mode
self.use_rms_norm = use_rms_norm
self.rms_norm_eps = rms_norm_eps
self.post_init()

def post_init(self):
Expand Down