Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiter/ops/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def top_k_per_row_prefill(
numRows: int,
stride0: int,
stride1: int,
k: int = 2048,
) -> None: ...


Expand All @@ -232,6 +233,7 @@ def top_k_per_row_decode(
numRows: int,
stride0: int,
stride1: int,
k: int = 2048,
) -> None: ...


Expand Down
42 changes: 22 additions & 20 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,26 +1646,28 @@ namespace py = pybind11;
m.def("rocb_mm", &RocSolIdxBlas, "mm"); \
m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
&top_k_per_row_prefill, \
py::arg("logits"), \
py::arg("rowStarts"), \
py::arg("rowEnds"), \
py::arg("indices"), \
py::arg("values"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
m.def("top_k_per_row_decode", \
&top_k_per_row_decode, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1"));
#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
&top_k_per_row_prefill, \
py::arg("logits"), \
py::arg("rowStarts"), \
py::arg("rowEnds"), \
py::arg("indices"), \
py::arg("values"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1"), \
py::arg("k") = 2048); \
m.def("top_k_per_row_decode", \
&top_k_per_row_decode, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1"), \
py::arg("k") = 2048);

#define MLA_METADATA_PYBIND \
m.def("get_mla_metadata_v1", \
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/topk_per_row.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ void top_k_per_row_prefill(const torch::Tensor& logits,
std::optional<torch::Tensor> values,
int64_t numRows,
int64_t stride0,
int64_t stride1);
int64_t stride1,
int64_t k = 2048);

void top_k_per_row_decode(const torch::Tensor& logits,
int64_t next_n,
const torch::Tensor& seqLens,
torch::Tensor& indices,
int64_t numRows,
int64_t stride0,
int64_t stride1);
int64_t stride1,
int64_t k = 2048);
27 changes: 17 additions & 10 deletions csrc/kernels/topk_per_row_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2435,7 +2435,9 @@ static __global__ void topk_per_row_decode(
} // namespace aiter

template <typename T, aiter::Phase phase = aiter::Phase::Prefill>
int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows,
int32_t stride0,
int k_param = 2048)
{
using IdxT = int32_t;

Expand All @@ -2449,7 +2451,7 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
constexpr bool fused_last_filter = false;
constexpr bool sorted = true;
constexpr bool is_largest = true;
constexpr int k = 2048;
int k = k_param;

int sm_cnt = get_num_cu_func();
unsigned grid_dim =
Expand Down Expand Up @@ -2497,7 +2499,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
}

// Explicit template instantiation to ensure the symbol is available for linking
template int64_t invokeComputeTopkLastDimWorkspaceSize<float>(int32_t numRows, int32_t stride0);
template int64_t invokeComputeTopkLastDimWorkspaceSize<float>(int32_t numRows,
int32_t stride0,
int k_param);

void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
Expand All @@ -2506,15 +2510,17 @@ void top_k_per_row_prefill(const torch::Tensor& logits,
std::optional<torch::Tensor> values,
int64_t numRows,
int64_t stride0,
int64_t stride1)
int64_t stride1,
int64_t k)
{
size_t buf_size = 0; // will be overwritten by the kernel

static constexpr int kTopK = 2048;
int kTopK = static_cast<int>(k);
static constexpr bool is_largest = true;

const hipStream_t stream = at::hip::getCurrentHIPStream();
int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize<float>(numRows, stride0);
int64_t workspace_size =
invokeComputeTopkLastDimWorkspaceSize<float>(numRows, stride0, kTopK);
// int64_t workspace_size = int64_t(1024)*1024*1024*2;
Comment on lines 2516 to 2524
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device());
torch::Tensor workspace = torch::empty({workspace_size}, options);
Expand Down Expand Up @@ -2630,16 +2636,17 @@ void top_k_per_row_decode(const torch::Tensor& logits,
torch::Tensor& indices,
int64_t numRows,
int64_t stride0,
int64_t stride1)
int64_t stride1,
int64_t k)
{
size_t buf_size = 0; // will be overwritten by the kernel

static constexpr int kTopK = 2048;
int kTopK = static_cast<int>(k);
static constexpr bool is_largest = true;

const hipStream_t stream = at::hip::getCurrentHIPStream();
int64_t workspace_size =
invokeComputeTopkLastDimWorkspaceSize<float, aiter::Phase::Decode>(numRows, stride0);
int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize<float, aiter::Phase::Decode>(
numRows, stride0, kTopK);
Comment on lines 2641 to +2649
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device());
torch::Tensor workspace = torch::empty({workspace_size}, options);

Expand Down
7 changes: 5 additions & 2 deletions csrc/kernels/topk_plain_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,13 @@ void standalone_stable_radix_11bits(void* buf,

// Forward declaration of workspace size calculation function (at global scope)
template <typename T, aiter::Phase phase = aiter::Phase::Prefill>
int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0);
int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows,
int32_t stride0,
int k_param = 2048);
extern template int64_t
invokeComputeTopkLastDimWorkspaceSize<float, aiter::Phase::Prefill>(int32_t numRows,
int32_t stride0);
int32_t stride0,
int k_param);

// Forward declaration of helper function to call topk_per_row kernel
template <typename IdxT>
Expand Down
22 changes: 18 additions & 4 deletions op_tests/test_topk_per_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def run_top_k_per_row_prefill(
num_rows: int,
stride_row: int,
stride_col: int,
k: int = 2048,
) -> None:
"""
Run the top_k_per_row kernel.
Expand All @@ -146,6 +147,7 @@ def run_top_k_per_row_prefill(
num_rows,
stride_row,
stride_col,
k=k,
)


Expand All @@ -159,11 +161,17 @@ def run_top_k_per_row_decode(
stride0: int,
stride1: int,
fast: bool,
k: int = 2048,
) -> None:
"""
Run the top_k_per_row kernel.

Note: the `_fast` ASM-kernel variant has `kTopK=2048` baked into its
precompiled `.co`; it ignores any caller-supplied `k`. The dispatch
here only allows `_fast` when k == 2048.
"""
if fast:
assert k == 2048, "top_k_per_row_decode_fast only supports k=2048"
return aiter.top_k_per_row_decode_fast(
logits,
next_n,
Expand All @@ -182,6 +190,7 @@ def run_top_k_per_row_decode(
numRows,
stride0,
stride1,
k=k,
)


Expand Down Expand Up @@ -216,6 +225,7 @@ def test_top_k_per_row_prefill(
num_rows,
logits.stride(0),
logits.stride(1),
k=top_k,
)

# Run reference implementation
Expand Down Expand Up @@ -277,6 +287,7 @@ def test_top_k_per_row_decode(
logits.stride(0),
logits.stride(1),
fast,
k=top_k,
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -319,10 +330,12 @@ def test_top_k_per_row_decode(
"-k",
"--top_k",
type=int,
default=[2048],
default=[512, 1024, 2048],
nargs="+",
help="""top-k elements per row.
e.g.: -k 2048""",
help="""top-k elements per row. The radix backend supports any positive
int; the `_fast` ASM-kernel path only supports 2048 and is skipped
for other values.
e.g.: -k 512 1024 2048""",
)

parser.add_argument(
Expand Down Expand Up @@ -391,7 +404,8 @@ def test_top_k_per_row_decode(
m, ctx, k, n, data_generation, False
)
df.append(ret)
if get_gfx() == "gfx942":
# `_fast` ASM kernel hardcodes k=2048; skip otherwise.
if get_gfx() == "gfx942" and k == 2048:
ret = test_top_k_per_row_decode(
m, ctx, k, n, data_generation, True
)
Expand Down
Loading