-
Notifications
You must be signed in to change notification settings - Fork 382
Fused TopK and Sigmoid kernel #1251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
2c6096e
Add topk softmax
samremes f1da3a3
Add test for topk sigmoid
samremes 1a9ff75
register the op properly
samremes 7505a7c
apply black
samremes 47d9750
don't use constexpr with std::string
samremes d86d952
bump ck to include topk sigmoid commit
samremes ff47eb7
Merge remote-tracking branch 'origin/main' into samremes/topk_sigmoid
samremes 784e3f2
hipify
samremes 47f9168
add argparse to the topk sigmoid test, also add pytest
samremes a7e4dd2
use own module instead of asm moe
samremes 6a53df0
black formatting
samremes efbd7e6
add missing file
samremes 367253a
revert changes to module_moe_asm
samremes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule composable_kernel
updated
254 files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| #include <torch/all.h> | ||
| #include <ATen/hip/HIPContext.h> | ||
| #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> | ||
| #include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h> | ||
| #include "py_itfs_common.h" | ||
|
|
||
| // from CK examples: | ||
| #include "topk_softmax_api.hpp" | ||
|
|
||
| namespace aiter | ||
| { | ||
|
|
||
| void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk] | ||
| torch::Tensor topk_indices, // [tokens, topk] | ||
| torch::Tensor gating_output) // [tokens, experts] | ||
| { | ||
| // Ensure the tensors are on the correct device | ||
| const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); | ||
|
|
||
| // Extract dimensions | ||
| const int tokens = gating_output.size(0); | ||
| const int experts = gating_output.size(1); | ||
| const int topk = topk_weights.size(1); | ||
|
|
||
| // Assume default strides | ||
| const int stride_input = experts; | ||
| const int stride_output = topk; | ||
|
|
||
| // Determine datatypes | ||
| auto dtype_to_string = [](const auto dtype) -> std::string { | ||
| if(dtype == torch::kFloat16) | ||
| { | ||
| return "fp16"; | ||
| } | ||
| else if(dtype == torch::kBFloat16) | ||
| { | ||
| return "bf16"; | ||
| } | ||
| else if(dtype == torch::kFloat32) | ||
| { | ||
| return "fp32"; | ||
| } | ||
| else | ||
| { | ||
| throw std::runtime_error("invalid datatype for topk_sigmoid: only fp16/bf16/fp32!"); | ||
| } | ||
| }; | ||
| std::string input_prec = dtype_to_string(gating_output.dtype()); | ||
| std::string weight_prec = dtype_to_string(topk_weights.dtype()); | ||
|
|
||
| // Prepare kernel arguments | ||
| static const std::string activation = "sigmoid"; | ||
| topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; | ||
|
|
||
| topk_softmax_kargs karg{gating_output.data_ptr(), | ||
| topk_weights.data_ptr(), | ||
| topk_indices.data_ptr(), | ||
| tokens, | ||
| experts, | ||
| topk, | ||
| stride_input, | ||
| stride_output}; | ||
|
|
||
| ck_tile::stream_config sc{at::hip::getCurrentHIPStream()}; | ||
|
|
||
| topk_softmax(trait, karg, sc); | ||
| } | ||
|
|
||
| } // namespace aiter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| /* SPDX-License-Identifier: MIT | ||
| Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| */ | ||
| #include "moe_op.h" | ||
| #include "rocm_ops.hpp" | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
| { | ||
| MOE_TOPK_PYBIND; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # SPDX-License-Identifier: MIT | ||
| # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| """ | ||
| Test topk_sigmoid operation with various configurations. | ||
|
|
||
| This test can be run in two ways: | ||
|
|
||
| 1. Using pytest (for automated testing): | ||
| pytest test_moe_topk_sigmoid.py -v | ||
|
|
||
| 2. Using command line arguments (for benchmarking with summary table): | ||
| python test_moe_topk_sigmoid.py --num-experts 64,128 --topk 2,4,8 --dtype fp16 | ||
| """ | ||
|
|
||
| import argparse | ||
| import itertools | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
| import torch | ||
| import aiter | ||
| from aiter.test_common import ( | ||
| checkAllclose, | ||
| perftest, | ||
| ) | ||
| from aiter.utility.dtypes import str2Dtype, str2tuple | ||
|
|
||
|
|
||
| @perftest(num_iters=10, num_warmup=1) | ||
| def run_torch(gating_output: torch.Tensor, topk: int): | ||
| # llama4 maverick custom routing function | ||
| router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) | ||
| router_scores = torch.sigmoid(router_scores.float()) | ||
| return router_scores, router_indices.to(torch.int32) | ||
|
|
||
|
|
||
| @perftest(num_iters=10, num_warmup=1) | ||
| def run_fused(gating_output: torch.Tensor, topk: int): | ||
| tokens, _ = gating_output.shape | ||
| router_scores = torch.empty( | ||
| (tokens, topk), dtype=torch.float32, device=gating_output.device | ||
| ) | ||
| router_indices = torch.empty( | ||
| (tokens, topk), dtype=torch.int32, device=gating_output.device | ||
| ) | ||
| aiter.topk_sigmoid(router_scores, router_indices, gating_output) | ||
| return router_scores, router_indices | ||
|
|
||
|
|
||
| def benchmark_topk_sigmoid( | ||
| num_experts: int = 128, | ||
| num_tokens: int = 1024, | ||
| topk: int = 4, | ||
| dtype: torch.dtype = torch.float16, | ||
| ): | ||
| # generate data - each row has only unique values | ||
| gating_output = ( | ||
| torch.arange(-1, 1, 2.0 / num_experts) | ||
| .repeat((num_tokens, 1)) | ||
| .to(dtype=dtype, device="cuda") | ||
| ) | ||
| permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) | ||
| gating_output = torch.gather(gating_output, dim=-1, index=permutation) | ||
| assert gating_output.is_contiguous() | ||
| # run benchmarks | ||
| (scores_torch, indices_torch), avg_torch = run_torch(gating_output.clone(), topk) | ||
| (scores_fused, indices_fused), avg_fused = run_fused(gating_output.clone(), topk) | ||
| # check correctness | ||
| score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01) | ||
| index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01) | ||
|
|
||
| # Collect results for summary | ||
| result = { | ||
| "num_experts": num_experts, | ||
| "num_tokens": num_tokens, | ||
| "topk": topk, | ||
| "dtype": str(dtype).split(".")[-1], | ||
| "torch_us": avg_torch, | ||
| "fused_us": avg_fused, | ||
| "uplift": avg_torch / avg_fused, | ||
| "score_errors": score_errors, | ||
| "index_errors": index_errors, | ||
| } | ||
|
|
||
| # print some failed rows if errors are significant | ||
| if score_errors > 0.01 or index_errors > 0.01: | ||
| failed_rows = (indices_torch != indices_fused).sum(dim=-1) > 0 | ||
| print( | ||
| f"\n[ERROR] Configuration: num_experts={num_experts}, num_tokens={num_tokens}, topk={topk}, dtype={str(dtype).split('.')[-1]}" | ||
| ) | ||
| print("Wrong scores:") | ||
| print(scores_torch[failed_rows][:5]) | ||
| print(scores_fused[failed_rows][:5]) | ||
| print("Wrong indices:") | ||
| print(indices_torch[failed_rows][:5]) | ||
| print(indices_fused[failed_rows][:5]) | ||
| print("Gating outputs:") | ||
| failed_values = gating_output[failed_rows][:5] | ||
| failed_values, _ = failed_values.sort(dim=-1, descending=True) | ||
| print(failed_values[:, :10]) | ||
| print( | ||
| f"Number of wrong tokens: {sum(failed_rows)} / {len(failed_rows)}, {100 * sum(failed_rows) / len(failed_rows):.2f} %" | ||
| ) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| # Pytest-parametrized test functions | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 4, 8]) | ||
| @pytest.mark.parametrize("num_tokens", [64, 1024, 2048]) | ||
| @pytest.mark.parametrize("num_experts", [64, 128]) | ||
| def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype): | ||
| """Pytest test for correctness of topk_sigmoid operation.""" | ||
| torch.random.manual_seed(0) | ||
|
|
||
| # generate data - each row has only unique values | ||
| gating_output = ( | ||
| torch.arange(-1, 1, 2.0 / num_experts) | ||
| .repeat((num_tokens, 1)) | ||
| .to(dtype=dtype, device="cuda") | ||
| ) | ||
| permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) | ||
| gating_output = torch.gather(gating_output, dim=-1, index=permutation) | ||
| assert gating_output.is_contiguous() | ||
|
|
||
| # run both implementations | ||
| (scores_torch, indices_torch), _ = run_torch(gating_output.clone(), topk) | ||
| (scores_fused, indices_fused), _ = run_fused(gating_output.clone(), topk) | ||
|
|
||
| # check correctness | ||
| score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01) | ||
| index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01) | ||
|
|
||
| # Assert correctness | ||
| assert score_errors <= 0.01, f"Score errors {score_errors} exceed tolerance" | ||
| assert index_errors <= 0.01, f"Index errors {index_errors} exceed tolerance" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="Test topk_sigmoid operation with various configurations" | ||
| ) | ||
| parser.add_argument( | ||
| "--num-experts", | ||
| type=str2tuple, | ||
| default=[128], | ||
| help="Comma-separated list of number of experts (default: 16,128)", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-tokens", | ||
| type=str2tuple, | ||
| default=[1024], | ||
| help="Comma-separated list of number of tokens (default: 1024)", | ||
| ) | ||
| parser.add_argument( | ||
| "--topk", | ||
| type=str2tuple, | ||
| default=[8], | ||
| help="Comma-separated list of topk values (default: 1,2,8)", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| type=str2Dtype, | ||
| default=[torch.float16, torch.bfloat16], | ||
| help="Comma-separated list of dtypes: fp16, bf16 (default: fp16,bf16)", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| # Get parsed parameter lists | ||
| num_experts_list = args.num_experts | ||
| num_tokens_list = args.num_tokens | ||
| topk_list = args.topk | ||
| dtype_list = args.dtype | ||
|
|
||
| # Run all combinations (cartesian product) | ||
| configs = list( | ||
| itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list) | ||
| ) | ||
|
|
||
| print(f"Running {len(configs)} configuration(s):") | ||
| print(f" num_experts: {num_experts_list}") | ||
| print(f" num_tokens: {num_tokens_list}") | ||
| print(f" topk: {topk_list}") | ||
| print(f" dtype: {[str(dt).split('.')[-1] for dt in dtype_list]}") | ||
| print("=" * 80) | ||
|
|
||
| # Collect results from all configurations | ||
| collected = [] | ||
| for i, (num_experts, num_tokens, topk, dtype) in enumerate(configs, 1): | ||
| result = benchmark_topk_sigmoid( | ||
| num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype | ||
| ) | ||
| collected.append(result) | ||
|
|
||
| print("\n" + "=" * 80) | ||
| print("SUMMARY") | ||
| print("=" * 80) | ||
|
|
||
| # Create and print DataFrame | ||
| df = pd.DataFrame(collected) | ||
| print(df.to_string(index=False)) | ||
|
|
||
| # Print additional statistics | ||
| print("\n" + "=" * 80) | ||
| print(f"Average uplift: {df['uplift'].mean():.2f}x") | ||
| print(f"Max uplift: {df['uplift'].max():.2f}x") | ||
| print(f"Min uplift: {df['uplift'].min():.2f}x") | ||
|
|
||
| # Check for any errors | ||
| errors = df[(df["score_errors"] > 0.01) | (df["index_errors"] > 0.01)] | ||
| if len(errors) > 0: | ||
| print( | ||
| f"\nWARNING: {len(errors)} configuration(s) had errors exceeding tolerance!" | ||
| ) | ||
| print(errors.to_string(index=False)) | ||
| else: | ||
| print("\nAll tests passed with errors within tolerance!") | ||
| print("=" * 80) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.