22
33from  keras .src  import  activations 
44from  keras .src  import  constraints 
5- from  keras .src  import  dtype_policies 
65from  keras .src  import  initializers 
76from  keras .src  import  ops 
87from  keras .src  import  quantizers 
@@ -110,9 +109,10 @@ def build(self, input_shape):
110109        kernel_shape  =  (input_shape [- 1 ], self .units )
111110        if  self .quantization_mode :
112111            self .quantized_build (kernel_shape , mode = self .quantization_mode )
113-         if  self .quantization_mode  !=  "int8" :
114-             # If the layer is quantized to int8, `self._kernel` will be added 
115-             # in `self._int8_build`. Therefore, we skip it here. 
112+         if  self .quantization_mode  not  in   ("int8" , "int4" ):
113+             # If the layer is quantized to int8 or int4, `self._kernel` will be 
114+             # added in `self._int8_build` or `_int4_build`. Therefore, we skip 
115+             # it here. 
116116            self ._kernel  =  self .add_weight (
117117                name = "kernel" ,
118118                shape = kernel_shape ,
@@ -182,9 +182,22 @@ def enable_lora(
182182                "lora is already enabled. This can only be done once per layer." 
183183            )
184184        self ._tracker .unlock ()
185+         # Determine the correct input dimension for the LoRA A matrix. When 
186+         # the layer has been int4-quantized, `self._kernel` stores a *packed* 
187+         # representation whose first dimension is `ceil(input_dim/2)`. We 
188+         # saved the true, *unpacked* input dimension in `self._orig_input_dim` 
189+         # during quantization. Use it if available; otherwise fall back to the 
190+         # first dimension of `self.kernel`. 
191+         if  self .quantization_mode  ==  "int4"  and  hasattr (
192+             self , "_orig_input_dim" 
193+         ):
194+             input_dim_for_lora  =  self ._orig_input_dim 
195+         else :
196+             input_dim_for_lora  =  self .kernel .shape [0 ]
197+ 
185198        self .lora_kernel_a  =  self .add_weight (
186199            name = "lora_kernel_a" ,
187-             shape = (self . kernel . shape [ 0 ] , rank ),
200+             shape = (input_dim_for_lora , rank ),
188201            initializer = initializers .get (a_initializer ),
189202            regularizer = self .kernel_regularizer ,
190203        )
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211224        if  self .use_bias :
212225            target_variables .append (self .bias )
213226        if  self .quantization_mode  is  not   None :
214-             if  self .quantization_mode  ==   "int8" :
227+             if  self .quantization_mode  in  ( "int8" ,  "int4" ) :
215228                target_variables .append (kernel_scale )
216229            elif  self .quantization_mode  ==  "float8" :
217230                target_variables .append (self .inputs_scale )
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237250        if  self .use_bias :
238251            target_variables .append (self .bias )
239252        if  self .quantization_mode  is  not   None :
240-             if  self .quantization_mode  ==   "int8" :
253+             if  self .quantization_mode  in  ( "int8" ,  "int4" ) :
241254                target_variables .append (self .kernel_scale )
242255            elif  self .quantization_mode  ==  "float8" :
243256                target_variables .append (self .inputs_scale )
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315328    def  quantized_build (self , kernel_shape , mode ):
316329        if  mode  ==  "int8" :
317330            self ._int8_build (kernel_shape )
331+         elif  mode  ==  "int4" :
332+             self ._int4_build (kernel_shape )
318333        elif  mode  ==  "float8" :
319334            self ._float8_build ()
320335        else :
@@ -337,6 +352,39 @@ def _int8_build(self, kernel_shape):
337352            trainable = False ,
338353        )
339354
355+     def  _int4_build (self , kernel_shape ):
356+         """Build variables for int4 quantization. 
357+ 
358+         `kernel_shape` is the *original* float32 kernel shape 
359+         `(input_dim, units)`. We allocate the stored kernel with rows 
360+         `ceil(input_dim/2)` because two int4 values are packed into a single 
361+         int8 byte. 
362+         """ 
363+         # Per-channel int8 quantizer for the last axis (features). 
364+         self .inputs_quantizer  =  quantizers .AbsMaxQuantizer (
365+             axis = - 1 ,
366+         )
367+         input_dim , output_dim  =  kernel_shape 
368+         packed_rows  =  (input_dim  +  1 ) //  2   # ceil for odd dims 
369+ 
370+         # Kernel is stored *packed*: each int8 byte contains two int4 values. 
371+         self ._kernel  =  self .add_weight (
372+             name = "kernel" ,
373+             shape = (packed_rows , output_dim ),
374+             initializer = "zeros" ,
375+             dtype = "int8" ,
376+             trainable = False ,
377+         )
378+         # One scale per output unit (per-channel). 
379+         self .kernel_scale  =  self .add_weight (
380+             name = "kernel_scale" ,
381+             shape = (self .units ,),
382+             initializer = "ones" ,
383+             trainable = False ,
384+         )
385+         # Record original input_dim for unpacking at runtime. 
386+         self ._orig_input_dim  =  input_dim 
387+ 
340388    def  _float8_build (self ):
341389        from  keras .src .dtype_policies  import  QuantizedFloat8DTypePolicy 
342390
@@ -383,6 +431,16 @@ def _float8_build(self):
383431    def  _int8_call (self , inputs , training = None ):
384432        @ops .custom_gradient  
385433        def  matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
434+             """Custom gradient function to handle the int8 quantized weights. 
435+ 
436+             Automatic differentiation will not know how to handle the int8 
437+             quantized weights. So a custom gradient function is needed to 
438+             handle the int8 quantized weights. 
439+ 
440+             The custom gradient function will use the dequantized kernel to 
441+             compute the gradient. 
442+             """ 
443+ 
386444            def  grad_fn (* args , upstream = None ):
387445                if  upstream  is  None :
388446                    (upstream ,) =  args 
@@ -415,6 +473,59 @@ def grad_fn(*args, upstream=None):
415473            x  =  self .activation (x )
416474        return  x 
417475
476+     def  _int4_call (self , inputs , training = None ):
477+         """Forward pass for int4 quantized Dense layer.""" 
478+ 
479+         @ops .custom_gradient  
480+         def  matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
481+             """Custom gradient function for int4 quantized weights. 
482+ 
483+             Automatic differentiation will not know how to handle the 
484+             int4 quantized weights. So a custom gradient function is needed 
485+             to handle the int4 quantized weights. 
486+ 
487+             The custom gradient function will use the dequantized kernel to 
488+             compute the gradient. 
489+             """ 
490+ 
491+             unpacked_kernel  =  quantizers .unpack_int4 (
492+                 kernel , self ._orig_input_dim 
493+             )
494+ 
495+             def  grad_fn (* args , upstream = None ):
496+                 if  upstream  is  None :
497+                     (upstream ,) =  args 
498+                 float_kernel  =  ops .divide (
499+                     ops .cast (unpacked_kernel , dtype = self .compute_dtype ),
500+                     kernel_scale ,
501+                 )
502+                 inputs_grad  =  ops .matmul (upstream , ops .transpose (float_kernel ))
503+                 return  (inputs_grad , None , None )
504+ 
505+             inputs , inputs_scale  =  self .inputs_quantizer (inputs )
506+             x  =  ops .matmul (inputs , unpacked_kernel )
507+             x  =  ops .cast (x , self .compute_dtype )
508+             x  =  ops .divide (x , ops .multiply (inputs_scale , kernel_scale ))
509+             return  x , grad_fn 
510+ 
511+         x  =  matmul_with_inputs_gradient (
512+             inputs ,
513+             ops .convert_to_tensor (self ._kernel ),
514+             ops .convert_to_tensor (self .kernel_scale ),
515+         )
516+ 
517+         if  self .lora_enabled :
518+             lora_x  =  ops .matmul (inputs , self .lora_kernel_a )
519+             lora_x  =  ops .matmul (lora_x , self .lora_kernel_b )
520+             x  =  ops .add (x , (self .lora_alpha  /  self .lora_rank ) *  lora_x )
521+ 
522+         # Add bias and activation 
523+         if  self .bias  is  not   None :
524+             x  =  ops .add (x , self .bias )
525+         if  self .activation  is  not   None :
526+             x  =  self .activation (x )
527+         return  x 
528+ 
418529    def  _float8_call (self , inputs , training = None ):
419530        if  self .lora_enabled :
420531            raise  NotImplementedError (
@@ -518,32 +629,117 @@ def quantize(self, mode, type_check=True):
518629            )
519630            kernel_scale  =  ops .squeeze (kernel_scale , axis = 0 )
520631            del  self ._kernel 
521-         self . quantized_build ( kernel_shape ,  mode ) 
522-         if   mode   ==   "int8" : 
632+              # Build variables for int8  mode
633+              self . quantized_build ( kernel_shape ,  mode ) 
523634            self ._kernel .assign (kernel_value )
524635            self .kernel_scale .assign (kernel_scale )
636+         elif  mode  ==  "int4" :
637+             # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) 
638+             kernel_value_int4 , kernel_scale  =  quantizers .abs_max_quantize (
639+                 self ._kernel ,
640+                 axis = 0 ,
641+                 value_range = (- 8 , 7 ),
642+                 dtype = "int8" ,
643+                 to_numpy = True ,
644+             )
645+             kernel_scale  =  ops .squeeze (kernel_scale , axis = 0 )
646+             # 2. Pack two int4 values into a single int8 byte. 
647+             packed_kernel_value , _ , _  =  quantizers .pack_int4 (kernel_value_int4 )
648+             del  self ._kernel 
649+             # Build variables using the original kernel shape; _int4_build will 
650+             # compute the packed shape internally. 
651+             self .quantized_build (kernel_shape , mode )
652+             # Assign packed values. 
653+             self ._kernel .assign (packed_kernel_value )
654+             self .kernel_scale .assign (kernel_scale )
655+         elif  mode  ==  "float8" :
656+             self .quantized_build (kernel_shape , mode )
657+         else :
658+             raise  self ._quantization_mode_error (mode )
525659
526-         # Set new dtype policy 
660+         # Set new dtype policy only for modes that already have a policy.  
527661        if  self .dtype_policy .quantization_mode  is  None :
662+             from  keras .src  import  dtype_policies   # local import to avoid cycle 
663+ 
528664            policy  =  dtype_policies .get (f"{ mode }  _from_{ self .dtype_policy .name }  " )
529665            self .dtype_policy  =  policy 
530666
531667    def  _get_kernel_with_merged_lora (self ):
668+         """Returns the kernel with LoRA matrices merged, for serialization. 
669+ 
670+         This method is called by `save_own_variables` to produce a single 
671+         kernel tensor that includes the adaptations from LoRA. This is useful 
672+         for deploying the model or for continuing training after permanently 
673+         applying the LoRA update. 
674+ 
675+         If the layer is quantized (`int8` or `int4`), the process is: 
676+         1. Dequantize the base kernel to float. 
677+         2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add 
678+             it to the dequantized kernel. 
679+         3. Re-quantize the merged result back to the original quantized 
680+             type (`int8` or packed `int4`), calculating a new scale factor. 
681+ 
682+         If the layer is not quantized, this method returns the result of the 
683+         `kernel` property (which computes the merge in floating-point) and a 
684+         scale of `None`. 
685+ 
686+         If LoRA is not enabled, it returns the original kernel and scale 
687+         without modification. 
688+ 
689+         Returns: 
690+             A tuple `(kernel_value, kernel_scale)`: 
691+                 `kernel_value`: The merged kernel. A quantized tensor if 
692+                     quantization is active, otherwise a high precision tensor. 
693+                 `kernel_scale`: The quantization scale for the merged kernel. 
694+                     This is `None` if the layer is not quantized. 
695+         """ 
532696        if  self .dtype_policy .quantization_mode  is  not   None :
533697            kernel_value  =  self ._kernel 
534698            kernel_scale  =  self .kernel_scale 
535699            if  self .lora_enabled :
536-                 # Dequantize & quantize to merge lora weights into int8 kernel 
537-                 # Note that this is a lossy compression 
538-                 kernel_value  =  ops .divide (kernel_value , kernel_scale )
539-                 kernel_value  =  ops .add (
540-                     kernel_value ,
541-                     (self .lora_alpha  /  self .lora_rank )
542-                     *  ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
700+                 # Dequantize kernel to float 
701+                 if  self .quantization_mode  ==  "int4" :
702+                     unpacked_kernel  =  quantizers .unpack_int4 (
703+                         kernel_value , self ._orig_input_dim 
704+                     )
705+                     float_kernel  =  ops .divide (
706+                         ops .cast (unpacked_kernel , self .compute_dtype ),
707+                         kernel_scale ,
708+                     )
709+                     quant_range  =  (- 8 , 7 )
710+                 elif  self .quantization_mode  ==  "int8" :
711+                     float_kernel  =  ops .divide (
712+                         ops .cast (kernel_value , self .compute_dtype ), kernel_scale 
713+                     )
714+                     quant_range  =  (- 127 , 127 )
715+                 else :
716+                     raise  ValueError (
717+                         "Unsupported quantization mode: " 
718+                         f"{ self .quantization_mode }  " 
719+                     )
720+ 
721+                 # Merge LoRA weights in float domain 
722+                 lora_delta  =  (self .lora_alpha  /  self .lora_rank ) *  ops .matmul (
723+                     self .lora_kernel_a , self .lora_kernel_b 
543724                )
544-                 kernel_value , kernel_scale  =  quantizers .abs_max_quantize (
545-                     kernel_value , axis = 0 , to_numpy = True 
725+                 merged_float_kernel  =  ops .add (float_kernel , lora_delta )
726+ 
727+                 # Requantize 
728+                 requantized_kernel , kernel_scale  =  quantizers .abs_max_quantize (
729+                     merged_float_kernel ,
730+                     axis = 0 ,
731+                     value_range = quant_range ,
732+                     dtype = "int8" ,
733+                     to_numpy = True ,
546734                )
547735                kernel_scale  =  ops .squeeze (kernel_scale , axis = 0 )
736+ 
737+                 # Pack if int4 
738+                 if  self .quantization_mode  ==  "int4" :
739+                     kernel_value , _ , _  =  quantizers .pack_int4 (
740+                         requantized_kernel 
741+                     )
742+                 else :
743+                     kernel_value  =  requantized_kernel 
548744            return  kernel_value , kernel_scale 
549745        return  self .kernel , None 
0 commit comments