diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index bb91fd601e70..b724fa71968c 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -1144,11 +1144,6 @@ def _make_empty_intermediate_tensors( dtype=dtype, device=device, ), - "residual": torch.zeros( - (batch_size, hidden_size), - dtype=dtype, - device=device, - ), } if ple_dim and ple_dim > 0: tensors["per_layer_inputs"] = torch.zeros( @@ -1312,13 +1307,12 @@ def forward( per_layer_inputs = self.project_per_layer_inputs( hidden_states, per_layer_embeds ) - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - per_layer_inputs = intermediate_tensors.get("per_layer_inputs") - + if per_layer_inputs is not None: + per_layer_inputs = intermediate_tensors["per_layer_inputs"] + residual = None aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) for layer_idx, layer in enumerate( islice(self.layers, self.start_layer, self.end_layer) @@ -1342,13 +1336,12 @@ def forward( aux_hidden_states, layer_idx + 1, hidden_states, residual ) if not get_pp_group().is_last_rank: - return IntermediateTensors( - { - "hidden_states": hidden_states, - "residual": residual, - "per_layer_inputs": per_layer_inputs, - } - ) + tensors: dict[str, torch.Tensor] = { + "hidden_states": hidden_states, + } + if per_layer_inputs is not None: + tensors["per_layer_inputs"] = per_layer_inputs + return IntermediateTensors(tensors) # Gemma4 incorporates residual into hidden_states directly # Apply norm without residual fusion when possible. if residual is None: