Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 254 files
16 changes: 16 additions & 0 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,22 @@
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_moe_topk": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_topk_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_ck/topk_sigmoid_kernels.cu'",
"f'{CK_DIR}/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/include/ck_tile'",
"f'{CK_DIR}/example/ck_tile/09_topk_softmax'"
],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_norm": {
"srcs": [
"f'{AITER_CSRC_DIR}/py_itfs_ck/norm_kernels.cu'",
Expand Down
6 changes: 6 additions & 0 deletions aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def topk_softmax_asm(
) -> None: ...


@compile_ops("module_moe_topk")
def topk_sigmoid(
topk_weights: Tensor, topk_indices: Tensor, gating_output: Tensor
) -> None: ...


@compile_ops("module_moe_asm")
def moe_sum(input: Tensor, output: Tensor) -> None: ...

Expand Down
4 changes: 4 additions & 0 deletions csrc/include/moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,8 @@ void moe_align_block_size(torch::Tensor topk_ids,

void moe_sum(torch::Tensor& input, torch::Tensor& output);

void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk]
torch::Tensor topk_indices, // [tokens, topk]
torch::Tensor gating_output); // [tokens, experts]

} // namespace aiter
8 changes: 8 additions & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,14 @@
py::arg("sorted_weights") = std::nullopt); \
m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()");

#define MOE_TOPK_PYBIND \
m.def("topk_sigmoid", \
&aiter::topk_sigmoid, \
py::arg("topk_weights"), \
py::arg("topk_indices"), \
py::arg("gating_output"), \
"Apply topk sigmoid to the gating outputs.");

#define MOE_SORTING_PYBIND \
m.def("moe_sorting_fwd", \
&moe_sorting_fwd, \
Expand Down
72 changes: 72 additions & 0 deletions csrc/py_itfs_ck/topk_sigmoid_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include "py_itfs_common.h"

// from CK examples:
#include "topk_softmax_api.hpp"

namespace aiter
{

void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk]
torch::Tensor topk_indices, // [tokens, topk]
torch::Tensor gating_output) // [tokens, experts]
{
// Ensure the tensors are on the correct device
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output));

// Extract dimensions
const int tokens = gating_output.size(0);
const int experts = gating_output.size(1);
const int topk = topk_weights.size(1);

// Assume default strides
const int stride_input = experts;
const int stride_output = topk;

// Determine datatypes
auto dtype_to_string = [](const auto dtype) -> std::string {
if(dtype == torch::kFloat16)
{
return "fp16";
}
else if(dtype == torch::kBFloat16)
{
return "bf16";
}
else if(dtype == torch::kFloat32)
{
return "fp32";
}
else
{
throw std::runtime_error("invalid datatype for topk_sigmoid: only fp16/bf16/fp32!");
}
};
std::string input_prec = dtype_to_string(gating_output.dtype());
std::string weight_prec = dtype_to_string(topk_weights.dtype());

// Prepare kernel arguments
static const std::string activation = "sigmoid";
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};

topk_softmax_kargs karg{gating_output.data_ptr(),
topk_weights.data_ptr(),
topk_indices.data_ptr(),
tokens,
experts,
topk,
stride_input,
stride_output};

ck_tile::stream_config sc{at::hip::getCurrentHIPStream()};

topk_softmax(trait, karg, sc);
}

} // namespace aiter
10 changes: 10 additions & 0 deletions csrc/pybind/moe_topk_pybind.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/* SPDX-License-Identifier: MIT
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
*/
#include "moe_op.h"
#include "rocm_ops.hpp"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
MOE_TOPK_PYBIND;
}
1 change: 1 addition & 0 deletions csrc/rocm_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
ATTENTION_RAGGED_PYBIND;
ATTENTION_V1_PYBIND;
MOE_OP_PYBIND;
MOE_TOPK_PYBIND;
ROPE_GENERAL_FWD_PYBIND;
ROPE_GENERAL_BWD_PYBIND;
ROPE_POS_FWD_PYBIND;
Expand Down
221 changes: 221 additions & 0 deletions op_tests/test_moe_topk_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.

"""
Test topk_sigmoid operation with various configurations.

This test can be run in two ways:

1. Using pytest (for automated testing):
pytest test_moe_topk_sigmoid.py -v

2. Using command line arguments (for benchmarking with summary table):
python test_moe_topk_sigmoid.py --num-experts 64,128 --topk 2,4,8 --dtype fp16
"""

import argparse
import itertools

import pandas as pd
import pytest
import torch
import aiter
from aiter.test_common import (
checkAllclose,
perftest,
)
from aiter.utility.dtypes import str2Dtype, str2tuple


@perftest(num_iters=10, num_warmup=1)
def run_torch(gating_output: torch.Tensor, topk: int):
# llama4 maverick custom routing function
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return router_scores, router_indices.to(torch.int32)


@perftest(num_iters=10, num_warmup=1)
def run_fused(gating_output: torch.Tensor, topk: int):
tokens, _ = gating_output.shape
router_scores = torch.empty(
(tokens, topk), dtype=torch.float32, device=gating_output.device
)
router_indices = torch.empty(
(tokens, topk), dtype=torch.int32, device=gating_output.device
)
aiter.topk_sigmoid(router_scores, router_indices, gating_output)
return router_scores, router_indices


def benchmark_topk_sigmoid(
num_experts: int = 128,
num_tokens: int = 1024,
topk: int = 4,
dtype: torch.dtype = torch.float16,
):
# generate data - each row has only unique values
gating_output = (
torch.arange(-1, 1, 2.0 / num_experts)
.repeat((num_tokens, 1))
.to(dtype=dtype, device="cuda")
)
permutation = torch.argsort(torch.rand_like(gating_output), dim=-1)
gating_output = torch.gather(gating_output, dim=-1, index=permutation)
assert gating_output.is_contiguous()
# run benchmarks
(scores_torch, indices_torch), avg_torch = run_torch(gating_output.clone(), topk)
(scores_fused, indices_fused), avg_fused = run_fused(gating_output.clone(), topk)
# check correctness
score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01)
index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01)

# Collect results for summary
result = {
"num_experts": num_experts,
"num_tokens": num_tokens,
"topk": topk,
"dtype": str(dtype).split(".")[-1],
"torch_us": avg_torch,
"fused_us": avg_fused,
"uplift": avg_torch / avg_fused,
"score_errors": score_errors,
"index_errors": index_errors,
}

# print some failed rows if errors are significant
if score_errors > 0.01 or index_errors > 0.01:
failed_rows = (indices_torch != indices_fused).sum(dim=-1) > 0
print(
f"\n[ERROR] Configuration: num_experts={num_experts}, num_tokens={num_tokens}, topk={topk}, dtype={str(dtype).split('.')[-1]}"
)
print("Wrong scores:")
print(scores_torch[failed_rows][:5])
print(scores_fused[failed_rows][:5])
print("Wrong indices:")
print(indices_torch[failed_rows][:5])
print(indices_fused[failed_rows][:5])
print("Gating outputs:")
failed_values = gating_output[failed_rows][:5]
failed_values, _ = failed_values.sort(dim=-1, descending=True)
print(failed_values[:, :10])
print(
f"Number of wrong tokens: {sum(failed_rows)} / {len(failed_rows)}, {100 * sum(failed_rows) / len(failed_rows):.2f} %"
)

return result


# Pytest-parametrized test functions
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("topk", [1, 2, 4, 8])
@pytest.mark.parametrize("num_tokens", [64, 1024, 2048])
@pytest.mark.parametrize("num_experts", [64, 128])
def test_topk_sigmoid_correctness(num_experts, num_tokens, topk, dtype):
"""Pytest test for correctness of topk_sigmoid operation."""
torch.random.manual_seed(0)

# generate data - each row has only unique values
gating_output = (
torch.arange(-1, 1, 2.0 / num_experts)
.repeat((num_tokens, 1))
.to(dtype=dtype, device="cuda")
)
permutation = torch.argsort(torch.rand_like(gating_output), dim=-1)
gating_output = torch.gather(gating_output, dim=-1, index=permutation)
assert gating_output.is_contiguous()

# run both implementations
(scores_torch, indices_torch), _ = run_torch(gating_output.clone(), topk)
(scores_fused, indices_fused), _ = run_fused(gating_output.clone(), topk)

# check correctness
score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01)
index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01)

# Assert correctness
assert score_errors <= 0.01, f"Score errors {score_errors} exceed tolerance"
assert index_errors <= 0.01, f"Index errors {index_errors} exceed tolerance"


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test topk_sigmoid operation with various configurations"
)
parser.add_argument(
"--num-experts",
type=str2tuple,
default=[128],
help="Comma-separated list of number of experts (default: 16,128)",
)
parser.add_argument(
"--num-tokens",
type=str2tuple,
default=[1024],
help="Comma-separated list of number of tokens (default: 1024)",
)
parser.add_argument(
"--topk",
type=str2tuple,
default=[8],
help="Comma-separated list of topk values (default: 1,2,8)",
)
parser.add_argument(
"--dtype",
type=str2Dtype,
default=[torch.float16, torch.bfloat16],
help="Comma-separated list of dtypes: fp16, bf16 (default: fp16,bf16)",
)

args = parser.parse_args()

# Get parsed parameter lists
num_experts_list = args.num_experts
num_tokens_list = args.num_tokens
topk_list = args.topk
dtype_list = args.dtype

# Run all combinations (cartesian product)
configs = list(
itertools.product(num_experts_list, num_tokens_list, topk_list, dtype_list)
)

print(f"Running {len(configs)} configuration(s):")
print(f" num_experts: {num_experts_list}")
print(f" num_tokens: {num_tokens_list}")
print(f" topk: {topk_list}")
print(f" dtype: {[str(dt).split('.')[-1] for dt in dtype_list]}")
print("=" * 80)

# Collect results from all configurations
collected = []
for i, (num_experts, num_tokens, topk, dtype) in enumerate(configs, 1):
result = benchmark_topk_sigmoid(
num_experts=num_experts, num_tokens=num_tokens, topk=topk, dtype=dtype
)
collected.append(result)

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)

# Create and print DataFrame
df = pd.DataFrame(collected)
print(df.to_string(index=False))

# Print additional statistics
print("\n" + "=" * 80)
print(f"Average uplift: {df['uplift'].mean():.2f}x")
print(f"Max uplift: {df['uplift'].max():.2f}x")
print(f"Min uplift: {df['uplift'].min():.2f}x")

# Check for any errors
errors = df[(df["score_errors"] > 0.01) | (df["index_errors"] > 0.01)]
if len(errors) > 0:
print(
f"\nWARNING: {len(errors)} configuration(s) had errors exceeding tolerance!"
)
print(errors.to_string(index=False))
else:
print("\nAll tests passed with errors within tolerance!")
print("=" * 80)