diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 451b332eda23..cfd4c7a56b43 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -137,7 +137,7 @@ def fi_chunk_gated_delta_rule( fi_state = initial_state.to(torch.float32) fi_g = g.to(torch.float32) fi_beta = beta.to(torch.float32) - output, final_state = chunk_gated_delta_rule_fi( + result = chunk_gated_delta_rule_fi( q=q, k=k, v=v, @@ -147,8 +147,14 @@ def fi_chunk_gated_delta_rule( output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) + # FlashInfer returns (output, state) when output_final_state=True, + # or just output when output_final_state=False. # Unsqueeze back to 4D (1, L, H, D) to match fla output format - return output.unsqueeze(0), final_state + if output_final_state: + output, final_state = result + return output.unsqueeze(0), final_state + else: + return result.unsqueeze(0), None @CustomOp.register("chunk_gated_delta_rule")