-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[FIX] Add NO_MUL activation support for modular kernel path #31528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a8c791b
fix
danielafrimi 3218c78
refactor
danielafrimi 6787f72
refactor
danielafrimi 0d7fa53
refactor
aae2367
refactor
3e1c987
fix
3fb7514
Merge branch 'main' into fix_path_moe
mgoin 34f4eab
Merge branch 'main' into fix_path_moe
mgoin 8d4a90c
Merge branch 'main' into fix_path_moe
danielafrimi 64ea89c
refactor
1a391e7
refactor
e5c5630
Merge branch 'main' into fix_path_moe
danielafrimi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for MoE with non-gated activations (*_no_mul). | ||
|
|
||
| These tests verify that MoE layers work correctly with activations like | ||
| silu_no_mul, gelu_no_mul, relu2_no_mul where the activation output dimension | ||
| equals N (not N // 2 like gated activations). | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| FUSED_MOE_UNQUANTIZED_CONFIG, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts | ||
| from vllm.model_executor.layers.fused_moe.utils import ( | ||
| GELU_NO_MUL, | ||
| RELU2_NO_MUL, | ||
| SILU_NO_MUL, | ||
| ) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| # Test parameters | ||
| M_SIZES = [1, 16, 64] | ||
| N_SIZES = [128, 256] | ||
| K_SIZES = [64, 128] | ||
| TOPK_VALUES = [1, 2] | ||
| NUM_EXPERTS = 8 | ||
| NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL] | ||
|
|
||
|
|
||
| def make_test_tensors( | ||
| m: int, | ||
| n: int, | ||
| k: int, | ||
| num_experts: int, | ||
| topk: int, | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| device: str = "cuda", | ||
| ): | ||
| """Create test tensors for MoE with non-gated activation. | ||
|
|
||
| For non-gated activations (*_no_mul): | ||
| - w1: (E, N, K) - projects from K to N | ||
| - w2: (E, K, N) - projects from N back to K (note: N, not N//2) | ||
| """ | ||
| hidden_states = torch.randn(m, k, dtype=dtype, device=device) | ||
|
|
||
| # For non-gated: w1 projects K -> N, w2 projects N -> K | ||
| w1 = torch.randn(num_experts, n, k, dtype=dtype, device=device) * 0.1 | ||
| w2 = torch.randn(num_experts, k, n, dtype=dtype, device=device) * 0.1 | ||
|
|
||
| topk_weights = torch.ones(m, topk, dtype=torch.float32, device=device) / topk | ||
| topk_ids = torch.randint(0, num_experts, (m, topk), device=device) | ||
|
|
||
| return hidden_states, w1, w2, topk_weights, topk_ids | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.has_device_capability(80), | ||
| reason="Requires compute capability >= 8.0", | ||
| ) | ||
| @pytest.mark.parametrize("m", M_SIZES) | ||
| @pytest.mark.parametrize("n", N_SIZES) | ||
| @pytest.mark.parametrize("k", K_SIZES) | ||
| @pytest.mark.parametrize("topk", TOPK_VALUES) | ||
| @pytest.mark.parametrize("activation", NO_MUL_ACTIVATIONS) | ||
| @torch.inference_mode() | ||
| def test_triton_experts_no_mul_activation( | ||
| m: int, | ||
| n: int, | ||
| k: int, | ||
| topk: int, | ||
| activation: str, | ||
| ): | ||
| hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors( | ||
| m, n, k, NUM_EXPERTS, topk | ||
| ) | ||
|
|
||
| experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) | ||
|
|
||
| ws1_shape, ws2_shape, out_shape = experts.workspace_shapes( | ||
| M=m, | ||
| N=n, | ||
| K=k, | ||
| topk=topk, | ||
| global_num_experts=NUM_EXPERTS, | ||
| local_num_experts=NUM_EXPERTS, | ||
| expert_tokens_meta=None, | ||
| activation=activation, | ||
| ) | ||
|
|
||
| # Verify workspace shapes are correct for no_mul activation | ||
| # workspace1 should handle activation_out_dim = N (not N//2) | ||
| assert ws1_shape == (m, topk, max(n, k)), ( | ||
| f"workspace1 shape mismatch: expected {(m, topk, max(n, k))}, got {ws1_shape}" | ||
| ) | ||
| # workspace2 should handle max(N, K) for intermediate_cache1/cache3 | ||
| assert ws2_shape == (m, topk, max(n, k)), ( | ||
| f"workspace2 shape mismatch: expected {(m, topk, max(n, k))}, got {ws2_shape}" | ||
| ) | ||
| assert out_shape == (m, k), ( | ||
| f"output shape mismatch: expected {(m, k)}, got {out_shape}" | ||
| ) | ||
|
|
||
| workspace1 = torch.empty( | ||
| ws1_shape[0] * ws1_shape[1] * ws1_shape[2], | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| workspace2 = torch.empty( | ||
| ws2_shape[0] * ws2_shape[1] * ws2_shape[2], | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| output = torch.zeros(m, k, dtype=hidden_states.dtype, device=hidden_states.device) | ||
|
|
||
| experts.apply( | ||
| output=output, | ||
| hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| activation=activation, | ||
| global_num_experts=NUM_EXPERTS, | ||
| expert_map=None, | ||
| a1q_scale=None, | ||
| a2_scale=None, | ||
| workspace13=workspace1, | ||
| workspace2=workspace2, | ||
| expert_tokens_meta=None, | ||
| apply_router_weight_on_input=False, | ||
| ) | ||
|
|
||
| assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}" | ||
| assert not torch.isnan(output).any(), "Output contains NaN" | ||
| assert not torch.isinf(output).any(), "Output contains Inf" | ||
| assert output.abs().sum() > 0, "Output is all zeros" | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.has_device_capability(80), | ||
| reason="Requires compute capability >= 8.0", | ||
| ) | ||
| @torch.inference_mode() | ||
| def test_workspace_shapes_no_mul_vs_gated(): | ||
| """Test that workspace shapes differ correctly between gated and non-gated.""" | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts | ||
|
|
||
| M, N, K, topk = 64, 256, 128, 2 | ||
|
|
||
| experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) | ||
|
|
||
| ws1_no_mul, _, out_no_mul = experts.workspace_shapes( | ||
| M, N, K, topk, 8, 8, None, SILU_NO_MUL | ||
| ) | ||
|
|
||
| ws1_gated, _, out_gated = experts.workspace_shapes( | ||
| M, N, K, topk, 8, 8, None, "silu" | ||
| ) | ||
|
|
||
| # For no_mul: activation_out_dim = N | ||
| # For gated: activation_out_dim = N // 2 | ||
| # workspace1 should use max(activation_out_dim, K) | ||
| activation_out_dim_no_mul = N | ||
| activation_out_dim_gated = N // 2 | ||
|
|
||
| assert ws1_no_mul[2] == max(activation_out_dim_no_mul, K), ( | ||
| f"no_mul workspace1 last dim should be max({activation_out_dim_no_mul}, {K})" | ||
| ) | ||
| assert ws1_gated[2] == max(activation_out_dim_gated, K), ( | ||
| f"gated workspace1 last dim should be max({activation_out_dim_gated}, {K})" | ||
| ) | ||
|
|
||
| # Output shapes should be the same | ||
| assert out_no_mul == out_gated == (M, K) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.has_device_capability(80), | ||
| reason="Requires compute capability >= 8.0", | ||
| ) | ||
| @torch.inference_mode() | ||
| def test_adjust_n_for_activation(): | ||
| """Test the adjust_N_for_activation method.""" | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts | ||
|
|
||
| experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) | ||
|
|
||
| N = 256 | ||
|
|
||
| # Gated activations should return N // 2 | ||
| assert experts.adjust_N_for_activation(N, "silu") == N // 2 | ||
| assert experts.adjust_N_for_activation(N, "gelu") == N // 2 | ||
|
|
||
| # Non-gated activations should return N | ||
| assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N | ||
| assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N | ||
| assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.