@@ -59,16 +59,23 @@ def __init__(
5959 intermediate_size : int ,
6060 hidden_activation : str ,
6161 quant_config : Optional [QuantizationConfig ] = None ,
62+ prefix : str = "" ,
6263 ) -> None :
6364 super ().__init__ ()
6465 self .gate_up_proj = MergedColumnParallelLinear (
65- hidden_size , [intermediate_size ] * 2 ,
66+ hidden_size ,
67+ [intermediate_size ] * 2 ,
68+ bias = False ,
69+ quant_config = quant_config ,
70+ prefix = f"{ prefix } .gate_up_proj" ,
71+ )
72+ self .down_proj = RowParallelLinear (
73+ intermediate_size ,
74+ hidden_size ,
6675 bias = False ,
67- quant_config = quant_config )
68- self .down_proj = RowParallelLinear (intermediate_size ,
69- hidden_size ,
70- bias = False ,
71- quant_config = quant_config )
76+ quant_config = quant_config ,
77+ prefix = f"{ prefix } .down_proj" ,
78+ )
7279 if hidden_activation != "gelu_pytorch_tanh" :
7380 raise ValueError (
7481 "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
@@ -125,12 +132,14 @@ def __init__(self,
125132 self .total_num_kv_heads ,
126133 bias = config .attention_bias ,
127134 quant_config = quant_config ,
135+ prefix = f"{ prefix } .qkv_proj" ,
128136 )
129137 self .o_proj = RowParallelLinear (
130138 self .total_num_heads * self .head_dim ,
131139 hidden_size ,
132140 bias = config .attention_bias ,
133141 quant_config = quant_config ,
142+ prefix = f"{ prefix } .o_proj" ,
134143 )
135144
136145 self .q_norm = GemmaRMSNorm (self .head_dim , eps = config .rms_norm_eps )
@@ -293,6 +302,7 @@ def __init__(
293302 intermediate_size = config .intermediate_size ,
294303 hidden_activation = config .hidden_activation ,
295304 quant_config = quant_config ,
305+ prefix = f"{ prefix } .mlp" ,
296306 )
297307 self .input_layernorm = GemmaRMSNorm (config .hidden_size ,
298308 eps = config .rms_norm_eps )
@@ -344,6 +354,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
344354 self .embed_tokens = VocabParallelEmbedding (
345355 config .vocab_size ,
346356 config .hidden_size ,
357+ prefix = f"{ prefix } .embed_tokens" ,
347358 )
348359 self .start_layer , self .end_layer , self .layers = make_layers (
349360 config .num_hidden_layers ,
0 commit comments