Skip to content

Commit 9ec8257

Browse files
cloud11665Bartholomew Sabat
andauthored
[Model] Add module name prefixes to gemma3 (#15889)
Signed-off-by: Bartholomew Sabat <[email protected]> Co-authored-by: Bartholomew Sabat <[email protected]>
1 parent 38327cf commit 9ec8257

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

vllm/model_executor/models/gemma3.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,23 @@ def __init__(
5959
intermediate_size: int,
6060
hidden_activation: str,
6161
quant_config: Optional[QuantizationConfig] = None,
62+
prefix: str = "",
6263
) -> None:
6364
super().__init__()
6465
self.gate_up_proj = MergedColumnParallelLinear(
65-
hidden_size, [intermediate_size] * 2,
66+
hidden_size,
67+
[intermediate_size] * 2,
68+
bias=False,
69+
quant_config=quant_config,
70+
prefix=f"{prefix}.gate_up_proj",
71+
)
72+
self.down_proj = RowParallelLinear(
73+
intermediate_size,
74+
hidden_size,
6675
bias=False,
67-
quant_config=quant_config)
68-
self.down_proj = RowParallelLinear(intermediate_size,
69-
hidden_size,
70-
bias=False,
71-
quant_config=quant_config)
76+
quant_config=quant_config,
77+
prefix=f"{prefix}.down_proj",
78+
)
7279
if hidden_activation != "gelu_pytorch_tanh":
7380
raise ValueError(
7481
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
@@ -125,12 +132,14 @@ def __init__(self,
125132
self.total_num_kv_heads,
126133
bias=config.attention_bias,
127134
quant_config=quant_config,
135+
prefix=f"{prefix}.qkv_proj",
128136
)
129137
self.o_proj = RowParallelLinear(
130138
self.total_num_heads * self.head_dim,
131139
hidden_size,
132140
bias=config.attention_bias,
133141
quant_config=quant_config,
142+
prefix=f"{prefix}.o_proj",
134143
)
135144

136145
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -293,6 +302,7 @@ def __init__(
293302
intermediate_size=config.intermediate_size,
294303
hidden_activation=config.hidden_activation,
295304
quant_config=quant_config,
305+
prefix=f"{prefix}.mlp",
296306
)
297307
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
298308
eps=config.rms_norm_eps)
@@ -344,6 +354,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
344354
self.embed_tokens = VocabParallelEmbedding(
345355
config.vocab_size,
346356
config.hidden_size,
357+
prefix=f"{prefix}.embed_tokens",
347358
)
348359
self.start_layer, self.end_layer, self.layers = make_layers(
349360
config.num_hidden_layers,

0 commit comments

Comments
 (0)