Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c02e302
int4 quantization support
JyotinderSingh Jun 28, 2025
dd11851
refactor packing utils into quantizers
JyotinderSingh Jun 29, 2025
777b5e6
generalize int4 packing
JyotinderSingh Jun 29, 2025
72a8cbc
restored pytest skip conditions
JyotinderSingh Jun 30, 2025
efe244e
fixes 'tuple' object has no attribute 'rank' error
JyotinderSingh Jun 30, 2025
7297410
fix dtype check to work across backends
JyotinderSingh Jun 30, 2025
3a9e26c
fixed torch compatibility
JyotinderSingh Jun 30, 2025
9e25042
fixed jax compatibility
JyotinderSingh Jun 30, 2025
1aa86de
removes redundant self._orig_input_dim initialization
JyotinderSingh Jun 30, 2025
f9013ae
improves readability
JyotinderSingh Jun 30, 2025
f334156
W4A8
JyotinderSingh Jul 3, 2025
f187306
added _int4_call stub
JyotinderSingh Jul 4, 2025
eed432b
Fix bug in unpack that promoted tensor to fp32
JyotinderSingh Jul 8, 2025
248fcc8
add missing dtype assertion to quantizer test
JyotinderSingh Jul 8, 2025
6495076
docstring fixes
JyotinderSingh Jul 9, 2025
0413b36
docstring fixes
JyotinderSingh Jul 9, 2025
052f7b6
introduces fastpath for dense unpack
JyotinderSingh Jul 9, 2025
a87687d
handle negative axis for pack/unpack
JyotinderSingh Jul 9, 2025
9e2901c
standardize docs formatting
JyotinderSingh Jul 10, 2025
519e6d7
fix docstring format
JyotinderSingh Jul 10, 2025
41cac4b
Reduce duplication in _get_kernel_with_merged_lora
JyotinderSingh Jul 11, 2025
0d5c3bd
remove unnecessary cast ops
JyotinderSingh Jul 11, 2025
98fa1ed
removes unused var
JyotinderSingh Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
)
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
from keras.src.quantizers.quantizers import (
quantize_and_dequantize as quantize_and_dequantize,
)
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4
2 changes: 2 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
)
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
from keras.src.quantizers.quantizers import (
quantize_and_dequantize as quantize_and_dequantize,
)
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4
4 changes: 2 additions & 2 deletions keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state

QUANTIZATION_MODES = ("int8", "float8")
QUANTIZATION_MODES = ("int8", "float8", "int4")


@keras_export(
Expand Down Expand Up @@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy):
f"Received: policy={policy}"
)
mode, source_name = split_name
if policy.startswith("int8"):
if policy.startswith("int8") or policy.startswith("int4"):
return QuantizedDTypePolicy(mode, source_name)
elif policy.startswith("float8"):
return QuantizedFloat8DTypePolicy(mode, source_name)
Expand Down
209 changes: 187 additions & 22 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from keras.src import activations
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
from keras.src import ops
from keras.src import quantizers
Expand Down Expand Up @@ -110,9 +109,10 @@ def build(self, input_shape):
kernel_shape = (input_shape[-1], self.units)
if self.quantization_mode:
self.quantized_build(kernel_shape, mode=self.quantization_mode)
if self.quantization_mode != "int8":
# If the layer is quantized to int8, `self._kernel` will be added
# in `self._int8_build`. Therefore, we skip it here.
if self.quantization_mode not in ("int8", "int4"):
# If the layer is quantized to int8 or int4, `self._kernel` will be
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
# it here.
self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
Expand Down Expand Up @@ -182,9 +182,22 @@ def enable_lora(
"lora is already enabled. This can only be done once per layer."
)
self._tracker.unlock()
# Determine the correct input dimension for the LoRA A matrix. When
# the layer has been int4-quantized, `self._kernel` stores a *packed*
# representation whose first dimension is `ceil(input_dim/2)`. We
# saved the true, *unpacked* input dimension in `self._orig_input_dim`
# during quantization. Use it if available; otherwise fall back to the
# first dimension of `self.kernel`.
if self.quantization_mode == "int4" and hasattr(
self, "_orig_input_dim"
):
input_dim_for_lora = self._orig_input_dim
else:
input_dim_for_lora = self.kernel.shape[0]

self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=(self.kernel.shape[0], rank),
shape=(input_dim_for_lora, rank),
initializer=initializers.get(a_initializer),
regularizer=self.kernel_regularizer,
)
Expand All @@ -211,7 +224,7 @@ def save_own_variables(self, store):
if self.use_bias:
target_variables.append(self.bias)
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
if self.quantization_mode in ("int8", "int4"):
target_variables.append(kernel_scale)
elif self.quantization_mode == "float8":
target_variables.append(self.inputs_scale)
Expand All @@ -237,7 +250,7 @@ def load_own_variables(self, store):
if self.use_bias:
target_variables.append(self.bias)
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
if self.quantization_mode in ("int8", "int4"):
target_variables.append(self.kernel_scale)
elif self.quantization_mode == "float8":
target_variables.append(self.inputs_scale)
Expand Down Expand Up @@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
def quantized_build(self, kernel_shape, mode):
if mode == "int8":
self._int8_build(kernel_shape)
elif mode == "int4":
self._int4_build(kernel_shape)
elif mode == "float8":
self._float8_build()
else:
Expand All @@ -337,6 +352,38 @@ def _int8_build(self, kernel_shape):
trainable=False,
)

def _int4_build(self, kernel_shape):
"""Build variables for int4 quantization.
`kernel_shape` is the *original* float32 kernel shape
`(input_dim, units)`. We allocate the stored kernel with rows
`ceil(input_dim/2)` because two int4 values are packed into a single
int8 byte.
"""
# Per-channel int8 quantizer for the last axis (features).
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
axis=-1,
)
input_dim, output_dim = kernel_shape
packed_rows = (input_dim + 1) // 2 # ceil for odd dims

# Kernel is stored *packed*: each int8 byte contains two int4 values.
self._kernel = self.add_weight(
name="kernel",
shape=(packed_rows, output_dim),
initializer="zeros",
dtype="int8",
trainable=False,
)
# One scale per output unit (per-channel).
self.kernel_scale = self.add_weight(
name="kernel_scale",
shape=(self.units,),
initializer="ones",
trainable=False,
)
# Record original input_dim for unpacking at runtime.
self._orig_input_dim = input_dim

def _float8_build(self):
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy

Expand Down Expand Up @@ -383,6 +430,16 @@ def _float8_build(self):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
"""Custom gradient function to handle the int8 quantized weights.

Automatic differentiation will not know how to handle the int8
quantized weights. So a custom gradient function is needed to
handle the int8 quantized weights.

The custom gradient function will use the dequantized kernel to
compute the gradient.
"""

def grad_fn(*args, upstream=None):
if upstream is None:
(upstream,) = args
Expand Down Expand Up @@ -415,6 +472,59 @@ def grad_fn(*args, upstream=None):
x = self.activation(x)
return x

def _int4_call(self, inputs, training=None):
"""Forward pass for int4 quantized Dense layer."""

@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
"""Custom gradient function for int4 quantized weights.

Automatic differentiation will not know how to handle the
int4 quantized weights. So a custom gradient function is needed
to handle the int4 quantized weights.

The custom gradient function will use the dequantized kernel to
compute the gradient.
"""

unpacked_kernel = quantizers.unpack_int4(
kernel, self._orig_input_dim
)

def grad_fn(*args, upstream=None):
if upstream is None:
(upstream,) = args
float_kernel = ops.divide(
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
kernel_scale,
)
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
return (inputs_grad, None, None)

inputs, inputs_scale = self.inputs_quantizer(inputs)
x = ops.matmul(inputs, unpacked_kernel)
x = ops.cast(x, self.compute_dtype)
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
return x, grad_fn

x = matmul_with_inputs_gradient(
inputs,
ops.convert_to_tensor(self._kernel),
ops.convert_to_tensor(self.kernel_scale),
)

if self.lora_enabled:
lora_x = ops.matmul(inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)

# Add bias and activation
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x

def _float8_call(self, inputs, training=None):
if self.lora_enabled:
raise NotImplementedError(
Expand Down Expand Up @@ -518,13 +628,40 @@ def quantize(self, mode, type_check=True):
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
del self._kernel
self.quantized_build(kernel_shape, mode)
if mode == "int8":
# Build variables for int8 mode
self.quantized_build(kernel_shape, mode)
self._kernel.assign(kernel_value)
self.kernel_scale.assign(kernel_scale)
elif mode == "int4":
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
self._kernel,
axis=0,
value_range=(-8, 7),
dtype="int8",
to_numpy=True,
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# 2. Pack two int4 values into a single int8 byte.
packed_kernel_value, _, orig_rows = quantizers.pack_int4(
kernel_value_int4
)
del self._kernel
# Build variables using the original kernel shape; _int4_build will
# compute the packed shape internally.
self.quantized_build(kernel_shape, mode)
# Assign packed values.
self._kernel.assign(packed_kernel_value)
self.kernel_scale.assign(kernel_scale)
elif mode == "float8":
self.quantized_build(kernel_shape, mode)
else:
raise self._quantization_mode_error(mode)

# Set new dtype policy
# Set new dtype policy only for modes that already have a policy.
if self.dtype_policy.quantization_mode is None:
from keras.src import dtype_policies # local import to avoid cycle

policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

Expand All @@ -533,17 +670,45 @@ def _get_kernel_with_merged_lora(self):
kernel_value = self._kernel
kernel_scale = self.kernel_scale
if self.lora_enabled:
# Dequantize & quantize to merge lora weights into int8 kernel
# Note that this is a lossy compression
kernel_value = ops.divide(kernel_value, kernel_scale)
kernel_value = ops.add(
kernel_value,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# For int4, `_kernel` is stored in a packed representation
# (two int4 values per int8 byte). We need to unpack it to the
# original float representation before merging with the LoRA
# update, and then pack it again after requantization.
if self.quantization_mode == "int4":
# 1. Unpack packed int4 tensor to int8 range [-8, 7].
unpacked_kernel = quantizers.unpack_int4(
kernel_value, self._orig_input_dim
)
# 2. De-scale to recover float32 kernel.
kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale)
# 3. Merge LoRA delta in float32 domain.
kernel_value_fp = ops.add(
kernel_value_fp,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
# 4. Re-quantize to int4 (values still held in int8 dtype).
kernel_int4, kernel_scale = quantizers.abs_max_quantize(
kernel_value_fp,
axis=0,
value_range=(-8, 7),
dtype="int8",
to_numpy=True,
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# 5. Pack the int4 values back into the compact int8 layout.
kernel_value, _, _ = quantizers.pack_int4(kernel_int4)
else:
# int8 path (regular): unpacking not required.
kernel_value = ops.divide(kernel_value, kernel_scale)
kernel_value = ops.add(
kernel_value,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
return kernel_value, kernel_scale
return self.kernel, None
Loading