Skip to content

Commit 7e0158b

Browse files
Qwen3: Fix eagle hidden states (#6199)
Signed-off-by: Izzy Putterman <[email protected]>
1 parent a16ba64 commit 7e0158b

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def forward(
214214
if residual is None:
215215
residual = hidden_states
216216
hidden_states = self.input_layernorm(hidden_states)
217-
217+
if spec_metadata is not None and spec_metadata.is_layer_capture(
218+
self.layer_idx):
219+
self.fusion_config.POST_MOE_FUSION = False
218220
# Self Attention
219221
hidden_states = self.self_attn(
220222
position_ids=position_ids,
@@ -257,9 +259,6 @@ def forward(
257259

258260
if self.fusion_config.POST_MOE_FUSION:
259261
if do_finalize:
260-
if spec_metadata:
261-
spec_metadata.maybe_capture_hidden_states(
262-
self.layer_idx, hidden_states, residual)
263262
hidden_states, residual = self.allreduce(
264263
hidden_states,
265264
all_reduce_params=AllReduceParams(
@@ -289,12 +288,8 @@ def forward(
289288
hidden_states, residual = self.moe_allreduce(
290289
fc2_output, all_reduce_params=moe_all_reduce_params)
291290

292-
if spec_metadata:
293-
spec_metadata.maybe_capture_hidden_states(
294-
self.layer_idx, hidden_states, residual)
295-
296291
else:
297-
if spec_metadata:
292+
if spec_metadata and spec_metadata.is_layer_capture(self.layer_idx):
298293
spec_metadata.maybe_capture_hidden_states(
299294
self.layer_idx, hidden_states, residual)
300295
if self.next_layer_layernorm is not None:

0 commit comments

Comments
 (0)