Skip to content

Commit 74bda86

Browse files
committed
inital update fused DeepSeek routing kernel
1 parent 54101e9 commit 74bda86

File tree

10 files changed

+1022
-0
lines changed

10 files changed

+1022
-0
lines changed

csrc/fused_moe/moeTopKFuncs.cuh

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
2+
/*
3+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
#ifndef TRTLLM_MOETOPKFUNCS_CUH_H
19+
#define TRTLLM_MOETOPKFUNCS_CUH_H
20+
21+
#include <cooperative_groups.h>
22+
#include <cooperative_groups/reduce.h>
23+
#include <cub/cub.cuh>
24+
25+
#include "tensorrt_llm/kernels/archCondition.h"
26+
27+
namespace tensorrt_llm::kernels
28+
{
29+
30+
namespace reduce_topk
31+
{
32+
namespace cg = cooperative_groups;
33+
static constexpr int kWARP_SIZE = 32;
34+
static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
35+
36+
template <typename T_>
37+
struct TopKRedType
38+
{
39+
using T = T_;
40+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>
41+
|| std::is_same_v<T, int>,
42+
"Top K reduction only implemented for int, float, float16 and bfloat16");
43+
44+
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
45+
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
46+
47+
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
48+
static constexpr int kMaxIdx = 65535;
49+
TypeCmp compValIdx;
50+
51+
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
52+
{
53+
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
54+
TypeCmp compactTmp = valueBits;
55+
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
56+
// Use 65535 minus idx to give higher priority to elements with smaller indices.
57+
return compactTmp;
58+
}
59+
60+
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
61+
{
62+
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
63+
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
64+
65+
auto compactTmp = cmp >> kMoveBits;
66+
auto valueBits
67+
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
68+
value = reinterpret_cast<T&>(valueBits);
69+
}
70+
71+
__host__ __device__ TopKRedType() = default;
72+
73+
__host__ __device__ TopKRedType(T val, int32_t idx)
74+
: compValIdx(makeCmpVal(val, idx))
75+
{
76+
}
77+
78+
__host__ __device__ operator TypeCmp() const noexcept
79+
{
80+
return compValIdx;
81+
}
82+
83+
__device__ inline TypeCmp reduce(cg::thread_block_tile<kWARP_SIZE> const& warp)
84+
{
85+
if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
86+
{
87+
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
88+
}
89+
else
90+
{
91+
TypeCmp result;
92+
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
93+
return result;
94+
}
95+
}
96+
};
97+
98+
////////////////////////////////////////////////////////////////////////////////////////////////////
99+
100+
template <int K_, bool Enable_>
101+
struct TopKIdx
102+
{
103+
// by default, empty
104+
};
105+
106+
template <int K_>
107+
struct TopKIdx<K_, true>
108+
{
109+
static constexpr int K = K_;
110+
int32_t val[K];
111+
};
112+
113+
////////////////////////////////////////////////////////////////////////////////////////////////////
114+
115+
#define TOPK_SWAP(I, J) \
116+
{ \
117+
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
118+
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
119+
topK[I].compValIdx = pairMax; \
120+
topK[J].compValIdx = pairMin; \
121+
}
122+
123+
template <int N, typename RedType>
124+
struct Sort;
125+
126+
template <typename RedType>
127+
struct Sort<1, RedType>
128+
{
129+
static __device__ void run(RedType* topK) {}
130+
};
131+
132+
template <typename RedType>
133+
struct Sort<2, RedType>
134+
{
135+
static __device__ void run(RedType* topK)
136+
{
137+
TOPK_SWAP(0, 1);
138+
}
139+
};
140+
141+
template <typename RedType>
142+
struct Sort<3, RedType>
143+
{
144+
static __device__ void run(RedType* topK)
145+
{
146+
TOPK_SWAP(0, 1);
147+
TOPK_SWAP(1, 2);
148+
TOPK_SWAP(0, 1);
149+
}
150+
};
151+
152+
template <typename RedType>
153+
struct Sort<4, RedType>
154+
{
155+
static __device__ void run(RedType* topK)
156+
{
157+
TOPK_SWAP(0, 2);
158+
TOPK_SWAP(1, 3);
159+
TOPK_SWAP(0, 1);
160+
TOPK_SWAP(2, 3);
161+
TOPK_SWAP(1, 2);
162+
}
163+
};
164+
165+
template <int K, typename Type>
166+
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
167+
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
168+
{
169+
static_assert(K > 0, "Top K must have K > 0");
170+
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
171+
using RedType = TopKRedType<Type>;
172+
RedType topK{value, idx};
173+
typename RedType::TypeCmp packedMax{};
174+
#pragma unroll
175+
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
176+
{
177+
topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
178+
// get the next largest value
179+
packedMax = topK.reduce(warp);
180+
RedType::unpack(out[kk], outIdx[kk], packedMax);
181+
}
182+
};
183+
184+
template <int K, typename Type, int N, bool IsSorted = false>
185+
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
186+
Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K)
187+
{
188+
static_assert(K > 0, "Top K must have K > 0");
189+
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
190+
static_assert(N > 0, "Top K must have N > 0");
191+
static_assert(N < 5, "Only support candidates number less than or equal to 128");
192+
using RedType = TopKRedType<Type>;
193+
RedType topK[N];
194+
#pragma unroll
195+
for (int nn = 0; nn < N; ++nn)
196+
{
197+
topK[nn] = RedType{value[nn], idx[nn]};
198+
}
199+
200+
if constexpr (!IsSorted)
201+
{
202+
Sort<N, RedType>::run(topK);
203+
}
204+
typename RedType::TypeCmp packedMax{};
205+
#pragma unroll
206+
for (int kk = 0; kk < actualK; ++kk)
207+
{
208+
bool update = kk > 0 && packedMax == topK[0].compValIdx;
209+
#pragma unroll
210+
for (int nn = 0; nn < N; ++nn)
211+
{
212+
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
213+
}
214+
// get the next largest value
215+
packedMax = topK[0].reduce(warp);
216+
RedType::unpack(out[kk], outIdx[kk], packedMax);
217+
}
218+
};
219+
220+
template <int K, typename Type, int N>
221+
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
222+
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
223+
{
224+
static_assert(K > 0, "Top K must have K > 0");
225+
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
226+
static_assert(N > 0, "Top K must have N > 0");
227+
static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512");
228+
static_assert(
229+
N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4");
230+
using RedType = TopKRedType<Type>;
231+
232+
if constexpr (N <= 4)
233+
{
234+
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue, actualK);
235+
}
236+
else
237+
{
238+
239+
constexpr int numLoops = N / 4;
240+
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;
241+
242+
Type topKBufferValue[numResults];
243+
int32_t topKBufferIdx[numResults];
244+
int32_t laneIdx = threadIdx.x % kWARP_SIZE;
245+
246+
for (int ii = 0; ii < numResults; ++ii)
247+
{
248+
topKBufferValue[ii] = minValue;
249+
topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct
250+
}
251+
for (int loop = 0; loop < numLoops; ++loop)
252+
{
253+
int start = loop * 4;
254+
Type topKValue[K];
255+
int32_t topKIdx[K];
256+
Type inValue[4];
257+
int32_t inIdx[4];
258+
for (int i = 0; i < 4; ++i)
259+
{
260+
inValue[i] = value[start + i];
261+
inIdx[i] = idx[start + i];
262+
}
263+
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK);
264+
int inOffset = laneIdx % K;
265+
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K)
266+
{
267+
topKBufferValue[0] = topKValue[inOffset];
268+
topKBufferIdx[0] = topKIdx[inOffset];
269+
}
270+
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE)))
271+
{
272+
topKBufferValue[1] = topKValue[inOffset];
273+
topKBufferIdx[1] = topKIdx[inOffset];
274+
}
275+
}
276+
277+
reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK);
278+
}
279+
};
280+
281+
#undef TOPK_SWAP
282+
283+
} // namespace reduce_topk
284+
} // namespace tensorrt_llm::kernels
285+
#endif // TRTLLM_MOETOPKFUNCS_CUH_H
286+

0 commit comments

Comments
 (0)