diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 211d64e18a..515e283091 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 211d64e18a1bf2ecb1d13c5eb87983bdcabb3b5e +Subproject commit 515e28309153ae8ab6fa3cbed81b44e2c01c43cd diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index ed8eead45c..6ee3147913 100644 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -410,6 +410,22 @@ "verbose": "False", "blob_gen_cmd": "''" }, + "module_moe_topk": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/moe_topk_pybind.cu'", + "f'{AITER_CSRC_DIR}/py_itfs_ck/topk_sigmoid_kernels.cu'", + "f'{CK_DIR}/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [ + "f'{AITER_CSRC_DIR}/include/ck_tile'", + "f'{CK_DIR}/example/ck_tile/09_topk_softmax'" + ], + "verbose": "False", + "blob_gen_cmd": "''" + }, "module_norm": { "srcs": [ "f'{AITER_CSRC_DIR}/py_itfs_ck/norm_kernels.cu'", diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index 4087fa787d..f70438b4e0 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -32,6 +32,12 @@ def topk_softmax_asm( ) -> None: ... +@compile_ops("module_moe_topk") +def topk_sigmoid( + topk_weights: Tensor, topk_indices: Tensor, gating_output: Tensor +) -> None: ... + + @compile_ops("module_moe_asm") def moe_sum(input: Tensor, output: Tensor) -> None: ... diff --git a/csrc/include/moe_op.h b/csrc/include/moe_op.h index 488c104889..27b7f5fbaa 100644 --- a/csrc/include/moe_op.h +++ b/csrc/include/moe_op.h @@ -190,4 +190,8 @@ void moe_align_block_size(torch::Tensor topk_ids, void moe_sum(torch::Tensor& input, torch::Tensor& output); +void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk] + torch::Tensor topk_indices, // [tokens, topk] + torch::Tensor gating_output); // [tokens, experts] + } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index e8dec972f7..53a5f26a12 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -906,6 +906,14 @@ py::arg("sorted_weights") = std::nullopt); \ m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); +#define MOE_TOPK_PYBIND \ + m.def("topk_sigmoid", \ + &aiter::topk_sigmoid, \ + py::arg("topk_weights"), \ + py::arg("topk_indices"), \ + py::arg("gating_output"), \ + "Apply topk sigmoid to the gating outputs."); + #define MOE_SORTING_PYBIND \ m.def("moe_sorting_fwd", \ &moe_sorting_fwd, \ diff --git a/csrc/py_itfs_ck/topk_sigmoid_kernels.cu b/csrc/py_itfs_ck/topk_sigmoid_kernels.cu new file mode 100644 index 0000000000..4b92ff7949 --- /dev/null +++ b/csrc/py_itfs_ck/topk_sigmoid_kernels.cu @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include "py_itfs_common.h" + +// from CK examples: +#include "topk_softmax_api.hpp" + +namespace aiter +{ + +void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk] + torch::Tensor topk_indices, // [tokens, topk] + torch::Tensor gating_output) // [tokens, experts] +{ + // Ensure the tensors are on the correct device + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + + // Extract dimensions + const int tokens = gating_output.size(0); + const int experts = gating_output.size(1); + const int topk = topk_weights.size(1); + + // Assume default strides + const int stride_input = experts; + const int stride_output = topk; + + // Determine datatypes + auto dtype_to_string = [](const auto dtype) -> std::string { + if(dtype == torch::kFloat16) + { + return "fp16"; + } + else if(dtype == torch::kBFloat16) + { + return "bf16"; + } + else if(dtype == torch::kFloat32) + { + return "fp32"; + } + else + { + throw std::runtime_error("invalid datatype for topk_sigmoid: only fp16/bf16/fp32!"); + } + }; + std::string input_prec = dtype_to_string(gating_output.dtype()); + std::string weight_prec = dtype_to_string(topk_weights.dtype()); + + // Prepare kernel arguments + static const std::string activation = "sigmoid"; + topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; + + topk_softmax_kargs karg{gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + tokens, + experts, + topk, + stride_input, + stride_output}; + + ck_tile::stream_config sc{at::hip::getCurrentHIPStream()}; + + topk_softmax(trait, karg, sc); +} + +} // namespace aiter diff --git a/csrc/pybind/moe_topk_pybind.cu b/csrc/pybind/moe_topk_pybind.cu new file mode 100644 index 0000000000..42351d379f --- /dev/null +++ b/csrc/pybind/moe_topk_pybind.cu @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: MIT + Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +*/ +#include "moe_op.h" +#include "rocm_ops.hpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + MOE_TOPK_PYBIND; +} \ No newline at end of file diff --git a/csrc/rocm_ops.cpp b/csrc/rocm_ops.cpp index db81f468e9..099bf7f619 100644 --- a/csrc/rocm_ops.cpp +++ b/csrc/rocm_ops.cpp @@ -88,6 +88,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ATTENTION_RAGGED_PYBIND; ATTENTION_V1_PYBIND; MOE_OP_PYBIND; + MOE_TOPK_PYBIND; ROPE_GENERAL_FWD_PYBIND; ROPE_GENERAL_BWD_PYBIND; ROPE_POS_FWD_PYBIND; diff --git a/op_tests/test_moe_topk_sigmoid.py b/op_tests/test_moe_topk_sigmoid.py new file mode 100644 index 0000000000..8e90860958 --- /dev/null +++ b/op_tests/test_moe_topk_sigmoid.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Test topk_sigmoid operation with various configurations. + +This test can be run in two ways: + +1. Using pytest (for automated testing): + pytest test_moe_topk_sigmoid.py -v + +2. Using command line arguments (for benchmarking with summary table): + python test_moe_topk_sigmoid.py --num-experts 64,128 --topk 2,4,8 --dtype fp16 +""" + +import argparse +import itertools + +import pandas as pd +import pytest +import torch +import aiter +from aiter.test_common import ( + checkAllclose, + perftest, +) +from aiter.utility.dtypes import str2Dtype, str2tuple + + +@perftest(num_iters=10, num_warmup=1) +def run_torch(gating_output: torch.Tensor, topk: int): + # llama4 maverick custom routing function + router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) + router_scores = torch.sigmoid(router_scores.float()) + return router_scores, router_indices.to(torch.int32) + + +@perftest(num_iters=10, num_warmup=1) +def run_fused(gating_output: torch.Tensor, topk: int): + tokens, _ = gating_output.shape + router_scores = torch.empty( + (tokens, topk), dtype=torch.float32, device=gating_output.device + ) + router_indices = torch.empty( + (tokens, topk), dtype=torch.int32, device=gating_output.device + ) + aiter.topk_sigmoid(router_scores, router_indices, gating_output) + return router_scores, router_indices + + +def benchmark_topk_sigmoid( + num_experts: int = 128, + num_tokens: int = 1024, + topk: int = 4, + dtype: torch.dtype = torch.float16, +): + # generate data - each row has only unique values + gating_output = ( + torch.arange(-1, 1, 2.0 / num_experts) + .repeat((num_tokens, 1)) + .to(dtype=dtype, device="cuda") + ) + permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) + gating_output = torch.gather(gating_output, dim=-1, index=permutation) + assert gating_output.is_contiguous() + # run benchmarks + (scores_torch, indices_torch), avg_torch = run_torch(gating_output.clone(), topk) + (scores_fused, indices_fused), avg_fused = run_fused(gating_output.clone(), topk) + # check correctness + score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01) + index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01) + + # Collect results for summary + result = { + "num_experts": num_experts, + "num_tokens": num_tokens, + "topk": topk, + "dtype": str(dtype).split(".")[-1], + "torch_us": avg_torch, + "fused_us": avg_fused, + "uplift": avg_torch / avg_fused, + "score_errors": score_errors, + "index_errors": index_errors, + } + + # print some failed rows if errors are significant + if score_errors > 0.01 or index_errors > 0.01: + failed_rows = (indices_torch != indices_fused).sum(dim=-1) > 0 + print( + f"\n[ERROR] Configuration: num_experts={num_experts}, num_tokens={num_tokens}, topk={topk}, dtype={str(dtype).split('.')[-1]}" + ) + print("Wrong scores:") + print(scores_torch[failed_rows][:5]) + print(scores_fused[failed_rows][:5]) + print("Wrong indices:") + print(indices_torch[failed_rows][:5]) + print(indices_fused[failed_rows][:5]) + print("Gating outputs:") + failed_values = gating_output[failed_rows][:5] + failed_values, _ = failed_values.sort(dim=-1, descending=True) + print(failed_values[:, :10]) + print( + f"Number of wrong tokens: {sum(failed_rows)} / {len(failed_rows)}, {100 * sum(failed_rows) / len(failed_rows):.2f} %" + ) + + return result + + +# Pytest-parametrized test functions +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("topk", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_tokens", [64, 1024, 2048]) +@pytest.mark.parametrize("num_experts", [64, 128]) +def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype): + """Pytest test for correctness of topk_sigmoid operation.""" + torch.random.manual_seed(0) + + # generate data - each row has only unique values + gating_output = ( + torch.arange(-1, 1, 2.0 / num_experts) + .repeat((num_tokens, 1)) + .to(dtype=dtype, device="cuda") + ) + permutation = torch.argsort(torch.rand_like(gating_output), dim=-1) + gating_output = torch.gather(gating_output, dim=-1, index=permutation) + assert gating_output.is_contiguous() + + # run both implementations + (scores_torch, indices_torch), _ = run_torch(gating_output.clone(), topk) + (scores_fused, indices_fused), _ = run_fused(gating_output.clone(), topk) + + # check correctness + score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01) + index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01) + + # Assert correctness + assert score_errors <= 0.01, f"Score errors {score_errors} exceed tolerance" + assert index_errors <= 0.01, f"Index errors {index_errors} exceed tolerance" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Test topk_sigmoid operation with various configurations" + ) + parser.add_argument( + "--num-experts", + type=str2tuple, + default=[128], + help="Comma-separated list of number of experts (default: 16,128)", + ) + parser.add_argument( + "--num-tokens", + type=str2tuple, + default=[1024], + help="Comma-separated list of number of tokens (default: 1024)", + ) + parser.add_argument( + "--topk", + type=str2tuple, + default=[8], + help="Comma-separated list of topk values (default: 1,2,8)", + ) + parser.add_argument( + "--dtype", + type=str2Dtype, + default=[torch.float16, torch.bfloat16], + help="Comma-separated list of dtypes: fp16, bf16 (default: fp16,bf16)", + ) + + args = parser.parse_args() + + # Get parsed parameter lists + num_experts_list = args.num_experts + num_tokens_list = args.num_tokens + topk_list = args.topk + dtype_list = args.dtype + + # Run all combinations (cartesian product) + configs = list( + itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list) + ) + + print(f"Running {len(configs)} configuration(s):") + print(f" num_experts: {num_experts_list}") + print(f" num_tokens: {num_tokens_list}") + print(f" topk: {topk_list}") + print(f" dtype: {[str(dt).split('.')[-1] for dt in dtype_list]}") + print("=" * 80) + + # Collect results from all configurations + collected = [] + for i, (num_experts, num_tokens, topk, dtype) in enumerate(configs, 1): + result = benchmark_topk_sigmoid( + num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype + ) + collected.append(result) + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + + # Create and print DataFrame + df = pd.DataFrame(collected) + print(df.to_string(index=False)) + + # Print additional statistics + print("\n" + "=" * 80) + print(f"Average uplift: {df['uplift'].mean():.2f}x") + print(f"Max uplift: {df['uplift'].max():.2f}x") + print(f"Min uplift: {df['uplift'].min():.2f}x") + + # Check for any errors + errors = df[(df["score_errors"] > 0.01) | (df["index_errors"] > 0.01)] + if len(errors) > 0: + print( + f"\nWARNING: {len(errors)} configuration(s) had errors exceeding tolerance!" + ) + print(errors.to_string(index=False)) + else: + print("\nAll tests passed with errors within tolerance!") + print("=" * 80)