22
33from keras .src import activations
44from keras .src import constraints
5- from keras .src import dtype_policies
65from keras .src import initializers
76from keras .src import ops
87from keras .src import quantizers
@@ -110,9 +109,10 @@ def build(self, input_shape):
110109 kernel_shape = (input_shape [- 1 ], self .units )
111110 if self .quantization_mode :
112111 self .quantized_build (kernel_shape , mode = self .quantization_mode )
113- if self .quantization_mode != "int8" :
114- # If the layer is quantized to int8, `self._kernel` will be added
115- # in `self._int8_build`. Therefore, we skip it here.
112+ if self .quantization_mode not in ("int8" , "int4" ):
113+ # If the layer is quantized to int8 or int4, `self._kernel` will be
114+ # added in `self._int8_build` or `_int4_build`. Therefore, we skip
115+ # it here.
116116 self ._kernel = self .add_weight (
117117 name = "kernel" ,
118118 shape = kernel_shape ,
@@ -182,9 +182,22 @@ def enable_lora(
182182 "lora is already enabled. This can only be done once per layer."
183183 )
184184 self ._tracker .unlock ()
185+ # Determine the correct input dimension for the LoRA A matrix. When
186+ # the layer has been int4-quantized, `self._kernel` stores a *packed*
187+ # representation whose first dimension is `ceil(input_dim/2)`. We
188+ # saved the true, *unpacked* input dimension in `self._orig_input_dim`
189+ # during quantization. Use it if available; otherwise fall back to the
190+ # first dimension of `self.kernel`.
191+ if self .quantization_mode == "int4" and hasattr (
192+ self , "_orig_input_dim"
193+ ):
194+ input_dim_for_lora = self ._orig_input_dim
195+ else :
196+ input_dim_for_lora = self .kernel .shape [0 ]
197+
185198 self .lora_kernel_a = self .add_weight (
186199 name = "lora_kernel_a" ,
187- shape = (self . kernel . shape [ 0 ] , rank ),
200+ shape = (input_dim_for_lora , rank ),
188201 initializer = initializers .get (a_initializer ),
189202 regularizer = self .kernel_regularizer ,
190203 )
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211224 if self .use_bias :
212225 target_variables .append (self .bias )
213226 if self .quantization_mode is not None :
214- if self .quantization_mode == "int8" :
227+ if self .quantization_mode in ( "int8" , "int4" ) :
215228 target_variables .append (kernel_scale )
216229 elif self .quantization_mode == "float8" :
217230 target_variables .append (self .inputs_scale )
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237250 if self .use_bias :
238251 target_variables .append (self .bias )
239252 if self .quantization_mode is not None :
240- if self .quantization_mode == "int8" :
253+ if self .quantization_mode in ( "int8" , "int4" ) :
241254 target_variables .append (self .kernel_scale )
242255 elif self .quantization_mode == "float8" :
243256 target_variables .append (self .inputs_scale )
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315328 def quantized_build (self , kernel_shape , mode ):
316329 if mode == "int8" :
317330 self ._int8_build (kernel_shape )
331+ elif mode == "int4" :
332+ self ._int4_build (kernel_shape )
318333 elif mode == "float8" :
319334 self ._float8_build ()
320335 else :
@@ -337,6 +352,39 @@ def _int8_build(self, kernel_shape):
337352 trainable = False ,
338353 )
339354
355+ def _int4_build (self , kernel_shape ):
356+ """Build variables for int4 quantization.
357+
358+ `kernel_shape` is the *original* float32 kernel shape
359+ `(input_dim, units)`. We allocate the stored kernel with rows
360+ `ceil(input_dim/2)` because two int4 values are packed into a single
361+ int8 byte.
362+ """
363+ # Per-channel int8 quantizer for the last axis (features).
364+ self .inputs_quantizer = quantizers .AbsMaxQuantizer (
365+ axis = - 1 ,
366+ )
367+ input_dim , output_dim = kernel_shape
368+ packed_rows = (input_dim + 1 ) // 2 # ceil for odd dims
369+
370+ # Kernel is stored *packed*: each int8 byte contains two int4 values.
371+ self ._kernel = self .add_weight (
372+ name = "kernel" ,
373+ shape = (packed_rows , output_dim ),
374+ initializer = "zeros" ,
375+ dtype = "int8" ,
376+ trainable = False ,
377+ )
378+ # One scale per output unit (per-channel).
379+ self .kernel_scale = self .add_weight (
380+ name = "kernel_scale" ,
381+ shape = (self .units ,),
382+ initializer = "ones" ,
383+ trainable = False ,
384+ )
385+ # Record original input_dim for unpacking at runtime.
386+ self ._orig_input_dim = input_dim
387+
340388 def _float8_build (self ):
341389 from keras .src .dtype_policies import QuantizedFloat8DTypePolicy
342390
@@ -383,6 +431,16 @@ def _float8_build(self):
383431 def _int8_call (self , inputs , training = None ):
384432 @ops .custom_gradient
385433 def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
434+ """Custom gradient function to handle the int8 quantized weights.
435+
436+ Automatic differentiation will not know how to handle the int8
437+ quantized weights. So a custom gradient function is needed to
438+ handle the int8 quantized weights.
439+
440+ The custom gradient function will use the dequantized kernel to
441+ compute the gradient.
442+ """
443+
386444 def grad_fn (* args , upstream = None ):
387445 if upstream is None :
388446 (upstream ,) = args
@@ -415,6 +473,59 @@ def grad_fn(*args, upstream=None):
415473 x = self .activation (x )
416474 return x
417475
476+ def _int4_call (self , inputs , training = None ):
477+ """Forward pass for int4 quantized Dense layer."""
478+
479+ @ops .custom_gradient
480+ def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
481+ """Custom gradient function for int4 quantized weights.
482+
483+ Automatic differentiation will not know how to handle the
484+ int4 quantized weights. So a custom gradient function is needed
485+ to handle the int4 quantized weights.
486+
487+ The custom gradient function will use the dequantized kernel to
488+ compute the gradient.
489+ """
490+
491+ unpacked_kernel = quantizers .unpack_int4 (
492+ kernel , self ._orig_input_dim
493+ )
494+
495+ def grad_fn (* args , upstream = None ):
496+ if upstream is None :
497+ (upstream ,) = args
498+ float_kernel = ops .divide (
499+ ops .cast (unpacked_kernel , dtype = self .compute_dtype ),
500+ kernel_scale ,
501+ )
502+ inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
503+ return (inputs_grad , None , None )
504+
505+ inputs , inputs_scale = self .inputs_quantizer (inputs )
506+ x = ops .matmul (inputs , unpacked_kernel )
507+ x = ops .cast (x , self .compute_dtype )
508+ x = ops .divide (x , ops .multiply (inputs_scale , kernel_scale ))
509+ return x , grad_fn
510+
511+ x = matmul_with_inputs_gradient (
512+ inputs ,
513+ ops .convert_to_tensor (self ._kernel ),
514+ ops .convert_to_tensor (self .kernel_scale ),
515+ )
516+
517+ if self .lora_enabled :
518+ lora_x = ops .matmul (inputs , self .lora_kernel_a )
519+ lora_x = ops .matmul (lora_x , self .lora_kernel_b )
520+ x = ops .add (x , (self .lora_alpha / self .lora_rank ) * lora_x )
521+
522+ # Add bias and activation
523+ if self .bias is not None :
524+ x = ops .add (x , self .bias )
525+ if self .activation is not None :
526+ x = self .activation (x )
527+ return x
528+
418529 def _float8_call (self , inputs , training = None ):
419530 if self .lora_enabled :
420531 raise NotImplementedError (
@@ -518,32 +629,117 @@ def quantize(self, mode, type_check=True):
518629 )
519630 kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
520631 del self ._kernel
521- self . quantized_build ( kernel_shape , mode )
522- if mode == "int8" :
632+ # Build variables for int8 mode
633+ self . quantized_build ( kernel_shape , mode )
523634 self ._kernel .assign (kernel_value )
524635 self .kernel_scale .assign (kernel_scale )
636+ elif mode == "int4" :
637+ # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
638+ kernel_value_int4 , kernel_scale = quantizers .abs_max_quantize (
639+ self ._kernel ,
640+ axis = 0 ,
641+ value_range = (- 8 , 7 ),
642+ dtype = "int8" ,
643+ to_numpy = True ,
644+ )
645+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
646+ # 2. Pack two int4 values into a single int8 byte.
647+ packed_kernel_value , _ , _ = quantizers .pack_int4 (kernel_value_int4 )
648+ del self ._kernel
649+ # Build variables using the original kernel shape; _int4_build will
650+ # compute the packed shape internally.
651+ self .quantized_build (kernel_shape , mode )
652+ # Assign packed values.
653+ self ._kernel .assign (packed_kernel_value )
654+ self .kernel_scale .assign (kernel_scale )
655+ elif mode == "float8" :
656+ self .quantized_build (kernel_shape , mode )
657+ else :
658+ raise self ._quantization_mode_error (mode )
525659
526- # Set new dtype policy
660+ # Set new dtype policy only for modes that already have a policy.
527661 if self .dtype_policy .quantization_mode is None :
662+ from keras .src import dtype_policies # local import to avoid cycle
663+
528664 policy = dtype_policies .get (f"{ mode } _from_{ self .dtype_policy .name } " )
529665 self .dtype_policy = policy
530666
531667 def _get_kernel_with_merged_lora (self ):
668+ """Returns the kernel with LoRA matrices merged, for serialization.
669+
670+ This method is called by `save_own_variables` to produce a single
671+ kernel tensor that includes the adaptations from LoRA. This is useful
672+ for deploying the model or for continuing training after permanently
673+ applying the LoRA update.
674+
675+ If the layer is quantized (`int8` or `int4`), the process is:
676+ 1. Dequantize the base kernel to float.
677+ 2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add
678+ it to the dequantized kernel.
679+ 3. Re-quantize the merged result back to the original quantized
680+ type (`int8` or packed `int4`), calculating a new scale factor.
681+
682+ If the layer is not quantized, this method returns the result of the
683+ `kernel` property (which computes the merge in floating-point) and a
684+ scale of `None`.
685+
686+ If LoRA is not enabled, it returns the original kernel and scale
687+ without modification.
688+
689+ Returns:
690+ A tuple `(kernel_value, kernel_scale)`:
691+ `kernel_value`: The merged kernel. A quantized tensor if
692+ quantization is active, otherwise a high precision tensor.
693+ `kernel_scale`: The quantization scale for the merged kernel.
694+ This is `None` if the layer is not quantized.
695+ """
532696 if self .dtype_policy .quantization_mode is not None :
533697 kernel_value = self ._kernel
534698 kernel_scale = self .kernel_scale
535699 if self .lora_enabled :
536- # Dequantize & quantize to merge lora weights into int8 kernel
537- # Note that this is a lossy compression
538- kernel_value = ops .divide (kernel_value , kernel_scale )
539- kernel_value = ops .add (
540- kernel_value ,
541- (self .lora_alpha / self .lora_rank )
542- * ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
700+ # Dequantize kernel to float
701+ if self .quantization_mode == "int4" :
702+ unpacked_kernel = quantizers .unpack_int4 (
703+ kernel_value , self ._orig_input_dim
704+ )
705+ float_kernel = ops .divide (
706+ ops .cast (unpacked_kernel , self .compute_dtype ),
707+ kernel_scale ,
708+ )
709+ quant_range = (- 8 , 7 )
710+ elif self .quantization_mode == "int8" :
711+ float_kernel = ops .divide (
712+ ops .cast (kernel_value , self .compute_dtype ), kernel_scale
713+ )
714+ quant_range = (- 127 , 127 )
715+ else :
716+ raise ValueError (
717+ "Unsupported quantization mode: "
718+ f"{ self .quantization_mode } "
719+ )
720+
721+ # Merge LoRA weights in float domain
722+ lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
723+ self .lora_kernel_a , self .lora_kernel_b
543724 )
544- kernel_value , kernel_scale = quantizers .abs_max_quantize (
545- kernel_value , axis = 0 , to_numpy = True
725+ merged_float_kernel = ops .add (float_kernel , lora_delta )
726+
727+ # Requantize
728+ requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
729+ merged_float_kernel ,
730+ axis = 0 ,
731+ value_range = quant_range ,
732+ dtype = "int8" ,
733+ to_numpy = True ,
546734 )
547735 kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
736+
737+ # Pack if int4
738+ if self .quantization_mode == "int4" :
739+ kernel_value , _ , _ = quantizers .pack_int4 (
740+ requantized_kernel
741+ )
742+ else :
743+ kernel_value = requantized_kernel
548744 return kernel_value , kernel_scale
549745 return self .kernel , None
0 commit comments