Skip to content

Conversation

@vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Mar 17, 2025

This PR integrates fused MoE kernels from AITER (AI Tensor Engine for ROCm)

Several fused MoE kernels have been integrated for different scenarios:

  1. The ck_moe kernel from AITER is integrated for unquantized model weights. It is enabled by default when VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using the dedicated environment variable VLLM_ROCM_USE_AITER_MOE. This is suitable for MoE models such as Mixtral.

  2. The asm_moe kernel from AITER is integrated for dynamic per-tensor quantization model weights. It is enabled by default when VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using the dedicated environment variable VLLM_ROCM_USE_AITER_MOE. This is suitable for MoE models such as Mixtral for fp8 quantization.

  3. The fmoe_fp8_block_scaled kernel from AITER is integrated for block fp8 quantization method. Unlike the above features, this is disabled by default even when the parent switch (VLLM_ROCM_USE_AITER=1) is enabled. To use this kernel, both the parent switch and its dedicated environment variable VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE must be enabled. This kernel is suitable for DeepSeek models.

These MoE kernels are integrated in /vllm/model_executor/layers/fused_moe/fused_moe.py. The necessary processing steps required for these kernels are included in their respective MoE Methods for both Unquantized (UnquantizedMoEMethod) in /vllm/model_executor/layers/fused_moe/layer.py and FP8 quantized (FP8MoEMethod) in /vllm/model_executor/layers/quantization/fp8.py.

Performance Improvement Tables

Mixtral-8x7B-FP8

Summary Performance Improvement Over No AITER
With Fused MoE -14~75%

Mixtral-8x7B-FP16

Summary Performance Improvement Over No AITER
With Fused MoE -11~2%

DeepSeekV3 Throughput

Summary Performance Improvement Over No AITER
fmoe_fp8_block_scaled (VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE=1) 8~26.7%

DeepSeekV3 Latency

Summary SpeedUp in TPOT SpeedUp in TTFT
fmoe_fp8_block_scaled (VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE=1) -2% 41%

AITER Operations Testing Overview

1. High-Level Integration Tests

The integration of AITER ops is tested at a higher module level in the following files under /tests/models/decoder_only/language:

  • test_models.py
  • test_mistral.py

These tests involve running various models to ensure overall functionality.

2. AITER MoE Specific Test

  • The AITER Mixture of Experts (MoE) is specifically tested for the Mixtral model in:
    /tests/kernels/test_moe.py

3. Quantization Testing

  • Quantization methods for AITER-enabled modules are tested in:
    /tests/quantization/test_fp8.py

4. Kernel Function Dispatch Testing

  • The correct dispatching of kernel functions (AITER-enabled or not) is verified in:
    /tests/model_executor/test_enabled_custom_ops.py

lm_eval results

mistralai/Mixtral-8x7B-Instruct-v0.1

Tasks Version Filter n-shot Metric Quantization Value (Without AITER) Stderr (Without AITER) Value (With AITER) Stderr (With AITER)
gsm8k 3 flexible-extract 5 exact_match ↑ Unquantized 0.6338 ±0.0133 0.6475 ±0.0132
gsm8k 3 strict-match 5 exact_match ↑ Unquantized 0.6315 ±0.0133 0.6437 ±0.0132
gsm8k 3 flexible-extract 5 exact_match ↑ FP8 0.6399 ±0.0132 0.6376 ±0.0132
gsm8k 3 strict-match 5 exact_match ↑ FP8 0.6353 ±0.0133 0.6323 ±0.0133

mistralai/Mixtral-8x22B-Instruct-v0.1

Tasks Version Filter n-shot Metric Quantization Value (Without AITER) Stderr (Without AITER) Value (With AITER) Stderr (With AITER)
gsm8k 3 flexible-extract 5 exact_match ↑ Unquantized 0.8544 ±0.0097 0.8522 ±0.0098
gsm8k 3 strict-match 5 exact_match ↑ Unquantized 0.8415 ±0.0101 0.8415 ±0.0101
gsm8k 3 flexible-extract 5 exact_match ↑ FP8 0.8506 ±0.0098 0.8552 ±0.0097
gsm8k 3 strict-match 5 exact_match ↑ FP8 0.8378 ±0.0102 0.8469 ±0.0099

Deepseek-V3

Tasks Version Filter n-shot Metric Value (Without AITER) Stderr (Without AITER) Value (With AITER) Stderr (With AITER)
gsm8k 3 flexible-extract 5 exact_match ↑ 0.9469 ±0.0062 0.9492 ±0.0060
gsm8k 3 strict-match 5 exact_match ↑ 0.9477 ±0.0061 0.9484 ±0.0061

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@vllmellm vllmellm marked this pull request as ready for review March 18, 2025 04:28
…o that the models unit tests would be triggered when aiter envs are switched on and off

Signed-off-by: vllmellm <[email protected]>
vllm/envs.py Outdated
Comment on lines 535 to 547
"VLLM_ROCM_USE_AITER_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# use aiter block scaled moe op if aiter ops are enabled.
# by default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE",
"false").lower() in ("true", "1")),
Copy link
Member

@DarkLight1337 DarkLight1337 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep vllm.envs simple by not doing any cascading here. The cascading logic should belong somewhere else (e.g. in the platform class, or in the place where it's actually being used)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the cascading logic is a bit much for the vllm.envs, but I don't think that the platforms class is really the right place for kernel selection logic. I'd prefer to keep all of these environment variable checks down in the "layer" level where we are actually selecting kernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 @SageMoore
have been addressed in this commit

@SageMoore
Copy link
Contributor

I have two high level requests for this PR. The first is that we remove AITER enablement in any unit test that does not exercise this kernel. It's important that we have a good understanding of where this kernel is being unit tested and that's hard to figure out in this PR's current state. The second is that you include lm_eval results for any models that should be supported by this kernel. It sounds like that's just Deepseek V3 and Mixtral? Regardless, we need to make sure that accuracy is maintained with those models before we merge.

Thank you so much for the contribution and for working with us to get this merged. We are very excited about the Deepseek performance improvements!

@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Mar 24, 2025
@hongxiayang
Copy link
Collaborator

Hi, @SageMoore : can we prioritize to merge this PR asap? This is very important feature. Thanks.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable to me. Thanks for cleaning up the tests and running lm_eval.

@mergify
Copy link

mergify bot commented Mar 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 25, 2025
@hongxiayang hongxiayang added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2025
@mergify mergify bot removed the needs-rebase label Mar 26, 2025
…o its default value which is false

Signed-off-by: vllmellm <[email protected]>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp

@DarkLight1337 DarkLight1337 merged commit 5ebf667 into vllm-project:main Mar 26, 2025
40 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
@tjtanaa tjtanaa deleted the aiter-fmoe-integration branch May 16, 2025 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants