From 8ce859bc2b9ce5ddbae9bdc5b1deacce1cfc156e Mon Sep 17 00:00:00 2001 From: Rohit kumar Singh Date: Fri, 24 Apr 2026 10:49:31 +0300 Subject: [PATCH 1/2] Fix PP in Gemma4 --- vllm/model_executor/models/gemma4.py | 30 ++++++++++------------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index bb91fd601e70..e56238cd9925 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -37,10 +37,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import ( - FusedMoE, - GateLinear, -) +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -1144,11 +1141,6 @@ def _make_empty_intermediate_tensors( dtype=dtype, device=device, ), - "residual": torch.zeros( - (batch_size, hidden_size), - dtype=dtype, - device=device, - ), } if ple_dim and ple_dim > 0: tensors["per_layer_inputs"] = torch.zeros( @@ -1312,13 +1304,12 @@ def forward( per_layer_inputs = self.project_per_layer_inputs( hidden_states, per_layer_embeds ) - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - per_layer_inputs = intermediate_tensors.get("per_layer_inputs") - + if per_layer_inputs is not None: + per_layer_inputs = intermediate_tensors["per_layer_inputs"] + residual = None aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) for layer_idx, layer in enumerate( islice(self.layers, self.start_layer, self.end_layer) @@ -1342,13 +1333,12 @@ def forward( aux_hidden_states, layer_idx + 1, hidden_states, residual ) if not get_pp_group().is_last_rank: - return IntermediateTensors( - { - "hidden_states": hidden_states, - "residual": residual, - "per_layer_inputs": per_layer_inputs, - } - ) + tensors: dict[str, torch.Tensor] = { + "hidden_states": hidden_states, + } + if per_layer_inputs is not None: + tensors["per_layer_inputs"] = per_layer_inputs + return IntermediateTensors(tensors) # Gemma4 incorporates residual into hidden_states directly # Apply norm without residual fusion when possible. if residual is None: From b9f49801082c3c67349bfab8cab6f3ee37f2713a Mon Sep 17 00:00:00 2001 From: Rohit kumar Singh Date: Fri, 24 Apr 2026 11:02:54 +0300 Subject: [PATCH 2/2] fix indentation Signed-off-by: Rohit kumar Singh --- vllm/model_executor/models/gemma4.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index e56238cd9925..b724fa71968c 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -37,7 +37,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + GateLinear, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear,