Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions vllm/model_executor/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
)
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -1144,11 +1141,6 @@ def _make_empty_intermediate_tensors(
dtype=dtype,
device=device,
),
"residual": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
}
if ple_dim and ple_dim > 0:
tensors["per_layer_inputs"] = torch.zeros(
Expand Down Expand Up @@ -1312,13 +1304,12 @@ def forward(
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")

if per_layer_inputs is not None:
per_layer_inputs = intermediate_tensors["per_layer_inputs"]
Comment on lines +1313 to +1314

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for loading per_layer_inputs from intermediate_tensors is incorrect. Since per_layer_inputs is an argument to the forward method with a default value of None, the condition if per_layer_inputs is not None: will evaluate to False on all pipeline ranks greater than 0 (where it is not explicitly passed by the caller). This prevents the model from loading the per-layer embeddings sent from the previous rank, effectively disabling PLE on those ranks.

Additionally, IntermediateTensors does not implement a .get() method (which is likely why the previous code was changed), so you should access the underlying .tensors dictionary to safely retrieve the optional key.

Suggested change
if per_layer_inputs is not None:
per_layer_inputs = intermediate_tensors["per_layer_inputs"]
per_layer_inputs = intermediate_tensors.tensors.get(
"per_layer_inputs")

residual = None
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
Expand All @@ -1342,13 +1333,12 @@ def forward(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual,
"per_layer_inputs": per_layer_inputs,
}
)
tensors: dict[str, torch.Tensor] = {
"hidden_states": hidden_states,
}
if per_layer_inputs is not None:
tensors["per_layer_inputs"] = per_layer_inputs
return IntermediateTensors(tensors)
# Gemma4 incorporates residual into hidden_states directly
# Apply norm without residual fusion when possible.
if residual is None:
Expand Down
Loading