-
Notifications
You must be signed in to change notification settings - Fork 971
fix(sm12x): fix micro-kernel workspace sizing when routed_rows > num_local_experts #3191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d0d7a10
27d763e
fb0870d
224c441
957cea7
4e76519
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1067,6 +1067,78 @@ def test_micro_single_token_unique_path(self): | |||||||||||||||||
| f"(atol={atol:.4f})" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||||||
| "num_tokens,top_k,num_experts", | ||||||||||||||||||
| [ | ||||||||||||||||||
| (2, 8, 8), # total_pairs=16 > num_local_experts=8 | ||||||||||||||||||
| (4, 8, 16), # total_pairs=32 > num_local_experts=16 | ||||||||||||||||||
| (4, 4, 8), # total_pairs=16 > num_local_experts=8 | ||||||||||||||||||
| ], | ||||||||||||||||||
|
Comment on lines
+1073
to
+1076
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test case
Consider replacing 🔧 Suggested replacement- (4, 8, 16), # total_pairs=32 > num_local_experts=16
+ (2, 8, 16), # total_pairs=16 > num_local_experts=16, routed_rows safely within micro cutover📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||
| ) | ||||||||||||||||||
| def test_micro_pairs_exceed_local_experts( | ||||||||||||||||||
| self, num_tokens: int, top_k: int, num_experts: int | ||||||||||||||||||
| ): | ||||||||||||||||||
| """Regression test: micro kernel when num_tokens * top_k > num_local_experts. | ||||||||||||||||||
|
|
||||||||||||||||||
| The workspace compact_topk_ids buffer was previously sized state_E | ||||||||||||||||||
| (num_local_experts), but the micro kernel fills it with total_pairs = | ||||||||||||||||||
| num_tokens * top_k. When total_pairs > num_local_experts the assertion | ||||||||||||||||||
| 'flat_ids.numel() <= workspace.compact_topk_ids.numel()' fired. | ||||||||||||||||||
|
|
||||||||||||||||||
| Fixed by sizing compact_topk_ids as max(state_E, max_rows). | ||||||||||||||||||
| """ | ||||||||||||||||||
| from flashinfer import b12x_fused_moe | ||||||||||||||||||
|
|
||||||||||||||||||
| hidden_size, intermediate_size = 256, 512 | ||||||||||||||||||
|
|
||||||||||||||||||
| tensors = create_moe_tensors( | ||||||||||||||||||
| num_tokens=num_tokens, | ||||||||||||||||||
| hidden_size=hidden_size, | ||||||||||||||||||
| intermediate_size=intermediate_size, | ||||||||||||||||||
| num_experts=num_experts, | ||||||||||||||||||
| num_local_experts=num_experts, | ||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| result = b12x_fused_moe( | ||||||||||||||||||
| x=tensors["x_bf16"], | ||||||||||||||||||
| w1_weight=tensors["w1_weight"], | ||||||||||||||||||
| w1_weight_sf=tensors["w1_weight_sf"], | ||||||||||||||||||
| w1_alpha=tensors["w1_alpha"], | ||||||||||||||||||
| fc2_input_scale=tensors["fc2_input_scale"], | ||||||||||||||||||
| w2_weight=tensors["w2_weight"], | ||||||||||||||||||
| w2_weight_sf=tensors["w2_weight_sf"], | ||||||||||||||||||
| w2_alpha=tensors["w2_alpha"], | ||||||||||||||||||
| token_selected_experts=tensors["token_selected_experts"], | ||||||||||||||||||
| token_final_scales=tensors["token_final_scales"], | ||||||||||||||||||
| num_experts=num_experts, | ||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert result.shape == (num_tokens, hidden_size) | ||||||||||||||||||
| assert not torch.isnan(result).any() | ||||||||||||||||||
| assert not torch.isinf(result).any() | ||||||||||||||||||
|
|
||||||||||||||||||
| ref_output = compute_reference_moe_fp4( | ||||||||||||||||||
| hidden_states=tensors["x_bf16"].float().cuda(), | ||||||||||||||||||
| gemm1_weights=tensors["w1_weight_bf16"].float().cuda(), | ||||||||||||||||||
| gemm2_weights=tensors["w2_weight_bf16"].float().cuda(), | ||||||||||||||||||
| token_selected_experts=tensors["token_selected_experts"], | ||||||||||||||||||
| token_final_scales=tensors["token_final_scales"], | ||||||||||||||||||
| num_tokens=num_tokens, | ||||||||||||||||||
| num_experts=num_experts, | ||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||
| hidden_size=hidden_size, | ||||||||||||||||||
| intermediate_size=intermediate_size, | ||||||||||||||||||
| fc2_input_scale=tensors["fc2_input_scale"], | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| passed, percent_within, atol = check_accuracy(result, ref_output) | ||||||||||||||||||
| assert passed, ( | ||||||||||||||||||
| f"Micro pairs>experts: {percent_within * 100:.2f}% within tolerance " | ||||||||||||||||||
| f"(atol={atol:.4f}, tokens={num_tokens}, top_k={top_k}, experts={num_experts})" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||
| # Test Class: ReLU2 Activation (SM120-only, non-gated) | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring contract is now stale after removing the size check.
Line 62 still says
weight_expert_idsmust be[>=total_pairs], but Lines 71-72 explicitly relax that. Please update the docstring to reflect the new expected sizing contract.Suggested doc fix
🤖 Prompt for AI Agents