Skip to content

[Bugfix] CustomAR + TritonAttn[AMPERE] + FULL_CG - gpt-oss#30650

Closed
bbrowning wants to merge 1 commit intovllm-project:mainfrom
bbrowning:custom-all-reduce-fix
Closed

[Bugfix] CustomAR + TritonAttn[AMPERE] + FULL_CG - gpt-oss#30650
bbrowning wants to merge 1 commit intovllm-project:mainfrom
bbrowning:custom-all-reduce-fix

Conversation

@bbrowning
Copy link
Contributor

@bbrowning bbrowning commented Dec 14, 2025

Purpose

Full CUDA graph capture + replay (the default since vLLM v0.11.0) could lead to memory correctness issues in the custom_all_reduce implementation due to previous logic that skipped running the actual all_reduce operation during the compile and warmup phase. This resulted in compile and warmup seeing a memory allocation, but not observing the custom op call or the actual reduce operation which could lead to optmizations being taken assuming purely functional code with no side-effects when in fact the reduce operation inherently has side effects.

To fix this, instead of allocating an empty tensor to mimic the reduce operation, we now call the actual all_reduce during compile and warmup as well.

Multiple users reported an issue that tracks back to this root cause on Ampere hardware, which is one platform where we use custom_all_reduce by default if tensor parallelism is used and NVLink is available.

Examples:

Note that until I rebased this on top of a very recent main, I also had to tag the all_reduce custom op registration in parallel_state.py with torch.Tag.maybe_aliasing_or_mutating to fix the reproducer described in #29998. The tag is no longer required on latest main, although I have not identified exactly why or what changed that fixed that.

Test Plan

On Ampere hardware with NVLink enabled (2x A5500 in my case), start gpt-oss-20b with tensor-parallel-size 2:

vllm serve openai/gpt-oss-20b \
  --tool-call-parser openai \
  --enable-auto-tool-choice \
  --tensor-parallel-size 2

Then, execute the curl reproducer from #29998 in a loop:

while true; do curl -X POST http://localhost:8000/v1/chat/completions  -H "Content-Type: application/json"   -d '{
    "model": "openai/gpt-oss-20b",
    "stream": false,
    "messages": [
      {
        "role": "system",
        "content": "Be a helpful assistant."
      },
      {
        "role": "user",
        "content": "Hi"
      },
      {
        "role": "assistant",
        "content": "How can I help you?"
      },
      {
        "role": "user",
        "content": "Do you like Monty Python?"
      }
    ],
    "tools": [
      {
        "type": "function",
        "function": {
          "name": "CHANGE-NAME-BEFORE-SENDING",
          "description": "Use this tool if you need to extract information from a website.",
          "parameters": {
            "type": "object",
            "properties": {
              "url": {
                "type": "string",
                "description": "The URL to search or extract information from."
              }
            },
            "required": ["url"]
          }
        }
      }
    ]
  }'; done

I also exercised the reproducer from #30498, slightly modified to point to gpt-oss-20b (instead of 120b) and to raise the max_tokens to 8192 as it was generating valid but verbose output for the first few iterations of the loop.

Test Result

Before this change, the 2nd time the curl loop from #29998 executed it always hung and the vLLM continued to generated token id 0 indefinitely. After this change, the 2nd and all subsequent curl requests succeed as expected.

Before this change, the the python requests loop from #30498 always hung by the 2nd or 3rd request. After this change, it continues to work for multiple subsequent requests.

I was able to successfully reproduce the test cases in both of those issues on my hardware and this change fixes both of those.

Full CUDA graph capture + replay (the default since vLLM v0.11.0) could
lead to memory correctness issues in the custom_all_reduce
implementation due to previous logic that skipped running the actual
all_reduce operation during the compile and warmup phase. This resulted
in compile and warmup seeing a memory allocation, but not observing the
custom op call or the actual reduce operation which could lead to
optmizations being taken assuming purely functional code with no
side-effects when in fact the reduce operation inherently has side
effects.

To fix this, instead of allocating an empty tensor to mimic the reduce
operation, we now call the actual all_reduce during compile and warmup
as well.

Multiple users reported an issue that tracks back to this root cause on
Ampere hardware, which is one platform where we use custom_all_reduce by
default if tensor parallelism is used and NVLink is available.

Examples:
- vllm-project#26480
- vllm-project#29998
- vllm-project#30498

Note that until I rebased this on top of a very recent main, I also had
to tag the all_reduce custom op registration in parallel_state.py with
`torch.Tag.maybe_aliasing_or_mutating` to fix the reproducer described
in vllm-project#29998. The tag is no longer required on latest main, although I have
no identified exactly why or what changed that fixed that.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a memory correctness issue in custom_all_reduce when using full CUDA graph capture. The previous implementation skipped the actual all-reduce operation during the warmup phase, leading to inconsistencies between warmup, capture, and replay, which could cause hangs. The fix correctly replaces the dummy tensor allocation with a call to the actual all_reduce operation during warmup. This ensures that the memory allocation, copy, and collective behaviors are consistent across all phases. The change is logical, well-explained, and directly resolves the reported bug. The code is clean and the fix is appropriate.

@bbrowning
Copy link
Contributor Author

While all user reports of this that I've come across have been on Ampere hardware (A5500, A6000, A1000 80GB), I have also been able to reproduce this on H100 GPUs when Triton ATTN is in use. After reproducing the same issue on an H100 (by setting VLLM_ATTENTION_BACKEND=TRITON_ATTN), I can confirm this fixes that issue on Hopper GPUs as well when Triton ATTN is in use.

@robertgshaw2-redhat robertgshaw2-redhat changed the title [Bugfix] Issue with custom_all_reduce and cudagraph_mode FULL [Bugfix] CustomAR + TritonAttn[AMPERE] + gpt-oss Dec 14, 2025
@robertgshaw2-redhat robertgshaw2-redhat changed the title [Bugfix] CustomAR + TritonAttn[AMPERE] + gpt-oss [Bugfix] CustomAR + TritonAttn[AMPERE] + FULL_CG - gpt-oss Dec 14, 2025
@mergify mergify bot added the gpt-oss Related to GPT-OSS models label Dec 14, 2025
@xyang16
Copy link
Contributor

xyang16 commented Dec 15, 2025

While all user reports of this that I've come across have been on Ampere hardware (A5500, A6000, A1000 80GB), I have also been able to reproduce this on H100 GPUs when Triton ATTN is in use. After reproducing the same issue on an H100 (by setting VLLM_ATTENTION_BACKEND=TRITON_ATTN), I can confirm this fixes that issue on Hopper GPUs as well when Triton ATTN is in use.

Thanks for the fix! Curious why skipping all_reduce operation would only cause problem for Triton ATTN?

@bbrowning
Copy link
Contributor Author

While all user reports of this that I've come across have been on Ampere hardware (A5500, A6000, A1000 80GB), I have also been able to reproduce this on H100 GPUs when Triton ATTN is in use. After reproducing the same issue on an H100 (by setting VLLM_ATTENTION_BACKEND=TRITON_ATTN), I can confirm this fixes that issue on Hopper GPUs as well when Triton ATTN is in use.

Thanks for the fix! Curious why skipping all_reduce operation would only cause problem for Triton ATTN on H100?

I'm not 100% sure that doesn't cause problems for other attention backends. However, the reproducers we've had reported from users all had Triton ATTN in common, and were all on Ampere hardware (where we use Triton ATTN by default). By manually forcing Triton ATTN on H100 hardware, I was also able to reproduce it there.

There could theoretically be correctness issues with other attention backends in certain scenarios that this fixes, but I have no direct evidence of that myself.

@ApostaC
Copy link
Collaborator

ApostaC commented Dec 16, 2025

Hey @mgoin , could you please help review this PR? Thanks!

@msrodlab
Copy link

I have been reliably producing this issue on 4x 40 GB A100 but not on 2x 80 GB H100 with the current master branch. I will let you know if this PR fixes this issue.

else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could maybe accomplish the same thing if it was torch.zeros_like(input)? Initial tests seem promising, but I'm not sure how to interpret what that means as far as the root issue here and what this actually fixes.

@msrodlab
Copy link

msrodlab commented Dec 16, 2025

So I was able to test this PR, and while it does prevent GPT OSS from simply generating nothing now, it appears that some bugs that were fixed in #28729 are back, particularly the one that I described where no final answer is produced
#28729 (comment)

I have also noticed a few scenarios where it seems like the model is outputting reasoning content in the content field.

Attached is a request body you can reproduce the first issue where no final answer is produced:
request-body.txt

Here are the startup args I use for 4x A100 40gb, I use --async-scheduling as mentioned in the GPT OSS vLLM cookbook:

CUDA_VISIBLE_DEVICES=0,1,2,3  python -m vllm.entrypoints.openai.api_server --model openai/gpt-oss-120b --tensor-parallel-size 4 --max-model-len 128000 --seed 42 --gpu-memory-utilization 0.90 --kv-cache-dtype auto --enable-tool-auto-choice --tool-call-parser openai  --async-scheduling

And for 2x H100 80gb

CUDA_VISIBLE_DEVICES=0,1  python -m vllm.entrypoints.openai.api_server --model openai/gpt-oss-120b --tensor-parallel-size 2 --max-model-len 128000 --seed 42 --gpu-memory-utilization 0.90 --kv-cache-dtype auto --enable-tool-auto-choice --tool-call-parser openai  

I have attached the raw output for the two different machines as well.
a100-output.txt
h100-output.txt

My H100s are using the commit based on #28729

I did not specify my attention backend in either scenario, so I am using the default for H100, and for the A100, TRITON_ATTN is used.

@chaunceyjiang
Copy link
Collaborator

@bbrowning
Copy link
Contributor Author

@chaunceyjiang This fix is likely unrelated to that error. That error looks like maybe someone changed the way we generate harmony messages and forgot to update a test? But, those tests are testing for correctness, so it may be that the change itself is wrong and the test is right.

[2025-12-18T06:38:30Z] messages = [Message(author=Author(role=<Role.SYSTEM: 'system'>, name=None), content=[SystemContent(model_identity='You are ChatGP...SER: 'user'>, name=None), content=[TextContent(text='what is 1+1?')], channel=None, recipient=None, content_type=None)]
--
[2025-12-18T06:38:30Z] expected_messages = [{'role': 'system'}, {'role': 'developer'}, {'content': 'what is 1+1?', 'role': 'user'}]

The test wants 3 messages - one with a system role, one with a developer role, and one with a user role and specific content. It looks like something wrongly converted those to a single message instsead.

@bbrowning
Copy link
Contributor Author

I believe this is just a fancy bandaid of the underlying issue, and #30887 fixes the underlying issue with Triton attention and sliding window models (like gpt-oss).

@bbrowning
Copy link
Contributor Author

With the merging of #30887, I believe this fix is no longer needed.

@bbrowning bbrowning closed this Dec 19, 2025
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Dec 19, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Dec 19, 2025
@bbrowning
Copy link
Contributor Author

So I was able to test this PR, and while it does prevent GPT OSS from simply generating nothing now, it appears that some bugs that were fixed in #28729 are back, particularly the one that I described where no final answer is produced #28729 (comment)

...

@bbRLdev I am looking into this using the sample request you gave to see if I can reproduce it on top of the latest main. I appreciate the detailed report, and it might be best to pull your comment out to a new issue so that we can properly track finding the root cause and solving what you're seeing. I believe you've hit on a separate real bug that we need to fix, as with your reproducer even in non-streaming mode I see the model is generating tool calls but we're not properly getting those back to the caller for whatever reason.

@bbrowning
Copy link
Contributor Author

@bbRLdev To fully track this down, I'll need to find the full parameters Cline is sending to vLLM here, including the tool_choice and tools parameters. The request body you gave had no tools, but the model is generating a tool call internally in response to this request.

@wangln19
Copy link
Contributor

hello, I meet something silimar, I'm trying to run gpt-oss-120b with EAGLE3 speculative decoding on vLLM, but getting CUDA errors during graph capture.
Environment:

  • vLLM version: vllm-main-20260121
  • GPU: 4xh20 (tensor_parallel_size=4)
  • Model: gpt-oss-120b
  • Draft model: EAGLE3 (draft_tensor_parallel_size=1)
    Command:
    vllm serve --model /path/to/gpt-oss-120b
    --tensor-parallel-size 4
    --max-model-len 131072
    --gpu-memory-utilization 0.85
    --speculative-config '{"model":"/path/to/eagle3","method":"eagle3","num_speculative_tokens":5,"draft_tensor_parallel_size":1}'
    Error:
    During CUDA graph capture phase (0%):
    Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:455 'an illegal memory access was encountered'
    What I've tried:
  1. --enforce-eager: works, but performance impact
  2. --disable-custom-all-reduce: works
    Question:
    Is there a known incompatibility between EAGLE3 speculative decoding and CUDA graph capture with tensor parallelism > 1? What's the recommended configuration?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models nvidia

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants