-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Use dsv3 optimized routing fused_topk_deepseek instead of moe_fused_gate
#15347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
8cb9f3d
use fused_topk_deepseek instead of moe_fused_gate
leejnau 59d1bee
keep old kernel; use new one when possible
leejnau 2a37618
add unit test for fused_topk_deepseek
leejnau c4f49ef
run pre-commit to fix formatting
leejnau 67a1850
move test_fused_topk_deepseeek from sgl-kernel to sglang/test
leejnau 21e4828
move test to test/registered/kernels and register as nightly-1-gpu
leejnau 67878ac
Merge branch 'main' into fused_topk_routing
Fridge003 ba18369
lower accept length threshold to 2.8
leejnau File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trevor-m marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import pytest | ||
Fridge003 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
|
|
||
| from sglang.srt.layers.moe.topk import biased_grouped_topk_gpu, biased_grouped_topk_impl | ||
| from sglang.test.ci.ci_register import register_cuda_ci | ||
|
|
||
| register_cuda_ci(est_time=2, suite="nightly-1-gpu", nightly=True) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "seq_length", | ||
| list(range(1, 10)) | ||
| + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "params", | ||
| [ | ||
| (128, 4, 2, 4), # 128 experts configuration | ||
| (256, 8, 4, 8), # DeepSeek V3 config - most important to test | ||
| (64, 2, 2, 4), # Smaller configuration | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True]) | ||
| def test_fused_topk_deepseek(seq_length, params, apply_routed_scaling_factor_on_output): | ||
| """ | ||
| Test the fused_topk_deepseek code path in biased_grouped_topk_gpu. | ||
| """ | ||
| num_experts, num_expert_group, topk_group, topk = params | ||
| dtype = torch.float32 | ||
|
|
||
| torch.manual_seed(seq_length) | ||
| hidden_states = torch.randn(seq_length, 128, dtype=dtype, device="cuda") | ||
| gating_output = torch.randn(seq_length, num_experts, dtype=dtype, device="cuda") | ||
| correction_bias = torch.randn(num_experts, dtype=dtype, device="cuda") | ||
|
|
||
| routed_scaling_factor = 2.5 if apply_routed_scaling_factor_on_output else None | ||
|
|
||
| # Fused implementation (uses fused_topk_deepseek when conditions are met) | ||
| output, indices = biased_grouped_topk_gpu( | ||
| hidden_states, | ||
| gating_output, | ||
| correction_bias, | ||
| topk=topk, | ||
| renormalize=True, | ||
| num_expert_group=num_expert_group, | ||
| topk_group=topk_group, | ||
| num_fused_shared_experts=0, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, | ||
| ) | ||
|
|
||
| # Reference implementation (pure PyTorch) | ||
| ref_output, ref_indices = biased_grouped_topk_impl( | ||
| hidden_states, | ||
| gating_output, | ||
| correction_bias, | ||
| topk=topk, | ||
| renormalize=True, | ||
| num_expert_group=num_expert_group, | ||
| topk_group=topk_group, | ||
| num_fused_shared_experts=0, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, | ||
| ) | ||
|
|
||
| # Check 1: Row-wise sums should match (invariant to tie-breaking) | ||
| output_sum = output.sum(dim=-1) | ||
| ref_output_sum = ref_output.sum(dim=-1) | ||
| sum_check = torch.allclose(output_sum, ref_output_sum, rtol=1e-03, atol=1e-04) | ||
|
|
||
| # Check 2: Scatter-based comparison with allowance for tie-breaking | ||
| res = torch.zeros(seq_length, num_experts, dtype=torch.float32, device="cuda") | ||
| ref = torch.zeros(seq_length, num_experts, dtype=torch.float32, device="cuda") | ||
|
|
||
| res.scatter_(1, indices.long(), output) | ||
| ref.scatter_(1, ref_indices.long(), ref_output) | ||
|
|
||
| diff = torch.abs(ref - res) | ||
| atol = ( | ||
| 5e-03 | ||
| if (seq_length >= 4096 and apply_routed_scaling_factor_on_output) | ||
| else 1e-03 | ||
| ) | ||
| num_large_diffs = (diff > atol).sum().item() | ||
|
|
||
| # Allow a small number of differences for tie-breaking situations | ||
| max_allowed_diffs = max(16, seq_length // 500) | ||
| scatter_check = num_large_diffs <= max_allowed_diffs | ||
|
|
||
| assert sum_check and scatter_check, ( | ||
| f"Output mismatch at seq_length {seq_length}, params {params}, " | ||
| f"apply_routed_scaling_factor_on_output {apply_routed_scaling_factor_on_output}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.