diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 46494488ac..17f90ef867 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -935,7 +935,7 @@ "verbose": "False", "blob_gen_cmd": "''" }, - "module_topk_per_row": { + "module_top_k_per_row": { "srcs": [ "f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'", "f'{AITER_CSRC_DIR}/pybind/topk_per_row_pybind.cu'" diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index a72a7e3968..b29f8cc27a 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -196,8 +196,8 @@ def grouped_topk_torch( return topk_weights.to(dtypes.fp32), topk_ids.to(dtypes.i32) -@compile_ops("module_topk_per_row") -def topk_per_row( +@compile_ops("module_top_k_per_row") +def top_k_per_row_prefill( logits: torch.Tensor, rowStarts: torch.Tensor, rowEnds: torch.Tensor, @@ -208,8 +208,8 @@ def topk_per_row( ) -> None: ... -@compile_ops("module_topk_per_row") -def topk_per_row_decode( +@compile_ops("module_top_k_per_row") +def top_k_per_row_decode( logits: torch.Tensor, next_n: int, seqLens: torch.Tensor, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 7926085f17..ba4af687b9 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1289,56 +1289,56 @@ namespace py = pybind11; #define GEMM_COMMON_PYBIND \ m.def("get_padded_m", &getPaddedM, py::arg("M"), py::arg("N"), py::arg("K"), py::arg("gl")); -#define TOPK_PER_ROW_PYBIND \ - m.def("topk_per_row", \ - &topk_per_row, \ - py::arg("logits"), \ - py::arg("rowStarts"), \ - py::arg("rowEnds"), \ - py::arg("indices"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ - py::arg("stride1")); \ - m.def("topk_per_row_decode", \ - &topk_per_row_decode, \ - py::arg("logits"), \ - py::arg("next_n"), \ - py::arg("seqLens"), \ - py::arg("indices"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ +#define TOP_K_PER_ROW_PYBIND \ + m.def("top_k_per_row_prefill", \ + &top_k_per_row_prefill, \ + py::arg("logits"), \ + py::arg("rowStarts"), \ + py::arg("rowEnds"), \ + py::arg("indices"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ + py::arg("stride1")); \ + m.def("top_k_per_row_decode", \ + &top_k_per_row_decode, \ + py::arg("logits"), \ + py::arg("next_n"), \ + py::arg("seqLens"), \ + py::arg("indices"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ py::arg("stride1")); -#define MLA_METADATA_PYBIND \ - m.def("get_mla_metadata_v1", \ - &get_mla_metadata_v1, \ - "get_mla_metadata_v1", \ - py::arg("seqlens_qo_indptr"), \ - py::arg("seqlens_kv_indptr"), \ - py::arg("num_heads_per_head_k"), \ - py::arg("num_heads_k"), \ - py::arg("is_causal"), \ - py::arg("work_metadata_ptrs"), \ - py::arg("work_info_set"), \ - py::arg("work_indptr"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("kv_granularity") = 16, \ - py::arg("max_seqlen_qo") = -1, \ - py::arg("uni_seqlen_qo") = -1, \ - py::arg("fast_mode") = true, \ - py::arg("topk") = -1); \ +#define MLA_METADATA_PYBIND \ + m.def("get_mla_metadata_v1", \ + &get_mla_metadata_v1, \ + "get_mla_metadata_v1", \ + py::arg("seqlens_qo_indptr"), \ + py::arg("seqlens_kv_indptr"), \ + py::arg("num_heads_per_head_k"), \ + py::arg("num_heads_k"), \ + py::arg("is_causal"), \ + py::arg("work_metadata_ptrs"), \ + py::arg("work_info_set"), \ + py::arg("work_indptr"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("kv_granularity") = 16, \ + py::arg("max_seqlen_qo") = -1, \ + py::arg("uni_seqlen_qo") = -1, \ + py::arg("fast_mode") = true, \ + py::arg("topk") = -1); \ m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant); -#define MLA_REDUCE_PYBIND \ - m.def("mla_reduce_v1", \ - &mla_reduce_v1, \ - "mla_reduce_v1", \ - py::arg("partial_output"), \ - py::arg("partial_lse"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("final_output"), \ - py::arg("final_lse") = std::nullopt); +#define MLA_REDUCE_PYBIND \ + m.def("mla_reduce_v1", \ + &mla_reduce_v1, \ + "mla_reduce_v1", \ + py::arg("partial_output"), \ + py::arg("partial_lse"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("final_output"), \ + py::arg("final_lse") = std::nullopt); diff --git a/csrc/include/topk_per_row.h b/csrc/include/topk_per_row.h index dae19dbdf1..b2894fe54a 100644 --- a/csrc/include/topk_per_row.h +++ b/csrc/include/topk_per_row.h @@ -2,18 +2,18 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -void topk_per_row(const torch::Tensor& logits, - const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, - torch::Tensor& indices, - int64_t numRows, - int64_t stride0, - int64_t stride1); +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, + torch::Tensor& indices, + int64_t numRows, + int64_t stride0, + int64_t stride1); -void topk_per_row_decode(const torch::Tensor& logits, - int64_t next_n, - const torch::Tensor& seqLens, - torch::Tensor& indices, - int64_t numRows, - int64_t stride0, - int64_t stride1); +void top_k_per_row_decode(const torch::Tensor& logits, + int64_t next_n, + const torch::Tensor& seqLens, + torch::Tensor& indices, + int64_t numRows, + int64_t stride0, + int64_t stride1); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 17f1d22938..af59458122 100755 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -47,23 +47,238 @@ struct to_vector<4> using type = fp32x4; }; -template +static inline __device__ uint32_t floatAsSortableUint(float x) +{ + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + return bits; +} + +template +static inline __device__ uint32_t extractBinIdx(float x) +{ + uint32_t bits = floatAsSortableUint(x); + + if constexpr(step == 0) + { + return bits >> 21; + } + else if constexpr(step == 1) + { + return (bits >> 10) & 0x7ff; + } + else + { + return bits & 0x3ff; + } +} + +template +static inline __device__ bool isPartialMatch(float x, uint32_t pattern) +{ + if constexpr(shift == 0) + { + return true; + } + uint32_t bits = floatAsSortableUint(x); + return (bits ^ pattern) >> shift == 0; +} + +template +__device__ bool processHistogramStep(const float* logits, + int rowEnd, + uint32_t& logitPattern, + int& thresholdBinIdx, + int* smemHistogram, + int* smemIndices, + int* smemThresholdBinIdx, + int* smemFinalDstIdx, + int* smemFinalBinSize, + int* smemFoundTopKValues, + SmemFinalType& smemFinal, + int stride1, + int rowStart) +{ + using VectorType = typename to_vector::type; + // Clear the histogram. +#pragma unroll + for(int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) + { + smemHistogram[idx] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Update pattern + constexpr auto patternShift = step == 0 ? 0 : step == 1 ? 21 : 10; + if constexpr(step == 1) + { + logitPattern = static_cast(thresholdBinIdx & 0x7ff) << patternShift; + } + else if constexpr(step == 2) + { + logitPattern |= static_cast(thresholdBinIdx & 0x7ff) << patternShift; + } + + // Fetch elements one-by-one. + for(int vecIdx = (rowStart / Vector) + threadIdx.x; vecIdx < (rowEnd + Vector - 1) / Vector; + vecIdx += kNumThreadsPerBlock) + { + auto v = reinterpret_cast(logits)[vecIdx]; +#pragma unroll + for(int j = 0; j < Vector; j++) + { + int vIdx = vecIdx * Vector + j; + if(vIdx >= rowEnd) + break; + float logit = v[j]; + if(isPartialMatch(logit, logitPattern)) + { + uint32_t binIdx = extractBinIdx(logit); + atomicAdd(&smemHistogram[binIdx], 1); + } + } + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Reads the value of the starting position in the smemIndices array + int lastValue = smemFoundTopKValues[0]; + + for(int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) + { + // Read the values from SMEM. + int idx = threadIdx.x + kNumThreadsPerBlock * round; + int binCount{0}; + binCount = smemHistogram[idx]; + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + using Scan = hipcub::BlockScan; + Scan(smemFinal.smemScan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + prefixSum += lastValue; + totalSum += lastValue; + smemHistogram[idx] = prefixSum; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + bool foundThreshold = false; + if(prefixSum < kTopK) + { + int nextPrefixSum = + threadIdx.x == kNumThreadsPerBlock - 1 ? totalSum : smemHistogram[idx + 1]; + + if(nextPrefixSum >= kTopK) + { + smemThresholdBinIdx[0] = idx; + smemFinalBinSize[0] = nextPrefixSum - prefixSum; + smemFoundTopKValues[0] = prefixSum; + foundThreshold = true; + } + } + + // Early exit: if any thread found the threshold, we can skip remaining + // rounds + if(__syncthreads_or(foundThreshold)) + { + break; + } + + lastValue = totalSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + thresholdBinIdx = smemThresholdBinIdx[0]; + + // Fetch elements one-by-one and populate the shared memory buffers. + for(int vecIdx = (rowStart / Vector) + threadIdx.x; vecIdx < (rowEnd + Vector - 1) / Vector; + vecIdx += kNumThreadsPerBlock) + { + // Compute the vector offset for coalesced VectorType load + auto v = reinterpret_cast(logits)[vecIdx]; +#pragma unroll + for(int j = 0; j < Vector; j++) + { + int vIdx = vecIdx * Vector + j; + if(vIdx >= rowEnd) + break; + float logit = v[j]; + + // Check for pattern match + if(!isPartialMatch(logit, logitPattern)) + continue; + + uint32_t binIdx = extractBinIdx(logit); + + if(binIdx < thresholdBinIdx) + { + int dstIdx = atomicAdd(&smemHistogram[binIdx], 1); + smemIndices[dstIdx] = vIdx; + } + + if constexpr(step < 2) + { + // Fill final items only if threshold bin fits + if(binIdx == thresholdBinIdx && smemFinalBinSize[0] <= kNumFinalItems) + { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + smemFinal.items.logits[dstIdx] = logit; + smemFinal.items.indices[dstIdx] = vIdx; + } + } + else + { + if(binIdx == thresholdBinIdx) + { + int dstIdx = atomicAdd(&smemHistogram[binIdx], 1); + if(dstIdx < kTopK) + { + smemIndices[dstIdx] = vIdx; + } + } + } + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // Check if we should continue to next step + return smemFinalBinSize[0] > kNumFinalItems; +} + +template __device__ void topk_per_row_kernel(const float* logits, const int rowStart, const int rowEnd, - const int rowIdx, int* outIndices, - int stride0, int stride1) { - // The number of elements per thread for the final top-k sort. - static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; - // The class to sort the elements during the final top-k sort. - using TopKSort = - hipcub::BlockRadixSort; - // The number of slots for the final pass. - static constexpr int kNumFinalItems = 3072; + static constexpr int kNumFinalItems = 2048; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. @@ -73,8 +288,6 @@ __device__ void topk_per_row_kernel(const float* logits, // The class to compute the inclusive prefix-sum over the histogram. using Scan = hipcub::BlockScan; - using VectorType = typename to_vector::type; - // Shared memory to compute the block scan. __shared__ typename Scan::TempStorage smemScan; @@ -92,7 +305,7 @@ __device__ void topk_per_row_kernel(const float* logits, { FinalItems items; typename FinalSort::TempStorage finalSort; - typename TopKSort::TempStorage topKSort; + typename Scan::TempStorage smemScan; } smemFinal; // Shared memory to store the histogram. @@ -103,6 +316,11 @@ __device__ void topk_per_row_kernel(const float* logits, __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; + // Shared memory to determine if the threshold bin fits in the final items. + __shared__ int smemFinalBinSize[1]; + // Shared memory to keep track of the top-k values found so far by the + // previous iterations + __shared__ int smemFoundTopKValues[1]; // The length of the row. int rowLen = rowEnd - rowStart; @@ -113,229 +331,261 @@ __device__ void topk_per_row_kernel(const float* logits, { for(int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { - int idx = rowStart + rowIt; - outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + outIndices[rowIt] = rowIt - rowStart; } for(int rowIt = rowLen + threadIdx.x; rowIt < kTopK; rowIt += kNumThreadsPerBlock) { - outIndices[rowIdx * kTopK + rowIt] = -1; + outIndices[rowIt] = -1; } return; } - // Clear the histogram. - if(threadIdx.x < kNumBins) + // Initialize values + if(threadIdx.x == 0) { - smemHistogram[threadIdx.x] = 0; + smemFinalDstIdx[0] = 0; + smemFoundTopKValues[0] = 0; } - - // Make sure the histogram is ready. __syncthreads(); - - // Fetch elements one-by-one. - for(int rowIt = rowStart + threadIdx.x; rowIt < (rowEnd + Vector - 1) / Vector; - rowIt += kNumThreadsPerBlock) + int thresholdBinIdx = -1; + uint32_t logitPattern = 0; + + // Step 0: Process first 11 bits + bool continueToNextStep = + processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kTopK, kNumFinalItems, Vector>( + logits, + rowEnd, + logitPattern, + thresholdBinIdx, + smemHistogram, + smemIndices, + smemThresholdBinIdx, + smemFinalDstIdx, + smemFinalBinSize, + smemFoundTopKValues, + smemFinal, + stride1, + rowStart); + + if(continueToNextStep) { - int64_t offset = ((int64_t)rowIdx) * (stride0 / Vector) + ((int64_t)rowIt) * stride1; - auto v = reinterpret_cast(logits)[offset]; - -#pragma unroll - for(int j = 0; j < Vector; j++) + // Step 1: Process next 11 bits + continueToNextStep = + processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kTopK, kNumFinalItems, Vector>( + logits, + rowEnd, + logitPattern, + thresholdBinIdx, + smemHistogram, + smemIndices, + smemThresholdBinIdx, + smemFinalDstIdx, + smemFinalBinSize, + smemFoundTopKValues, + smemFinal, + stride1, + rowStart); + + if(continueToNextStep) { - float logit = (rowIt * Vector + j) < rowEnd ? v[j] : -INFINITY; - uint16_t idx = extractBinIdx(logit); - atomicAdd(&smemHistogram[idx], 1); + // Step 2: Process final 10 bits + processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kTopK, kNumFinalItems, Vector>( + logits, + rowEnd, + logitPattern, + thresholdBinIdx, + smemHistogram, + smemIndices, + smemThresholdBinIdx, + smemFinalDstIdx, + smemFinalBinSize, + smemFoundTopKValues, + smemFinal, + stride1, + rowStart); } } - // Make sure the histogram is ready. - __syncthreads(); - - // Read the values from SMEM. - int binCount{0}; - if(threadIdx.x < kNumBins) - { - binCount = smemHistogram[threadIdx.x]; - } - - // Make sure each thread has read its value. - __syncthreads(); - - // Compute the prefix sum. - int prefixSum{0}, totalSum{0}; - Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); - - // Update the histogram with the prefix sums. - if(threadIdx.x < kNumBins) - { - smemHistogram[threadIdx.x] = prefixSum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Find the last valid bin. - if(threadIdx.x < kNumBins) + if(!continueToNextStep) { - int nextPrefixSum = threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; - if(prefixSum < kTopK && nextPrefixSum >= kTopK) + // The histogram did not proceed to the final 10 bits, therefore we need to + // sort the final items The logits of the elements to be sorted in the final + // pass. + if constexpr(useRadixSort) { - smemThresholdBinIdx[0] = threadIdx.x; - } - } + // Sorting with radix sort + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; - // Clear the counter to store the items for the final phase. - if(threadIdx.x == 0) - { - smemFinalDstIdx[0] = 0; - } +#pragma unroll + for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) + { + finalLogits[ii] = -FLT_MAX; + } - // Make sure the data is in shared memory. - __syncthreads(); + // Read the elements from SMEM. +#pragma unroll + for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) + { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if(srcIdx < smemFinalDstIdx[0]) + { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + // Make sure the shared memory has been read. + __syncthreads(); - // The threshold bin. - int thresholdBinIdx = smemThresholdBinIdx[0]; + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); - // Fetch elements one-by-one and populate the shared memory buffers. - for(int rowIt = rowStart + threadIdx.x; rowIt < (rowEnd + Vector - 1) / Vector; - rowIt += kNumThreadsPerBlock) - { - int64_t offset = ((int64_t)rowIdx) * stride0 / Vector + ((int64_t)rowIt) * stride1; - auto v = reinterpret_cast(logits)[offset]; + // Copy the data back to the shared memory storage. + int baseIdx = smemFoundTopKValues[0]; #pragma unroll - for(auto j = 0; j < Vector; j++) - { - float logit = (rowIt * Vector + j) < rowEnd ? v[j] : -INFINITY; - // float logit = v[j]; - uint16_t idx = extractBinIdx(logit); - if(idx < thresholdBinIdx) + for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemIndices[dstIdx] = Vector * rowIt + j; + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + + if(dstIdx < kTopK) + { + smemIndices[dstIdx] = finalIndices[ii]; + } } - else if(idx == thresholdBinIdx) + } + else + { + // Sorting with insertion sort + auto baseIdx = smemFoundTopKValues[0]; + for(int i = threadIdx.x; i < smemFinalDstIdx[0]; i += kNumThreadsPerBlock) { - int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); - if(dstIdx < kNumFinalItems) + int outIndex = 0; + auto logit = smemFinal.items.logits[i]; + for(int j = 0; j < smemFinalDstIdx[0]; j++) { - smemFinal.items.logits[dstIdx] = logit; - smemFinal.items.indices[dstIdx] = Vector * rowIt + j; + auto otherLogit = smemFinal.items.logits[j]; + if(logit < otherLogit || (logit == otherLogit && i < j)) + { + outIndex++; + } + } + // Store if outIndex is in bounds + if(outIndex + baseIdx < kTopK) + { + smemIndices[outIndex + baseIdx] = smemFinal.items.indices[i]; } } } + __syncthreads(); } - // Make sure the elements are in shared memory. - // __syncthreads(); - - // The logits of the elements to be sorted in the final pass. - float finalLogits[kNumFinalItemsPerThread]; - // The indices of the elements to be sorted in the final pass. - int finalIndices[kNumFinalItemsPerThread]; - -// Init. -#pragma unroll - for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) + if constexpr(sortResultLogitDescending) { - finalLogits[ii] = -FLT_MAX; - } - - __syncthreads(); + // Sorting with radix sort + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; // Read the elements from SMEM. #pragma unroll - for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) - { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - if(srcIdx < smemFinalDstIdx[0]) + for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - finalLogits[ii] = smemFinal.items.logits[srcIdx]; - finalIndices[ii] = smemFinal.items.indices[srcIdx]; + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + const auto index = smemIndices[srcIdx]; + const auto logit = logits[index * stride1]; + finalLogits[ii] = logit; + finalIndices[ii] = index; } - } - // Make sure the shared memory has been read. - __syncthreads(); + // Make sure the shared memory has been read. + __syncthreads(); - // Sort the elements. - FinalSort(smemFinal.finalSort).SortDescendingBlockedToStriped(finalLogits, finalIndices); + // Sort the elements. + FinalSort(smemFinal.finalSort).SortDescendingBlockedToStriped(finalLogits, finalIndices); - // Copy the data back to the shared memory storage. - int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; + // Store to global memory #pragma unroll - for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) - { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - int dstIdx = baseIdx + srcIdx; - if(dstIdx < kTopK) + for(int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - smemIndices[dstIdx] = finalIndices[ii]; + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + outIndices[srcIdx] = finalIndices[ii] - rowStart; } } - // Make sure the data is in shared memory. - __syncthreads(); - -// Store to global memory. -#pragma unroll - for(int ii = 0; ii < kNumTopKItemsPerThread; ++ii) + if constexpr(!sortResultLogitDescending) { - int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; + // Store to global memory. +#pragma unroll + for(int i = threadIdx.x; i < kTopK; i += kNumThreadsPerBlock) + { + outIndices[i] = smemIndices[i] - rowStart; + } } } -template +template static __global__ void topk_per_row(const float* logits, const int* rowStarts, const int* rowEnds, int* outIndices, int stride0, - int stride1) + int stride1, + int rowOffset) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; + static constexpr int kNumBins = 2048; // The top-k width. static constexpr int kTopK = 2048; // The row computed by this block. - int rowIdx = blockIdx.x; + int64_t rowIdx = static_cast(blockIdx.x) + rowOffset; // The range of logits within the row. int rowStart = rowStarts[rowIdx]; int rowEnd = rowEnds[rowIdx]; - topk_per_row_kernel( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + auto outIndicesLocal = outIndices + rowIdx * kTopK; + auto logitsLocal = logits + rowIdx * stride0; + + topk_per_row_kernel( + logitsLocal, rowStart, rowEnd, outIndicesLocal, stride1); } -template +template static __global__ void topk_per_row_decode( const float* logits, const int* seqLens, int* outIndices, int stride0, int stride1, int next_n) { // The number of bins in the histogram. - static constexpr int kNumBins = kNumThreadsPerBlock; + static constexpr int kNumBins = 2048; // The top-k width. static constexpr int kTopK = 2048; // The row computed by this block. - int rowIdx = blockIdx.x; + int64_t rowIdx = static_cast(blockIdx.x); // The range of logits within the row. int rowStart = 0; int seq_len = seqLens[rowIdx / next_n]; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; - topk_per_row_kernel( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + auto outIndicesLocal = outIndices + rowIdx * kTopK; + auto logitsLocal = logits + rowIdx * stride0; + + topk_per_row_kernel( + logitsLocal, rowStart, rowEnd, outIndicesLocal, stride1); } } // namespace aiter -void topk_per_row(const torch::Tensor& logits, +void top_k_per_row_prefill(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, @@ -343,6 +593,8 @@ void topk_per_row(const torch::Tensor& logits, int64_t stride0, int64_t stride1) { + constexpr int kSortingAlgorithmThreshold = 12288; + // Compute the results on the device. constexpr int kNumThreadsPerBlock = 512; @@ -351,25 +603,60 @@ void topk_per_row(const torch::Tensor& logits, const hipStream_t stream = at::hip::getCurrentHIPStream(); + int numInsertionBlocks = std::min(static_cast(numRows), kSortingAlgorithmThreshold); + if(stride0 % 4 == 0) - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1)); + { + aiter::topk_per_row + <<>>(logits.data_ptr(), + rowStarts.data_ptr(), + rowEnds.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + 0); + } else - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1)); + { + aiter::topk_per_row + <<>>(logits.data_ptr(), + rowStarts.data_ptr(), + rowEnds.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + 0); + } + + if(numRows > kSortingAlgorithmThreshold) + { + int numRadixBlocks = numRows - kSortingAlgorithmThreshold; + if(stride0 % 4 == 0) + { + aiter::topk_per_row + <<>>(logits.data_ptr(), + rowStarts.data_ptr(), + rowEnds.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + kSortingAlgorithmThreshold); + } + else + { + aiter::topk_per_row + <<>>(logits.data_ptr(), + rowStarts.data_ptr(), + rowEnds.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + kSortingAlgorithmThreshold); + } + } } -void topk_per_row_decode(const torch::Tensor& logits, +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, const torch::Tensor& seqLens, torch::Tensor& indices, @@ -377,24 +664,54 @@ void topk_per_row_decode(const torch::Tensor& logits, int64_t stride0, int64_t stride1) { + constexpr int kSortingAlgorithmThreshold = 12288; // Compute the results on the device. constexpr int kNumThreadsPerBlock = 1024; const hipStream_t stream = at::hip::getCurrentHIPStream(); + const auto numColumns = logits.size(1); - if(stride0 % 4 == 0) - aiter::topk_per_row_decode - <<>>(logits.data_ptr(), - seqLens.data_ptr(), - indices.data_ptr(), - static_cast(stride0), + if(numColumns < kSortingAlgorithmThreshold) + { + if(stride0 % 4 == 0) + { + aiter::topk_per_row_decode + <<>>(logits.data_ptr(), + seqLens.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + static_cast(next_n)); + } + else + { + aiter::topk_per_row_decode + <<>>(logits.data_ptr(), + seqLens.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + static_cast(next_n)); + } + } + else + { + if (stride0 % 4 == 0) + { + aiter::topk_per_row_decode + <<>>(logits.data_ptr(), + seqLens.data_ptr(), + indices.data_ptr(), + static_cast(stride0), static_cast(stride1), static_cast(next_n)); - else - aiter::topk_per_row_decode + } else { + aiter::topk_per_row_decode <<>>(logits.data_ptr(), seqLens.data_ptr(), indices.data_ptr(), static_cast(stride0), static_cast(stride1), static_cast(next_n)); + } +} } diff --git a/csrc/pybind/topk_per_row_pybind.cu b/csrc/pybind/topk_per_row_pybind.cu index 5f031dc155..471c07efc4 100755 --- a/csrc/pybind/topk_per_row_pybind.cu +++ b/csrc/pybind/topk_per_row_pybind.cu @@ -5,5 +5,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - TOPK_PER_ROW_PYBIND; + TOP_K_PER_ROW_PYBIND; } diff --git a/op_tests/test_topk_per_row.py b/op_tests/test_topk_per_row.py index 9d572da744..a4ca0f7a05 100755 --- a/op_tests/test_topk_per_row.py +++ b/op_tests/test_topk_per_row.py @@ -13,12 +13,34 @@ def create_random_logits( row_ends: torch.Tensor, dtype: torch.dtype, seed: int, + data_generation: str = "random", ) -> torch.Tensor: """Create random logits tensor for testing.""" torch.manual_seed(seed) np.random.seed(seed) # Generate logits with some structure to make testing more meaningful - logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") + if data_generation == "random": + logits = torch.randn( + row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda" + ) + elif data_generation == "10LSBits": + top_22_bits_mask = 0xFFFFFC00 + last_10_bits_mask = 0x000003FF + fixed_top_22_bits = 0x3F900000 + # Generate random bits for the last 10 bits + random_bottom_bits = torch.randint( + 0, + 2**10, + (row_starts.shape[0], max(row_ends)), + dtype=torch.int32, + device="cuda", + ) + # Combine: fixed top 22 bits with random last 10 bits + logits_bits = (fixed_top_22_bits & top_22_bits_mask) | ( + random_bottom_bits & last_10_bits_mask + ) + logits = logits_bits.view(dtype) + for i, end in enumerate(row_ends): logits[i, end:] = float("-inf") return logits @@ -91,7 +113,7 @@ def compare_topk_results( @perftest() -def run_topk_per_row( +def run_top_k_per_row_prefill( logits: torch.Tensor, row_starts: torch.Tensor, row_ends: torch.Tensor, @@ -103,7 +125,7 @@ def run_topk_per_row( """ Run the top_k_per_row kernel. """ - return aiter.topk_per_row( + return aiter.top_k_per_row_prefill( logits, row_starts, row_ends, @@ -115,7 +137,7 @@ def run_topk_per_row( @perftest() -def run_topk_per_row_decode( +def run_top_k_per_row_decode( logits: torch.Tensor, next_n: int, seqLens: torch.Tensor, @@ -127,7 +149,7 @@ def run_topk_per_row_decode( """ Run the top_k_per_row kernel. """ - return aiter.topk_per_row_decode( + return aiter.top_k_per_row_decode( logits, next_n, seqLens, @@ -139,9 +161,9 @@ def run_topk_per_row_decode( @benchmark() -def test_topk_per_row(num_rows: int, top_k: int) -> dict: +def test_top_k_per_row_prefill(num_rows: int, top_k: int) -> dict: """ - Test topk_per_row. + Test topk_per_row_prefill. """ ret = {} torch.set_default_device("cuda:0") @@ -154,7 +176,7 @@ def test_topk_per_row(num_rows: int, top_k: int) -> dict: indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run the kernel - _, us = run_topk_per_row( + _, us = run_top_k_per_row_prefill( logits, row_starts, row_ends, @@ -184,14 +206,15 @@ def test_topk_per_row(num_rows: int, top_k: int) -> dict: @benchmark() -def test_topk_per_row_decode( +def test_top_k_per_row_decode( batch_size: int, context_len: int, top_k: int, next_n: int, + data_generation: str = "random", ) -> None: """ - Test topk_per_row with seq_lens tensor. + Test top_k_per_row_decode with seq_lens tensor. """ torch.set_default_device("cuda:0") ret = {} @@ -210,7 +233,7 @@ def test_topk_per_row_decode( indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run the kernel - _, us = run_topk_per_row_decode( + _, us = run_top_k_per_row_decode( logits, next_n, seq_lens, @@ -285,17 +308,28 @@ def test_topk_per_row_decode( e.g.: -n 4""", ) +parser.add_argument( + "-d", + "--data_generation", + type=str, + default=["random"], + choices=["random", "10LSBits"], + nargs="+", + help="""Specify method for generating logits. + e.g.: -d random""", +) + args = parser.parse_args() df = [] for m in args.context_len: for k in args.top_k: - ret = test_topk_per_row(m, k) + ret = test_top_k_per_row_prefill(m, k) df.append(ret) df = pd.DataFrame(df) -aiter.logger.info(f"summary for topk_per_row kernel:\n{df}") +aiter.logger.info(f"summary for top_k_per_row_prefill kernel:\n{df}") df = [] @@ -303,8 +337,8 @@ def test_topk_per_row_decode( for ctx in args.context_len: for k in args.top_k: for n in args.next_n: - ret = test_topk_per_row_decode(m, ctx, k, n) + ret = test_top_k_per_row_decode(m, ctx, k, n) df.append(ret) df = pd.DataFrame(df) -aiter.logger.info(f"summary for topk_per_row_decode kernel:\n{df}") +aiter.logger.info(f"summary for top_k_per_row_decode kernel:\n{df}")