-
Notifications
You must be signed in to change notification settings - Fork 825
feat: Fuse shared experts into trtllm_gen moe (fp8) #2625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2255bca
95a01d0
2aa949e
dd679ea
ef6efb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -250,17 +250,28 @@ __global__ void routingMainKernel(KernelParams params) { | |
| auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; | ||
|
|
||
| // write expert idx out already | ||
| auto idxTopK = blockIdx.x * params.mTopK + laneIdx; | ||
| auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx; | ||
| auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx; | ||
| if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { | ||
| PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore), | ||
| static_cast<int16_t>(expertIdx)}; | ||
| params.mPtrTopKPacked[idxTopK] = packedScore; | ||
| } | ||
|
|
||
| if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) { | ||
| PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(1.0F), | ||
| static_cast<int16_t>(params.mNumExperts + laneIdx)}; | ||
| params.mPtrTopKPacked[idxShared] = packedScore; | ||
| } | ||
|
|
||
| if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && | ||
| params.mPtrTopKIds == nullptr) { | ||
| params.mPtrTopKWeights[idxTopK] = finalScore; | ||
| } | ||
|
|
||
| if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) { | ||
| params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -561,6 +572,11 @@ void runImpl(Data& data, void* stream) { | |
| FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, | ||
| "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, | ||
| data.mNumLimitedGroups); | ||
|
|
||
| int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts; | ||
| int const topK = data.mTopK + data.mNumFusedSharedExperts; | ||
| int const numThreadsHist = getMaxNumExperts(numExperts); | ||
|
|
||
| // Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK | ||
| if (data.mNumExperts <= NumKimiK2Experts) { | ||
| FLASHINFER_CHECK( | ||
|
|
@@ -573,6 +589,9 @@ void runImpl(Data& data, void* stream) { | |
| "When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", | ||
| MaxSupportedTopExperts, data.mTopK); | ||
| } | ||
| FLASHINFER_CHECK(topK <= MaxSupportedTopExperts, | ||
| "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, | ||
| topK); | ||
| FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", | ||
| data.mTopK); | ||
| FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, | ||
|
|
@@ -598,14 +617,19 @@ void runImpl(Data& data, void* stream) { | |
| data.mNumExperts / data.mNumExpertGroups <= WarpSize, | ||
| "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", | ||
| data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); | ||
|
|
||
| FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, | ||
| "Number of fused shared experts (%d) must be less than warp size.", | ||
| data.mNumFusedSharedExperts); | ||
|
Comment on lines
+621
to
+623
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for
Comment on lines
+620
to
+623
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This validation is guarded by Suggested fixMove the check out of the + FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+ "Number of fused shared experts (%d) must be less than warp size.",
+ data.mNumFusedSharedExperts);
+
if (data.mNumExpertGroups > 1) {
FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
...);
...
-
- FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
- "Number of fused shared experts (%d) must be less than warp size.",
- data.mNumFusedSharedExperts);
}π€ Prompt for AI Agents |
||
| } | ||
| FLASHINFER_CHECK(data.mNumExperts % 4 == 0, | ||
| "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); | ||
|
|
||
| int const numBlocks = data.mNumTokens; | ||
| int const numThreadsHist = getMaxNumExperts(data.mNumExperts); | ||
|
|
||
| bool const useSingleCluster = data.mNumTokens <= 1024; | ||
| int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster; | ||
| bool const useSingleCluster = | ||
| data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster; | ||
| if (!useSingleCluster) { | ||
| // Reset the global histograms (not used in single-cluster code path). | ||
| // Cover both for the cooperative and two-kernel code paths. | ||
|
|
@@ -629,7 +653,7 @@ void runImpl(Data& data, void* stream) { | |
| int const numBlocksCoop = 128; | ||
|
|
||
| // Maximum number of tokens supported by the kernel using a cooperative launch. | ||
| int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; | ||
| int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK; | ||
| if (data.mPtrTopKIds == nullptr) { | ||
| int const numThreadsMain = | ||
| max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); | ||
|
|
@@ -645,6 +669,12 @@ void runImpl(Data& data, void* stream) { | |
| stream, data.mNumExpertGroups > 1); | ||
| } | ||
|
|
||
| if (data.mNumFusedSharedExperts > 0) { | ||
| data.mNumExperts += data.mNumFusedSharedExperts; | ||
| data.mTopK += data.mNumFusedSharedExperts; | ||
| data.mNumLocalExperts += data.mNumFusedSharedExperts; | ||
| } | ||
|
Comment on lines
+672
to
+676
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating
You should calculate the total expert count and top-k at the beginning of |
||
|
|
||
| if (data.mPtrPermutedIdxSize != nullptr) { | ||
| if (useSingleCluster) { | ||
| LAUNCH_ROUTING_DEEPSEEK(data, | ||
|
|
@@ -659,7 +689,7 @@ void runImpl(Data& data, void* stream) { | |
| /*smemSize=*/0, // No dynamic smem | ||
| stream, data.mNumExpertGroups > 1); | ||
| } else { | ||
| const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; | ||
| const int32_t expandedIdxSize = data.mNumTokens * topK; | ||
| const int32_t histogramEltsPerBlock = 8 * numThreadsHist; | ||
| const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.