-
Notifications
You must be signed in to change notification settings - Fork 581
feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
289d526
a9f8bc8
c2a0cad
21de9af
595ee1b
9c08d33
81a1afc
5186e5d
f4e1073
08d088a
5c6b9d9
869c0c1
5dc1a28
4950b67
e535e80
39e36dc
e7cca24
ed46ea9
8abb7ca
e040826
43bf624
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,7 +129,16 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena | |
| // 1 - naive PDL | ||
| // 2 - aggressive PDL (implemented only in mha_sm90.cu for now) | ||
| #ifndef ENABLE_PDL | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | ||
| #if __CUDA_ARCH__ == 900 | ||
| #define ENABLE_PDL 2 | ||
| #else | ||
| #define ENABLE_PDL 1 | ||
| #endif | ||
| #else | ||
| /* default for host or older architectures */ | ||
| #define ENABLE_PDL 0 | ||
| #endif | ||
| #endif | ||
|
Comment on lines
116
to
127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Runtime-vs-compile-time PDL mismatch ENABLE_PDL defaults to 1/2 for SMβ₯90 at compile time, but kernels still use this macro while the host now passes a runtime enable_pdl. Host compilation sees ENABLE_PDL==0 (no CUDA_ARCH), so kernels may execute preExit/acqBulk even when enable_pdl=false at launch. This is inconsistent and can lead to invalid usage of programmatic stream serialization. Thread enable_pdl into the kernels and guard PDL intrinsics with the runtime flag (while keeping the arch guards). See follow-up diffs in kernel files below. π€ Prompt for AI Agents |
||
|
|
||
| #ifndef USE_INPUT_KV | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1966,9 +1966,12 @@ __device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gme | |||||||||||||||||||||||||||
| for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { | ||||||||||||||||||||||||||||
| static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == | ||||||||||||||||||||||||||||
| exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); | ||||||||||||||||||||||||||||
| ret[i] = reinterpret_cast<Vec<Vec<float, GmmaAccCoreMat::cols>, | ||||||||||||||||||||||||||||
| exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( | ||||||||||||||||||||||||||||
| gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; | ||||||||||||||||||||||||||||
| uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound); | ||||||||||||||||||||||||||||
| uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols; | ||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||
| for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { | ||||||||||||||||||||||||||||
| ret[i][j] = gmemVec[baseOffset + j]; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
Comment on lines
+1877
to
1883
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outβofβbounds read in loadGmemColWiseVecWithDup for attention sinks gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols. Apply this fix: - uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
- uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
-#pragma unroll
- for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
- ret[i][j] = gmemVec[baseOffset + j];
- }
+ uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
+#pragma unroll
+ for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
+ // Duplicate the same head sink across the 2 columns
+ ret[i][j] = gmemVec[clampedIdx];
+ }π Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||
| return ret; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
@@ -3033,7 +3036,7 @@ void launchHopperF8MHA( | |||||||||||||||||||||||||||
| #if SPEC_DEC | ||||||||||||||||||||||||||||
| SpecDecParams const& specDecParams, | ||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||
| uint32_t* semaphores, void* scratch, cudaStream_t stream) { | ||||||||||||||||||||||||||||
| uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) { | ||||||||||||||||||||||||||||
| if (beamWidth != 1) { | ||||||||||||||||||||||||||||
| throw std::runtime_error("not implemented"); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
@@ -3070,7 +3073,7 @@ void launchHopperF8MHA( | |||||||||||||||||||||||||||
| // nbInputSeqSplit | ||||||||||||||||||||||||||||
| dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; | ||||||||||||||||||||||||||||
| dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; | ||||||||||||||||||||||||||||
| auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); | ||||||||||||||||||||||||||||
| auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); | ||||||||||||||||||||||||||||
| #if USE_PAGED_KV_CACHE | ||||||||||||||||||||||||||||
| uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); | ||||||||||||||||||||||||||||
| auto const dtype = [] { | ||||||||||||||||||||||||||||
|
|
@@ -3191,7 +3194,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads | |||||||||||||||||||||||||||
| #if SPEC_DEC | ||||||||||||||||||||||||||||
| uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, | ||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||
| uint32_t* semaphores, void* scratch, cudaStream_t stream) { | ||||||||||||||||||||||||||||
| uint32_t* semaphores, void* scratch, bool enable_pdl, | ||||||||||||||||||||||||||||
| cudaStream_t stream) { | ||||||||||||||||||||||||||||
| uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { | ||||||||||||||||||||||||||||
| float const factor = 0.25f; | ||||||||||||||||||||||||||||
| return mha::min<uint32_t>( | ||||||||||||||||||||||||||||
|
|
@@ -3207,7 +3211,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads | |||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||
| dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; | ||||||||||||||||||||||||||||
| dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; | ||||||||||||||||||||||||||||
| auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); | ||||||||||||||||||||||||||||
| auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); | ||||||||||||||||||||||||||||
| #if USE_PAGED_KV_CACHE | ||||||||||||||||||||||||||||
| uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); | ||||||||||||||||||||||||||||
| auto const dtype = [] { | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.