Skip to content

[ROCM] Optmize redudent d2d copy of moe.#38597

Closed
benenzhu wants to merge 4 commits into
vllm-project:mainfrom
benenzhu:amd/opt_moe
Closed

[ROCM] Optmize redudent d2d copy of moe.#38597
benenzhu wants to merge 4 commits into
vllm-project:mainfrom
benenzhu:amd/opt_moe

Conversation

@benenzhu
Copy link
Copy Markdown
Contributor

@benenzhu benenzhu commented Mar 31, 2026

Purpose

When running MiniMax M2.5, currently the MOE kernel's output will have two d2d copys, which can be optimized.
image

Test Plan

Run MiniMax M2.5 inference on Mi355.
Confirm output correctness matches the non-aiter fallback path by comparing the accuracy of the model.

Benchmarking

export VLLM_ROCM_USE_AITER=1
vllm serve MiniMaxAI/MiniMax-M2.5 \
    --tensor-parallel-size 8 \
    --enable_expert-parallel \
    --max-num-batched-tokens 196608 \
    --max-model-len=10240 \
    --max-num-seqs 512 \
    --block-size=32 \
    --trust-remote-code \
    --no-enable-prefix-caching \
    --port=30000
vllm bench serve \
  --model MiniMaxAI/MiniMax-M2.5 \
  --dataset-name sharegpt \
  --dataset-path /A/datasets/ShareGPT_V3_unfiltered_cleaned_split.json \
  --sharegpt-output-len 300 \
  --port 30000 \
  --max-concurrency 8 \
  --num-prompts 1000 \
  --num-warmups 50 \
  --ignore-eos \
  --temperature 0

Before:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  442.59    
Total input tokens:                      225133    
Total generated tokens:                  300000    
Request throughput (req/s):              2.26      
Output token throughput (tok/s):         677.83    
Peak output token throughput (tok/s):    704.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          1186.49   
---------------Time to First Token----------------
Mean TTFT (ms):                          57.49     
Median TTFT (ms):                        50.00     
P99 TTFT (ms):                           106.37    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.65     
Median TPOT (ms):                        11.65     
P99 TPOT (ms):                           11.90     
---------------Inter-token Latency----------------
Mean ITL (ms):                           11.65     
Median ITL (ms):                         11.54     
P99 ITL (ms):                            13.49     
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  427.85    
Total input tokens:                      225133    
Total generated tokens:                  300000    
Request throughput (req/s):              2.34      
Output token throughput (tok/s):         701.18    
Peak output token throughput (tok/s):    736.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          1227.38   
---------------Time to First Token----------------
Mean TTFT (ms):                          57.12     
Median TTFT (ms):                        47.68     
P99 TTFT (ms):                           107.23    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.25     
Median TPOT (ms):                        11.26     
P99 TPOT (ms):                           11.54     
---------------Inter-token Latency----------------
Mean ITL (ms):                           11.25     
Median ITL (ms):                         11.12     
P99 ITL (ms):                            14.14     
==================================================

e2e TPOT drop by about 3.4%.

Accuracy Testing

python3 -m lm_eval --model local-completions \
  --model_args model=MiniMaxAI/MiniMax-M2.5,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=64 \
  --tasks gsm8k

