Skip to content

[Performance][Kernel] Fused_moe Performance Improvement #9384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 24, 2024

Conversation

charlifu
Copy link
Contributor

@charlifu charlifu commented Oct 15, 2024

This PR tries to improve the performance of vllm fuse moe layer by:

  • Increase the thread number of moe_align_block_size_kernel to warp size.
  • Replace torch.sum with customized moe_sum.
  • Move moe_align_block_size_kernels.cu to moe/moe_align_sum_kernels.cu.

Motivations:

  • moe_align_block_size_kernel only uses num_experts as the thread number, which causes low parallelism when handling big input size, i.e., prefilling.
  • At the end of fused_moe, torch.sum is used to sum the partial results from all selected experts. But torch.sum is not optimized for this use case.

Performance Improvement (Mixtral 8x22b):
moe_sum: x6 for 8k prefill.
image
moe_align_block_size: x5 for 8k prefill
image


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@comaniac
Copy link
Collaborator

Thanks for the improvements! Do you have some end-to-end benchmark results to share with?

@charlifu
Copy link
Contributor Author

charlifu commented Oct 15, 2024

Thanks for the improvements! Do you have some end-to-end benchmark results to share with?

Here are the latency (s) numbers tested on upstream (prefill, tp 2, Mixtral 8x22b, bf16):

num-prompt input-len w/o w/
8 8192 14.80 14.27
16 8192 29.84 29.21
32 8192 59.55 57.36

But upstream's triton moe kernel is not tuned for rocm, so its huge overhead offsets the performance gain. On the rocm fork of vllm, we are able to see around 10% boost on tp8 for Mixtral 8x22b.

@comaniac
Copy link
Collaborator

Thanks for the numbers. In this case it would be great to come up with tuned configurations (we have them for some NVIDIA GPUs under https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/fused_moe/configs).

Also how does this PR change the performance on NVIDIA GPUs?

@charlifu
Copy link
Contributor Author

charlifu commented Oct 15, 2024

Thanks for the numbers. In this case it would be great to come up with tuned configurations (we have them for some NVIDIA GPUs under https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/fused_moe/configs).

Also how does this PR change the performance on NVIDIA GPUs?

Will add the tuned configs!

Nvidia GPUs should also benefit from this PR.

@charlifu
Copy link
Contributor Author

charlifu commented Oct 18, 2024

@comaniac
H100 e2e numbers: Mixtral 8x22b, tp8, prefill, 8k input len, 6.25% gain

num-prompt input-len w/o w/
8 8192 2.00 1.88
16 8192 3.99 3.75
32 8192 8.00 7.50
64 8192 16.01 15.04
128 8192 32.01 30.07

MI300X e2e numbers: Mixtral 8x7B (tuned), tp8, prefill, 8k input len, 18% gain

num-prompt input-len w/o w/
8 8192 1.26 1.01
16 8192 2.53 2.04
32 8192 5.08 4.10
64 8192 10.19 8.29
128 8192 20.53 16.78

@comaniac
Copy link
Collaborator

Awesome! Do you want to add AMD tuned configs in this PR or you prefer to have a follow-up PR for it?

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I feel we could merge this PR first as it doesn't actually touch the MoE triton kernel.
cc @tlrmchlsmth

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 18, 2024
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! looks good to merge

Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
@charlifu
Copy link
Contributor Author

charlifu commented Oct 23, 2024

Looks like it is causing some CI failures on Nvidia side. Looking at it.

Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
@charlifu
Copy link
Contributor Author

The kernel test failure is fixed. I think this PR is ready to be merged. For other two failures, I noticed recently merged PRs failed on those two as well. So I suppose it is not the issue caused by this PR. @tlrmchlsmth @comaniac

@comaniac
Copy link
Collaborator

Ok I'll ping folks to force merge.

@simon-mo simon-mo merged commit 5944909 into vllm-project:main Oct 24, 2024
76 of 79 checks passed
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
@charlifu charlifu deleted the moe_opt branch October 31, 2024 14:12
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants