@@ -158,25 +158,28 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
158158
159159class  Gemma3MLP (nn .Module ):
160160
161-     def  __init__ (self , config :  Gemma3TextConfig ):
161+     def  __init__ (self , model_config :  ModelConfig [ Gemma3TextConfig ] ):
162162        super ().__init__ ()
163-         self .config  =  config 
164-         self .hidden_size  =  config .hidden_size 
165-         self .intermediate_size  =  config .intermediate_size 
166-         self .dtype  =  config .torch_dtype 
163+         self .config  =  model_config . pretrained_config 
164+         self .hidden_size  =  self . config .hidden_size 
165+         self .intermediate_size  =  self . config .intermediate_size 
166+         self .dtype  =  self . config .torch_dtype 
167167        self .gate_proj  =  Linear (self .hidden_size ,
168168                                self .intermediate_size ,
169169                                bias = False ,
170-                                 dtype = self .dtype )
170+                                 dtype = self .dtype ,
171+                                 quant_config = model_config .get_quant_config ())
171172        self .up_proj  =  Linear (self .hidden_size ,
172173                              self .intermediate_size ,
173174                              bias = False ,
174-                               dtype = self .dtype )
175+                               dtype = self .dtype ,
176+                               quant_config = model_config .get_quant_config ())
175177        self .down_proj  =  Linear (self .intermediate_size ,
176178                                self .hidden_size ,
177179                                bias = False ,
178-                                 dtype = self .dtype )
179-         self .act_fn  =  ACT2FN [config .hidden_activation ]
180+                                 dtype = self .dtype ,
181+                                 quant_config = model_config .get_quant_config ())
182+         self .act_fn  =  ACT2FN [self .config .hidden_activation ]
180183
181184    @torch .inference_mode () 
182185    def  forward (self , x ):
@@ -202,7 +205,7 @@ def __init__(
202205            is_sliding = is_sliding ,
203206        )
204207
205-         self .mlp  =  Gemma3MLP (config )
208+         self .mlp  =  Gemma3MLP (model_config = model_config )
206209
207210        self .input_layernorm  =  RMSNorm (hidden_size = config .hidden_size ,
208211                                       eps = config .rms_norm_eps ,
0 commit comments