1+
2+ /*
3+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4+ *
5+ * Licensed under the Apache License, Version 2.0 (the "License");
6+ * you may not use this file except in compliance with the License.
7+ * You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+ #pragma once
18+ #ifndef TRTLLM_MOETOPKFUNCS_CUH_H
19+ #define TRTLLM_MOETOPKFUNCS_CUH_H
20+
21+ #include < cooperative_groups.h>
22+ #include < cooperative_groups/reduce.h>
23+ #include < cub/cub.cuh>
24+
25+ #include " tensorrt_llm/kernels/archCondition.h"
26+
27+ namespace tensorrt_llm ::kernels
28+ {
29+
30+ namespace reduce_topk
31+ {
32+ namespace cg = cooperative_groups;
33+ static constexpr int kWARP_SIZE = 32 ;
34+ static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10 >;
35+
36+ template <typename T_>
37+ struct TopKRedType
38+ {
39+ using T = T_;
40+ static_assert (std::is_same_v<T, float > || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>
41+ || std::is_same_v<T, int >,
42+ " Top K reduction only implemented for int, float, float16 and bfloat16" );
43+
44+ using TypeCmp = std::conditional_t <sizeof (T) == 4 , uint64_t , uint32_t >;
45+ using IdxT = std::conditional_t <sizeof (T) == 4 , int32_t , int16_t >;
46+
47+ static constexpr int kMoveBits = (sizeof (T) == 4 ) ? 32 : 16 ;
48+ static constexpr int kMaxIdx = 65535 ;
49+ TypeCmp compValIdx;
50+
51+ static __host__ __device__ inline TypeCmp makeCmpVal (T val, int32_t idx = 0 )
52+ {
53+ auto valueBits = cub::Traits<T>::TwiddleIn (reinterpret_cast <typename cub::Traits<T>::UnsignedBits&>(val));
54+ TypeCmp compactTmp = valueBits;
55+ compactTmp = (compactTmp << kMoveBits ) | (0xFFFF & (kMaxIdx - idx));
56+ // Use 65535 minus idx to give higher priority to elements with smaller indices.
57+ return compactTmp;
58+ }
59+
60+ static __host__ __device__ void unpack (T& value, int32_t & index, TypeCmp cmp)
61+ {
62+ // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
63+ index = kMaxIdx - static_cast <int32_t >((cmp & 0xFFFF ));
64+
65+ auto compactTmp = cmp >> kMoveBits ;
66+ auto valueBits
67+ = cub::Traits<T>::TwiddleOut (reinterpret_cast <typename cub::Traits<T>::UnsignedBits&>(compactTmp));
68+ value = reinterpret_cast <T&>(valueBits);
69+ }
70+
71+ __host__ __device__ TopKRedType () = default;
72+
73+ __host__ __device__ TopKRedType (T val, int32_t idx)
74+ : compValIdx(makeCmpVal(val, idx))
75+ {
76+ }
77+
78+ __host__ __device__ operator TypeCmp () const noexcept
79+ {
80+ return compValIdx;
81+ }
82+
83+ __device__ inline TypeCmp reduce (cg::thread_block_tile<kWARP_SIZE > const & warp)
84+ {
85+ if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof (TypeCmp) == 8 )
86+ {
87+ return cg::reduce (warp, compValIdx, cg::greater<TypeCmp>{});
88+ }
89+ else
90+ {
91+ TypeCmp result;
92+ asm (" redux.sync.max.u32 %0, %1, 0xffffffff;\n " : " =r" (result) : " r" (compValIdx));
93+ return result;
94+ }
95+ }
96+ };
97+
98+ // //////////////////////////////////////////////////////////////////////////////////////////////////
99+
100+ template <int K_, bool Enable_>
101+ struct TopKIdx
102+ {
103+ // by default, empty
104+ };
105+
106+ template <int K_>
107+ struct TopKIdx <K_, true >
108+ {
109+ static constexpr int K = K_;
110+ int32_t val[K];
111+ };
112+
113+ // //////////////////////////////////////////////////////////////////////////////////////////////////
114+
115+ #define TOPK_SWAP (I, J ) \
116+ { \
117+ auto pairMin = min (topK[I].compValIdx , topK[J].compValIdx ); \
118+ auto pairMax = max (topK[I].compValIdx , topK[J].compValIdx ); \
119+ topK[I].compValIdx = pairMax; \
120+ topK[J].compValIdx = pairMin; \
121+ }
122+
123+ template <int N, typename RedType>
124+ struct Sort ;
125+
126+ template <typename RedType>
127+ struct Sort <1 , RedType>
128+ {
129+ static __device__ void run (RedType* topK) {}
130+ };
131+
132+ template <typename RedType>
133+ struct Sort <2 , RedType>
134+ {
135+ static __device__ void run (RedType* topK)
136+ {
137+ TOPK_SWAP (0 , 1 );
138+ }
139+ };
140+
141+ template <typename RedType>
142+ struct Sort <3 , RedType>
143+ {
144+ static __device__ void run (RedType* topK)
145+ {
146+ TOPK_SWAP (0 , 1 );
147+ TOPK_SWAP (1 , 2 );
148+ TOPK_SWAP (0 , 1 );
149+ }
150+ };
151+
152+ template <typename RedType>
153+ struct Sort <4 , RedType>
154+ {
155+ static __device__ void run (RedType* topK)
156+ {
157+ TOPK_SWAP (0 , 2 );
158+ TOPK_SWAP (1 , 3 );
159+ TOPK_SWAP (0 , 1 );
160+ TOPK_SWAP (2 , 3 );
161+ TOPK_SWAP (1 , 2 );
162+ }
163+ };
164+
165+ template <int K, typename Type>
166+ __forceinline__ __device__ void reduceTopK (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K],
167+ int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
168+ {
169+ static_assert (K > 0 , " Top K must have K > 0" );
170+ static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
171+ using RedType = TopKRedType<Type>;
172+ RedType topK{value, idx};
173+ typename RedType::TypeCmp packedMax{};
174+ #pragma unroll
175+ for (int kk = 0 ; kk < actualK; ++kk) // @todo: check if actualK is correct
176+ {
177+ topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
178+ // get the next largest value
179+ packedMax = topK.reduce (warp);
180+ RedType::unpack (out[kk], outIdx[kk], packedMax);
181+ }
182+ };
183+
184+ template <int K, typename Type, int N, bool IsSorted = false >
185+ __device__ void reduceTopKFunc (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K], int32_t (&outIdx)[K],
186+ Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K)
187+ {
188+ static_assert (K > 0 , " Top K must have K > 0" );
189+ static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
190+ static_assert (N > 0 , " Top K must have N > 0" );
191+ static_assert (N < 5 , " Only support candidates number less than or equal to 128" );
192+ using RedType = TopKRedType<Type>;
193+ RedType topK[N];
194+ #pragma unroll
195+ for (int nn = 0 ; nn < N; ++nn)
196+ {
197+ topK[nn] = RedType{value[nn], idx[nn]};
198+ }
199+
200+ if constexpr (!IsSorted)
201+ {
202+ Sort<N, RedType>::run (topK);
203+ }
204+ typename RedType::TypeCmp packedMax{};
205+ #pragma unroll
206+ for (int kk = 0 ; kk < actualK; ++kk)
207+ {
208+ bool update = kk > 0 && packedMax == topK[0 ].compValIdx ;
209+ #pragma unroll
210+ for (int nn = 0 ; nn < N; ++nn)
211+ {
212+ topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1 ] : topK[nn];
213+ }
214+ // get the next largest value
215+ packedMax = topK[0 ].reduce (warp);
216+ RedType::unpack (out[kk], outIdx[kk], packedMax);
217+ }
218+ };
219+
220+ template <int K, typename Type, int N>
221+ __forceinline__ __device__ void reduceTopK (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K],
222+ int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
223+ {
224+ static_assert (K > 0 , " Top K must have K > 0" );
225+ static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
226+ static_assert (N > 0 , " Top K must have N > 0" );
227+ static_assert (N <= 16 , " Only support candidates number less than or equal to 16*32=512" );
228+ static_assert (
229+ N <= 4 || N % 4 == 0 , " Only support candidates number is a multiple of 4*32=128 or less than or equal to 4" );
230+ using RedType = TopKRedType<Type>;
231+
232+ if constexpr (N <= 4 )
233+ {
234+ reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue, actualK);
235+ }
236+ else
237+ {
238+
239+ constexpr int numLoops = N / 4 ;
240+ constexpr int numResults = (numLoops * K - 1 ) / kWARP_SIZE + 1 ;
241+
242+ Type topKBufferValue[numResults];
243+ int32_t topKBufferIdx[numResults];
244+ int32_t laneIdx = threadIdx .x % kWARP_SIZE ;
245+
246+ for (int ii = 0 ; ii < numResults; ++ii)
247+ {
248+ topKBufferValue[ii] = minValue;
249+ topKBufferIdx[ii] = ii * kWARP_SIZE - 1 ; // @todo: check if this is correct
250+ }
251+ for (int loop = 0 ; loop < numLoops; ++loop)
252+ {
253+ int start = loop * 4 ;
254+ Type topKValue[K];
255+ int32_t topKIdx[K];
256+ Type inValue[4 ];
257+ int32_t inIdx[4 ];
258+ for (int i = 0 ; i < 4 ; ++i)
259+ {
260+ inValue[i] = value[start + i];
261+ inIdx[i] = idx[start + i];
262+ }
263+ reduceTopKFunc<K, Type, 4 >(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK);
264+ int inOffset = laneIdx % K;
265+ if (laneIdx >= loop * K && laneIdx < (loop + 1 ) * K)
266+ {
267+ topKBufferValue[0 ] = topKValue[inOffset];
268+ topKBufferIdx[0 ] = topKIdx[inOffset];
269+ }
270+ if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE )))
271+ {
272+ topKBufferValue[1 ] = topKValue[inOffset];
273+ topKBufferIdx[1 ] = topKIdx[inOffset];
274+ }
275+ }
276+
277+ reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK);
278+ }
279+ };
280+
281+ #undef TOPK_SWAP
282+
283+ } // namespace reduce_topk
284+ } // namespace tensorrt_llm::kernels
285+ #endif // TRTLLM_MOETOPKFUNCS_CUH_H
286+
0 commit comments