diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index 809a23c08a..2ec6ed15a3 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -207,6 +207,7 @@ def top_k_per_row_prefill( numRows: int, stride0: int, stride1: int, + k: int = 2048, ) -> None: ... @@ -232,6 +233,7 @@ def top_k_per_row_decode( numRows: int, stride0: int, stride1: int, + k: int = 2048, ) -> None: ... diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index f43072b1c1..d2a2bd654f 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1646,26 +1646,28 @@ namespace py = pybind11; m.def("rocb_mm", &RocSolIdxBlas, "mm"); \ m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols"); -#define TOP_K_PER_ROW_PYBIND \ - m.def("top_k_per_row_prefill", \ - &top_k_per_row_prefill, \ - py::arg("logits"), \ - py::arg("rowStarts"), \ - py::arg("rowEnds"), \ - py::arg("indices"), \ - py::arg("values"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ - py::arg("stride1")); \ - m.def("top_k_per_row_decode", \ - &top_k_per_row_decode, \ - py::arg("logits"), \ - py::arg("next_n"), \ - py::arg("seqLens"), \ - py::arg("indices"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ - py::arg("stride1")); +#define TOP_K_PER_ROW_PYBIND \ + m.def("top_k_per_row_prefill", \ + &top_k_per_row_prefill, \ + py::arg("logits"), \ + py::arg("rowStarts"), \ + py::arg("rowEnds"), \ + py::arg("indices"), \ + py::arg("values"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ + py::arg("stride1"), \ + py::arg("k") = 2048); \ + m.def("top_k_per_row_decode", \ + &top_k_per_row_decode, \ + py::arg("logits"), \ + py::arg("next_n"), \ + py::arg("seqLens"), \ + py::arg("indices"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ + py::arg("stride1"), \ + py::arg("k") = 2048); #define MLA_METADATA_PYBIND \ m.def("get_mla_metadata_v1", \ diff --git a/csrc/include/topk_per_row.h b/csrc/include/topk_per_row.h index e3bae1887d..fe4970d412 100644 --- a/csrc/include/topk_per_row.h +++ b/csrc/include/topk_per_row.h @@ -9,7 +9,8 @@ void top_k_per_row_prefill(const torch::Tensor& logits, std::optional values, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t stride1, + int64_t k = 2048); void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, @@ -17,4 +18,5 @@ void top_k_per_row_decode(const torch::Tensor& logits, torch::Tensor& indices, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t stride1, + int64_t k = 2048); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 6edf377ca8..199626e0cb 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -2435,7 +2435,9 @@ static __global__ void topk_per_row_decode( } // namespace aiter template -int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, + int32_t stride0, + int k_param = 2048) { using IdxT = int32_t; @@ -2449,7 +2451,7 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) constexpr bool fused_last_filter = false; constexpr bool sorted = true; constexpr bool is_largest = true; - constexpr int k = 2048; + int k = k_param; int sm_cnt = get_num_cu_func(); unsigned grid_dim = @@ -2497,7 +2499,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) } // Explicit template instantiation to ensure the symbol is available for linking -template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); +template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, + int32_t stride0, + int k_param); void top_k_per_row_prefill(const torch::Tensor& logits, const torch::Tensor& rowStarts, @@ -2506,15 +2510,17 @@ void top_k_per_row_prefill(const torch::Tensor& logits, std::optional values, int64_t numRows, int64_t stride0, - int64_t stride1) + int64_t stride1, + int64_t k) { size_t buf_size = 0; // will be overwritten by the kernel - static constexpr int kTopK = 2048; + int kTopK = static_cast(k); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); - int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize(numRows, stride0); + int64_t workspace_size = + invokeComputeTopkLastDimWorkspaceSize(numRows, stride0, kTopK); // int64_t workspace_size = int64_t(1024)*1024*1024*2; auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device()); torch::Tensor workspace = torch::empty({workspace_size}, options); @@ -2630,16 +2636,17 @@ void top_k_per_row_decode(const torch::Tensor& logits, torch::Tensor& indices, int64_t numRows, int64_t stride0, - int64_t stride1) + int64_t stride1, + int64_t k) { size_t buf_size = 0; // will be overwritten by the kernel - static constexpr int kTopK = 2048; + int kTopK = static_cast(k); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); - int64_t workspace_size = - invokeComputeTopkLastDimWorkspaceSize(numRows, stride0); + int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize( + numRows, stride0, kTopK); auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device()); torch::Tensor workspace = torch::empty({workspace_size}, options); diff --git a/csrc/kernels/topk_plain_kernels.cu b/csrc/kernels/topk_plain_kernels.cu index 7c03823ae0..f96674c0af 100644 --- a/csrc/kernels/topk_plain_kernels.cu +++ b/csrc/kernels/topk_plain_kernels.cu @@ -92,10 +92,13 @@ void standalone_stable_radix_11bits(void* buf, // Forward declaration of workspace size calculation function (at global scope) template -int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, + int32_t stride0, + int k_param = 2048); extern template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, - int32_t stride0); + int32_t stride0, + int k_param); // Forward declaration of helper function to call topk_per_row kernel template diff --git a/op_tests/test_topk_per_row.py b/op_tests/test_topk_per_row.py index 34055eb1a6..25aae216e9 100755 --- a/op_tests/test_topk_per_row.py +++ b/op_tests/test_topk_per_row.py @@ -133,6 +133,7 @@ def run_top_k_per_row_prefill( num_rows: int, stride_row: int, stride_col: int, + k: int = 2048, ) -> None: """ Run the top_k_per_row kernel. @@ -146,6 +147,7 @@ def run_top_k_per_row_prefill( num_rows, stride_row, stride_col, + k=k, ) @@ -159,11 +161,17 @@ def run_top_k_per_row_decode( stride0: int, stride1: int, fast: bool, + k: int = 2048, ) -> None: """ Run the top_k_per_row kernel. + + Note: the `_fast` ASM-kernel variant has `kTopK=2048` baked into its + precompiled `.co`; it ignores any caller-supplied `k`. The dispatch + here only allows `_fast` when k == 2048. """ if fast: + assert k == 2048, "top_k_per_row_decode_fast only supports k=2048" return aiter.top_k_per_row_decode_fast( logits, next_n, @@ -182,6 +190,7 @@ def run_top_k_per_row_decode( numRows, stride0, stride1, + k=k, ) @@ -216,6 +225,7 @@ def test_top_k_per_row_prefill( num_rows, logits.stride(0), logits.stride(1), + k=top_k, ) # Run reference implementation @@ -277,6 +287,7 @@ def test_top_k_per_row_decode( logits.stride(0), logits.stride(1), fast, + k=top_k, ) torch.cuda.synchronize() @@ -319,10 +330,12 @@ def test_top_k_per_row_decode( "-k", "--top_k", type=int, - default=[2048], + default=[512, 1024, 2048], nargs="+", - help="""top-k elements per row. - e.g.: -k 2048""", + help="""top-k elements per row. The radix backend supports any positive + int; the `_fast` ASM-kernel path only supports 2048 and is skipped + for other values. + e.g.: -k 512 1024 2048""", ) parser.add_argument( @@ -391,7 +404,8 @@ def test_top_k_per_row_decode( m, ctx, k, n, data_generation, False ) df.append(ret) - if get_gfx() == "gfx942": + # `_fast` ASM kernel hardcodes k=2048; skip otherwise. + if get_gfx() == "gfx942" and k == 2048: ret = test_top_k_per_row_decode( m, ctx, k, n, data_generation, True )