main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9234|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.9212|±  |0.0074|```

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9318|±  |0.0069|
|     |       |strict-match    |     5|exact_match|↑  |0.9280|±  |0.0071|

There is no significant difference in accuracy.

cc @gshtras @chunfangamd
prev: #38346

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added the rocm Related to AMD ROCm label Mar 31, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 31, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations to the MoE layers by replacing tensor copies with pointer reassignments or conditional copies. However, the review identifies critical safety concerns: reassigning the output to a transient workspace buffer in modular_kernel.py could lead to data corruption as the memory may be overwritten by subsequent operations, and using .data reassignment in rocm_aiter_fused_moe.py is unsafe for CUDA graph capture and bypasses proper buffer management.

Comment on lines +1380 to +1385
if (
not self.inplace
and fused_out.shape == output.shape
and fused_out.is_contiguous()
):
output = fused_out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Reassigning output = fused_out is highly dangerous because fused_out is a view into the transient workspace memory (allocated via current_workspace_manager). Returning this tensor means the MoE layer's output can be silently overwritten by any subsequent operation that requests workspace buffers, leading to data corruption or non-deterministic results. The copy from the workspace buffer to the persistent output tensor (allocated at line 1350) is mandatory to ensure the result persists correctly throughout the model's execution.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing fused_out for output is dangerous beyond the fact that it might point to the temporary buffers. Doing this basically forces all the finalize methods to operate inplace which may or may not be supported.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added platform.is_rocm(), only effect in rocm's finalize

Copy link
Copy Markdown
Contributor

@Rohan138 Rohan138 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check should be rocm_aiter_ops.is_fused_moe_enabled(). For other MOE backends, ROCm devices that don't support AITER, AITER fused_moe explicitly disabled, etc. we still want to keep the default behavior.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thanks. Should work only when AITER enabled.

output_dtype=output.dtype,
)
output.copy_(result)
output.data = result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using output.data = result is unsafe and discouraged. This shallow pointer swap is incompatible with CUDA graphs, as the graph capture records the memory address of the tensor's storage. If result is a new allocation (which it appears to be, as rocm_aiter_fused_experts returns a new tensor), its memory address will change in subsequent iterations, invalidating the captured graph. Additionally, it detaches the tensor from the pre-allocated buffer provided by the modular kernel's workspace management. To safely avoid a copy, the underlying AITER kernel should be modified to accept the destination buffer as an argument and write into it directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aiter will use torch.empty to create a new tensor, and it's competible with cuda graph I think.

@Rohan138
Copy link
Copy Markdown
Contributor

Rohan138 commented Mar 31, 2026

@benenzhu can you please run pre-commit? We'd like to merge this and get it cherry-picked into 0.19.0 to fix ROCm regressions on gpt-oss and deepseek, cc @zyongye

Comment on lines +1380 to +1385
if (
not self.inplace
and fused_out.shape == output.shape
and fused_out.is_contiguous()
):
output = fused_out
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing fused_out for output is dangerous beyond the fact that it might point to the temporary buffers. Doing this basically forces all the finalize methods to operate inplace which may or may not be supported.

@bnellnm
Copy link
Copy Markdown
Collaborator

bnellnm commented Mar 31, 2026

@benenzhu can you please run pre-commit? We'd like to merge this and get it cherry-picked into 0.19.0 to fix ROCm regressions on gpt-oss and deepseek, cc @zyongye

This will likely break things if it gets merged.

@benenzhu
Copy link
Copy Markdown
Contributor Author

benenzhu commented Apr 1, 2026

@benenzhu can you please run pre-commit? We'd like to merge this and get it cherry-picked into 0.19.0 to fix ROCm regressions on gpt-oss and deepseek, cc @zyongye

This will likely break things if it gets merged.

@bnellnm Thanks, I added a platform.is_rocm() for it. So Rocm uses aiter's moe. And it will always use torch.empty() as the output buffer.
Also I checked the shape is equal, so it should only be the TopKWeightAndReduceNoOp, other ops will have different shapes.

@benenzhu
Copy link
Copy Markdown
Contributor Author

benenzhu commented Apr 1, 2026

image @Rohan138 The precommit failed because I don't have 'ready' label or the I don't have at least 4 merged.

@gshtras
Copy link
Copy Markdown
Collaborator

gshtras commented Apr 1, 2026

@benenzhu can you please run pre-commit? We'd like to merge this and get it cherry-picked into 0.19.0 to fix ROCm regressions on gpt-oss and deepseek, cc @zyongye

This will likely break things if it gets merged.

@bnellnm Thanks, I added a platform.is_rocm() for it. So Rocm uses aiter's moe. And it will always use torch.empty() as the output buffer. Also I checked the shape is equal, so it should only be the TopKWeightAndReduceNoOp, other ops will have different shapes.

Are you saying it's tailored to work just with AITER MoE?

@benenzhu
Copy link
Copy Markdown
Contributor Author

benenzhu commented Apr 1, 2026

@benenzhu can you please run pre-commit? We'd like to merge this and get it cherry-picked into 0.19.0 to fix ROCm regressions on gpt-oss and deepseek, cc @zyongye

This will likely break things if it gets merged.

@bnellnm Thanks, I added a platform.is_rocm() for it. So Rocm uses aiter's moe. And it will always use torch.empty() as the output buffer. Also I checked the shape is equal, so it should only be the TopKWeightAndReduceNoOp, other ops will have different shapes.

Are you saying it's tailored to work just with AITER MoE?

@gshtras Yeah, others allocated moe output form here. https://github.com/vllm-project/vllm/blob/v0.19.0rc0/vllm/model_executor/layers/fused_moe/modular_kernel.py#L1009-L1070
AITER's MOE don't use this one, and create one inside kernel with torch.empty(), and we change it to the aiter's pointer with output.data = result https://github.com/vllm-project/vllm/blob/v0.19.0rc0/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py#L425-L439

So I think it's tailored for aiter to safely skip with this copy.

Copy link
Copy Markdown
Member

@zyongye zyongye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We can merge this and I can take a look to nv side so we can fix that for all.

@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 1, 2026
benenzhu added 3 commits April 1, 2026 16:18
Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cant we update aiter to optionally accept an output buffer like every other kernel we have?

then aiter would fit into the structure that we have and we could avoid these confusing, bug-prone edge cases

@benenzhu
Copy link
Copy Markdown
Contributor Author

benenzhu commented Apr 3, 2026

why cant we update aiter to optionally accept an output buffer like every other kernel we have?

then aiter would fit into the structure that we have and we could avoid these confusing, bug-prone edge cases

Yeah, thanks, the copy inside vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py can be move into aiter.
But the second in vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py can only change in VLLM I think.

nholmber added a commit to nholmber/vllm that referenced this pull request Apr 5, 2026
Cherry-pick of vllm-project#38597 onto v0.19.0.
Eliminates two device-to-device memory copies in the AITER MoE path:
1. Replace output.copy_(result) with output.data = result in rocm_aiter_fused_moe.py
2. Skip copy in TopKWeightAndReduceNoOP when output already points to fused_out
3. Add conditional in modular_kernel.py to reuse fused_out as output when shapes match
tpopp added a commit to amdsiloai/vllm that referenced this pull request Apr 9, 2026
Skip unnecessary d2d copies in the MoE path when using AITER:
- modular_kernel: skip copy when AITER output is contiguous
- rocm_aiter_fused_moe: use output.data assignment instead of copy_
- topk_weight_and_reduce: guard copy_ with data_ptr check

Signed-off-by: Tres Popp <tres.popp@amd.com>
Made-with: Cursor
@tpopp
Copy link
Copy Markdown
Contributor

tpopp commented Apr 21, 2026

How important is removing allocation overhead which is the goal of pre-creating the workspaces. This is only a concern for non CUDAGraph style eager mode executions where the workspace allocation is a small overhead relative to the rest. The buffer allocation and subsequent call is only used in a single location rather than being a larger interface consideration, so this is forcing a specific calling convention on libraries and even how many workspace buffers can be used just to remove a couple of allocation calls in a less performant mode for certain backends.

@frida-andersson
Copy link
Copy Markdown
Contributor

Missed this PR but I hit the same redundant copies independently and have a draft at #41020 that does what @robertgshaw2-redhat is asking for. I have an AITER-side change that adds output_buffer_override (ROCm/aiter@da318d0) so it writes directly into the caller's buffer. @benenzhu let me know if it makes sense to combine the two, the AITER-side changes referenced in #41020 should address the review concerns here

@benenzhu
Copy link
Copy Markdown
Contributor Author

Missed this PR but I hit the same redundant copies independently and have a draft at #41020 that does what @robertgshaw2-redhat is asking for. I have an AITER-side change that adds output_buffer_override (ROCm/aiter@da318d0) so it writes directly into the caller's buffer. @benenzhu let me know if it makes sense to combine the two, the AITER-side changes referenced in #41020 should address the review concerns here

Yeah, I will close this for now. Aiter side change should be better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants