-
Notifications
You must be signed in to change notification settings - Fork 578
feat: add xqa fp8 mha and fp8 kv cache #1769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer library by integrating FP8 support for both Multi-Head Attention computations and the Key-Value cache within the XQA framework. These changes are primarily aimed at leveraging the advanced capabilities of NVIDIA Hopper GPUs (SM90+) to achieve substantial performance and memory efficiency gains. The implementation includes new CUDA kernels utilizing GMMA and TMA, along with Python-side modifications to enable configurable FP8 execution paths, ensuring that users can opt into these optimizations while maintaining numerical stability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer library by introducing support for FP8 Multi-Head Attention (MHA) and FP8 Key-Value (KV) cache. These additions leverage advanced features of NVIDIA Hopper GPUs, such as GMMA and TMA, to achieve higher performance and memory efficiency for large language model inference. The changes span the CUDA C++ backend, including new kernel implementations and memory management utilities, as well as updates to the Python AOT compilation and testing framework to ensure robust integration and validation of the new FP8 capabilities. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache within the XQA kernel, primarily targeting the NVIDIA Hopper architecture. This is a significant feature addition, enabled by new CUDA primitives for Hopper's Tensor Memory Access (TMA) and Grace Hopper MMA (GMMA) instructions. The changes are well-implemented, including new CUDA headers for hardware abstraction, a dispatch mechanism for the new FP8 kernel path, and corresponding updates to the Python build system and tests. The tests have been thoughtfully adjusted with relaxed tolerances for FP8 precision. My review includes one suggestion to refactor a small piece of duplicated code to enhance maintainability.
csrc/xqa/xqa_wrapper.cu
Outdated
| if (run_fp8_mha) { | ||
| launchHopperF8MHAFlashInfer( | ||
| multiProcessorCount, nbKHeads, slidingWinSize, qScale, | ||
| reinterpret_cast<OutputHead*>(output.data_ptr()), | ||
| #if LOW_PREC_OUTPUT | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), | ||
| maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen, | ||
| reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| #if SPEC_DEC | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| } else { | ||
| launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, | ||
| reinterpret_cast<OutputHead*>(output.data_ptr()), | ||
| #if LOW_PREC_OUTPUT | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), | ||
| maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| #if SPEC_DEC | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if and else blocks contain identical arguments passed to two different functions (launchHopperF8MHAFlashInfer and launchMHAFlashInfer). This code duplication can be reduced to improve maintainability. Since both functions share the same signature, you can use a function pointer to select the appropriate kernel and then call it once with the common set of arguments.
using mha_launcher_t = decltype(&launchMHAFlashInfer);
mha_launcher_t launcher = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
launcher(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the XQA (eXtended Query Attention) kernels by introducing support for FP8 Multi-Head Attention and FP8 Key-Value cache. These changes are designed to optimize performance and memory usage on NVIDIA Hopper (SM90+) GPUs through low-level CUDA programming, including asynchronous tensor memory access and matrix operations. The integration ensures that the system can efficiently handle lower precision data types, with comprehensive testing to maintain accuracy within acceptable bounds. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache in the XQA kernels, targeting Hopper architecture for performance improvements. The changes include new low-level CUDA files (gmma.cuh, tma.h, tensorMap.cpp) with Hopper-specific WGMMA and TMA instructions, a new FP8 MHA kernel entry point, and updates to the AOT compilation scripts and Python wrappers to handle the new FP8 variants. The tests have also been updated to include FP8 configurations and use a more lenient assertion method to account for precision differences.
My review focuses on code maintainability and clarity. I've suggested refactoring a duplicated code block in the C++ wrapper to improve readability and proposed adding a comment in the Python tests to clarify a magic number used for data scaling. Overall, the changes are well-structured and the addition of FP8 support is a valuable performance enhancement.
csrc/xqa/xqa_wrapper.cu
Outdated
| if (run_fp8_mha) { | ||
| launchHopperF8MHAFlashInfer( | ||
| multiProcessorCount, nbKHeads, slidingWinSize, qScale, | ||
| reinterpret_cast<OutputHead*>(output.data_ptr()), | ||
| #if LOW_PREC_OUTPUT | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), | ||
| maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen, | ||
| reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| #if SPEC_DEC | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| } else { | ||
| launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, | ||
| reinterpret_cast<OutputHead*>(output.data_ptr()), | ||
| #if LOW_PREC_OUTPUT | ||
| reinterpret_cast<float const*>(rcpOutScale.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr, | ||
| reinterpret_cast<GMemCacheHead*>(pool.data_ptr()), | ||
| reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), | ||
| maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, | ||
| reinterpret_cast<float const*>(kvCacheScale.data_ptr()), | ||
| #if SPEC_DEC | ||
| qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()), | ||
| reinterpret_cast<MaskType const*>(mask.data_ptr()), | ||
| #endif | ||
| reinterpret_cast<uint32_t*>(semaphores.data_ptr()), | ||
| reinterpret_cast<void*>(scratch.data_ptr()), stream); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a large block of duplicated code for launching the MHA kernels. The only difference between the if and else blocks is the function being called (launchHopperF8MHAFlashInfer vs. launchMHAFlashInfer). This could be refactored to improve maintainability and reduce redundancy.
Consider using a function pointer to select the kernel, and then make a single call. This would make the code cleaner and easier to manage if more arguments are added in the future.
For example:
void (*mha_func)(uint32_t, uint32_t, ...); // Using a function pointer type alias
if (run_fp8_mha) {
mha_func = &launchHopperF8MHAFlashInfer;
} else {
mha_func = &launchMHAFlashInfer;
}
mha_func(
multiProcessorCount,
nbKHeads,
slidingWinSize,
// ... other arguments
); using mha_func_t = void (*)(uint32_t, uint32_t, uint32_t, float, OutputHead*,
#if LOW_PREC_OUTPUT
float const*,
#endif
InputHead const*, float const*, GMemCacheHead*,
KVCachePageIndex const*, uint32_t, uint32_t const*, uint32_t,
float const* __restrict__,
#if SPEC_DEC
uint32_t, uint32_t const*, MaskType const*,
#endif
uint32_t*, void*, cudaStream_t);
mha_func_t mha_func = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
tests/attention/test_xqa.py
Outdated
| if fp8_kv_cache: | ||
| cache_heads /= 4.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value 4.0 is used to scale down the cache_heads tensor when fp8_kv_cache is enabled. This appears to be a magic number. To improve code clarity and maintainability, please add a comment explaining the rationale for this specific scaling factor. For example, explaining that it's to prevent overflow and how 4.0 was determined would be very helpful for future readers.
| if fp8_kv_cache: | |
| cache_heads /= 4.0 | |
| if fp8_kv_cache: | |
| # Scale down the cache heads to keep values within the representable range of FP8 | |
| # and prevent overflow during computation. The factor 4.0 is chosen empirically. | |
| cache_heads /= 4.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache for Cross-Query Attention (XQA), targeting NVIDIA's Hopper architecture. This is a significant feature addition that leverages low-level hardware capabilities like TMA and GMMA for performance. The changes include new CUDA files for these Hopper-specific features, along with updates to the Python build system and tests to accommodate the new configurations. The review identified a critical bug in the new TMA storeAsync implementation and a high-severity correctness issue related to the handling of masked values in the softmax computation.
| : "memory"); | ||
| } else if constexpr (nbDims == 5) { | ||
| asm volatile( | ||
| "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], " | ||
| "[%6];\n" | ||
| : | ||
| : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), | ||
| "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) | ||
| : "memory"); | ||
| } else { | ||
| static_assert(nbDims >= 1 && nbDims <= 5); | ||
| } | ||
| } | ||
|
|
||
| __device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) { | ||
| asm volatile( | ||
| "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), | ||
| "l"(ptr) | ||
| : "memory"); | ||
| } | ||
|
|
||
| __device__ inline void commitGroup() { | ||
| asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a copy-paste error in the storeAsync template function. For nbDims of 3, 4, and 5, the inline assembly instruction is cp.async.bulk.tensor.2d..., but it should be cp.async.bulk.tensor.3d..., cp.async.bulk.tensor.4d..., and cp.async.bulk.tensor.5d... respectively. This will lead to incorrect memory access patterns and likely data corruption for higher-dimensional tensors.
else if constexpr (nbDims == 3)
{
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n"
:
: "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
"l"(__cvta_generic_to_shared(src))
: "memory");
}
else if constexpr (nbDims == 4)
{
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n"
:
: "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
"r"(offset[3]), "l"(__cvta_generic_to_shared(src))
: "memory");
}
else if constexpr (nbDims == 5)
{
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n"
:
: "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
"r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
: "memory");
}| ? true | ||
| : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); | ||
| acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY; | ||
| acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using safeInitRowMax for masked elements can lead to incorrect results. When an entire row/sequence is masked, all attention scores become safeInitRowMax. In the softmax computation, maxVal also becomes safeInitRowMax, and exp(score - maxVal) evaluates to 1 for all masked positions. This results in a uniform attention distribution over masked tokens, and the output becomes the average of values in V, instead of zero.
A correct implementation should ensure that the softmax output for masked tokens is zero. If the entire row is masked, the final output should also be zero. This might require changes in the softmax function to handle safeInitRowMax specially, and in the final normalization step to handle a row sum of zero.
csrc/flashinfer_xqa_binding.cu
Outdated
| @@ -16,8 +16,8 @@ | |||
|
|
|||
| #include "pytorch_extension_utils.h" | |||
|
|
|||
| void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, | |||
| double qScale, at::Tensor output, | |||
| void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of making this a flag, could we pass a dtype?
Same for the other places where we pass:
- the type of the input (only bf16 and fp16 supported I think)
- the type of the kv-cache (fp8 or bf16)
- the type in which we perform arithmetic (the same type as the kv-cache I think?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it is passing dtype in flashinfer/flashinfer/xqa.py
|
|
||
| inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) | ||
| inline constexpr float safeInitRowMax = -1e+30F; | ||
| // we used an optimization where exp(x-rowMax) is computed as: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's interesting: what were the symptoms of the instability? Accuracy loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from NVIDIA/TensorRT-LLM@c1aa7f3, you may ask the author, I am not sure about this question.
tests/attention/test_xqa.py
Outdated
| @@ -354,4 +364,21 @@ def cache_head_at( | |||
| kernel_output = output[req][b][ | |||
| idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size | |||
| ].to(torch.float32) | |||
| assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01) | |||
| if fp8_kv_cache or run_fp8_mha: | |||
| atol = 0.05 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you tune this tolerance? Can it be smaller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From 0.01 to 0.05, add 0.01 every step. And it can't be smaller from my test.
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
flashinfer/jit/xqa.py
Outdated
| else: | ||
| flag_sliding_window = ["-DSLIDING_WINDOW=0"] | ||
|
|
||
| if sm_version == 100: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add SM103 support by targeting SM100f instead of SM100a?
And similarly, can we add SM121 support by targeting SM120f instead of SM120a?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference in those archs? I mean SM103/SM100f/SM100a, and SM121/SM120f/SM120a.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "a" means "arch-specific" and the "f" means "family".
SM100 and SM103 are in the "SM100f family".
SM100a (SM100 arch-specific) will only run on SM100 devices.
But I believe SM103 devices have a superset of the SM100 features, and therefore if you target SM100f instead of SM100a during compilation, your cubin will be able to run on SM103 as well, without any loss of optimization on either device. So I think it's strictly better than targeting SM103a.
SM121 and SM120 have a similar story: it's better to target SM120f as a compilation target, yielding a cubin that will run on both SM120 and SM121 devices without any compromise to performance.
See this documentation for details: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#family-specific-features
@aleozlx can you confirm my understanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes sm_100f is known as family specific or family conditional. this is important to enhance device compatibility in a sort of out of the box fashion.
i'd like to point out a few things tho based on my experience:
- with strictly jit compilation where the premise is that at runtime fewer devices are in the compatibility question, the arch conditionals may be the safest way to target the instruction supersets. (when the target is available to compile by the toolkit at the time of implementation)
- family conditionals are important for compatibility story (indeed aligning with your understanding conceptually) but it is not without inherent engineering complexity. from an engineer's perspective i naturally experience a slightly more complicated story at the levels beneath. i'll spare the details but leave it as a reliability intuition (in the context of jit).
however, i want to bring this CompilationContext.get_nvcc_flags_list() up for consideration to not hard code any guidance either way but have it abstracted so we can adjust if situation changes. briefly, how this works is for each op/backend if you whitelist your supported targets, this function shall serve as the mapping to provide the recommended flags.
we can put in sm103 support (likely fine for attn) supposing our cicd will catch the issue if not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, let me clarify my guidance then:
- Since SM120 and SM121 are identical architectures, naively I would think that supporting both is only marginally harder than supporting one.
- The story seems even a little harder for SM100 and SM103, because they are not identical, but if you don't care about SM103-specific features, I would think the marginal effort of supporting SM103 shouldn't be massive.
- Therefore the default design for any solution for SM100 and SM120 should at least try to include SM103 and SM121 support, or should at least be designed with SM103 and SM121 in mind, even if some of the details are left as a future TODO.
- What compilation targets you use, and how you architect the code to query for those compilation targets, is an engineering implementation detail, and I probably shouldn't be opinionated about that. :)
The problem here is that the PR doesn't address (3), maybe because @qsang-nv didn't know about (1) and (2). I'm only proposing that we re-think the design here with the above in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation! I've added support for sm100f and sm121a.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rely on compilation context instead:
flashinfer/flashinfer/compilation_context.py
Lines 50 to 52 in d4a3ff4
| def get_nvcc_flags_list( | |
| self, supported_major_versions: list[int] = None | |
| ) -> list[str]: |
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds MLA (SM120/121) MLA-specific MHA path, FP8 KV-cache support, conditional paged KV-cache layout and tensor-map/TMA async APIs, gmma MatDesc/MMA helpers, MLA kernels/launchers, Python dtype-driven JIT wiring and tests, plus docs entry for flashinfer.xqa. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as Python Client
participant Gen as JIT/AOT Gen
participant NVCC as NVCC
participant Module as Compiled Module
participant Bind as C++ Binding
participant Wrapper as xqa_wrapper / xqa_wrapper_mla
participant Launcher as MHA Dispatcher
participant MLA as MLA Kernel (SM120/121)
participant Hopper as Hopper F8 Kernel
participant Std as Std MHA Kernel
Py->>Gen: request module (input_dtype, kv_cache_dtype, page_size, head_dim, sm_versions...)
Gen->>NVCC: compile sources (tensorMap.cpp, gmma.cuh, tma.h, mla_sm120.cu, xqa_wrapper.cu, ...)
NVCC-->>Gen: compiled Module
Gen-->>Py: module handle
Py->>Module: call xqa(...) or xqa_mla(...)
Module->>Bind: call exported wrapper
Bind->>Wrapper: forward args (run_sm90_fp8_mha?, k_cache, v_cache or pool, page_table, seq_lens, semaphores, workspace)
Wrapper->>Launcher: select launcher (MLA / HopperF8 / Std) based on MLA macro and run_sm90_fp8_mha
alt MLA path
Launcher->>MLA: launchMLAFlashInfer(...)
else if run_sm90_fp8_mha
Launcher->>Hopper: launchHopperF8MHAFlashInfer(...)
else
Launcher->>Std: launchMHAFlashInfer(...)
end
Launcher-->>Wrapper: kernel completes
Wrapper-->>Bind: return
Bind-->>Module: return
Module-->>Py: return
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_xqa.py (1)
28-33: Avoid GPU property access at import time.Accessing
torch.cuda.get_device_properties(0)during import can break test discovery on CPU/multi-device envs. Move it inside the test after skip checks.-props = torch.cuda.get_device_properties(0) -sm_count = props.multi_processor_count +sm_count = None # set inside test to avoid import-time CUDA queries
♻️ Duplicate comments (4)
csrc/xqa/tma.h (1)
208-229: Bug: 3D/4D/5D storeAsync use 2D opcode (will corrupt data).The cp.async store paths for nbDims 3–5 incorrectly use tensor.2d. Must be tensor.3d/4d/5d.
Apply this diff:
@@ } else if constexpr (nbDims == 3) { - asm volatile( - "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" - : - : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), - "r"(offset[2]), "l"(__cvta_generic_to_shared(src)) - : "memory"); + asm volatile( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "l"(__cvta_generic_to_shared(src)) + : "memory"); } else if constexpr (nbDims == 4) { - asm volatile( - "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" - : - : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), - "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) - : "memory"); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); } else if constexpr (nbDims == 5) { - asm volatile( - "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], " - "[%6];\n" - : - : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), - "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) - : "memory"); + asm volatile( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n" + : + : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); }csrc/xqa/mha.cu (1)
479-479: Critical: Masked position initialization may cause incorrect attention output.Using
safeInitRowMaxfor masked elements can lead to incorrect results. When an entire row is masked, all scores becomesafeInitRowMax, and in softmax computationexp(score - maxVal)evaluates to1for all positions, producing a uniform distribution over masked tokens instead of zero output.As noted in the previous review, the softmax function should handle
safeInitRowMaxspecially to ensure masked tokens contribute zero to the output, or alternatively masked positions should use a different sentinel value that results in zero after softmax.csrc/flashinfer_xqa_binding.cu (1)
19-21: Prefer a typed precision enum over a new boolean flag.Using
bool run_fp8_mhadoes not scale. Replace with a small enum (e.g.,int32_t precision: {bf16, fp16, fp8}) and, similarly, pass/cache element/compute dtypes as enums instead of separate flags. This reduces combinatorial overload and ABI churn.tests/attention/test_xqa.py (1)
241-246: Good: FP8 cache scaling is documented.Comment explains the 4.0 factor and overflow concerns.
🧹 Nitpick comments (10)
csrc/xqa/tma.h (1)
74-83: Comment/code mismatch for nbDims==1 path.The comment says “nbDims==1 does not need tensormap,” but the code uses the tensor.1d variant taking a tensor map. Either drop the map for 1D linear copies or update the comment.
Also applies to: 129-138
csrc/xqa/gmma.cuh (1)
27-56: Bitfield layout is implementation-defined; prefer explicit packing.Relying on 64‑bit bitfield layout and reinterpret_cast to Raw can be brittle across compilers/ABIs. Recommend encoding/decoding with shifts/masks into a uint64_t to guarantee layout and endianness. Keep sizeof(MatDesc)==8 as a guard.
csrc/xqa/tensorMap.h (1)
3-3: cuda.h include: make header robust to non-CUDA analysis/compiles.Static analysis flagged ‘cuda.h’ not found. If this header is transitively included by non‑CUDA TU(s), guard the include or move these declarations behind a build flag. Example: wrap with a small shim header included only from .cpp, or add a dedicated config that ensures CUDA include paths are present in CI.
csrc/xqa/tensorMap.cpp (1)
43-73: Tensor map for contiguous KV cache looks correct.The function properly constructs a tensor map for contiguous KV cache layout:
- Global dimensions and strides are configured appropriately
- Swizzle selection based on cache line size (128B or 64B)
- Error handling via
checkCuwrapperMinor suggestion: The error message on line 64 "unsupported cache head size" could be more specific about expected values.
- throw std::runtime_error("unsupported cache head size"); + throw std::runtime_error("unsupported partElems: " + std::to_string(partElems) + + ", expected 128 or 64");flashinfer/jit/xqa.py (1)
76-100: SM version selection and build configuration verified; optional refactor still recommended for clarity and error handling.The changes are correct:
- New source files (
mha_sm90.cu,tensorMap.cpp) exist incsrc/xqa/- Build configuration properly references and links them
- Required CUDA Driver API linker flag (
-lcuda) and cache layout flag includedHowever, the SM version selection logic could be improved for maintainability. The current code defaults to
sm90a_nvcc_flagsfor unrecognized versions, which implicitly handlessm_version=90but obscures intent and provides no validation for truly unsupported architectures.Consider making the SM90 case explicit and adding validation:
- if sm_version == 100: + if sm_version == 90: + sm_nvcc_flags = sm90a_nvcc_flags + elif sm_version == 100: sm_nvcc_flags = sm100a_nvcc_flags elif sm_version == 120: sm_nvcc_flags = sm120a_nvcc_flags else: - sm_nvcc_flags = sm90a_nvcc_flags + raise ValueError(f"Unsupported sm_version: {sm_version}")This makes supported architectures explicit and catches invalid SM versions early.
csrc/flashinfer_xqa_binding.cu (1)
25-35: KV cache params behind preprocessor guards: keep Python and C++ signatures locked.Since params differ when
PAGED_KV_CACHE_LAYOUT!=1, ensure JIT always defines it (the JIT path does) and document this contract near the binding to avoid accidental ABI mismatches. Consider adding a static assert/log print on init when it’s not set.Also applies to: 28-29
tests/attention/test_xqa.py (3)
263-275: Remove unusedbeam_widtharg (lint: ARG001).
beam_widthincache_head_atis unused; drop it and update call sites.-def cache_head_at( +def cache_head_at( batch, is_k, idx_kv_head, pos, - cache_k_heads, - cache_v_heads, - page_list, - beam_width, + cache_k_heads, + cache_v_heads, + page_list, nb_k_heads, tokens_per_page, ): @@ - cache_head = cache_head_at( + cache_head = cache_head_at( batch, kv == 0, idx_kv_head, pos, cache_k_heads, cache_v_heads, - page_list_arg, - beam_width, + page_list_arg, nb_k_heads, tokens_per_page, )Also applies to: 291-303
317-319: Make scratch size configurable; 256 MiB may OOM CI.Read from an env var with a sane default to reduce flakiness.
- scratch_size = 256 << 20 + import os + scratch_mb = int(os.environ.get("FLASHINFER_TEST_SCRATCH_MB", "256")) + scratch_size = scratch_mb << 20You can validate with different values in CI matrix.
392-397: Stable epsilon for relative diff.Optional: use dtype-aware epsilon via
torch.finfoto avoid hard-coded 1e-8.- diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) + eps = torch.finfo(torch.float32).eps + diff_rel = diff_abs / (torch.abs(ref_output) + eps)flashinfer/xqa.py (1)
147-150: Avoid repeated capability queries and shorten the error.Cache CC once and use a shorter exception message (addresses Ruff TRY003).
- if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: - raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs") - sm_version = int(get_compute_capability(torch.device(device="cuda"))[0] * 10) + cc_major, _ = get_compute_capability(torch.device(device="cuda")) + if cc_major not in (9, 10, 12): + raise RuntimeError("Unsupported GPU (require SM90/100/120)") + sm_version = int(cc_major * 10)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/flashinfer_xqa_binding.cu(1 hunks)csrc/xqa/gmma.cuh(1 hunks)csrc/xqa/mha.cu(5 hunks)csrc/xqa/mha.h(2 hunks)csrc/xqa/tensorMap.cpp(1 hunks)csrc/xqa/tensorMap.h(1 hunks)csrc/xqa/tma.h(1 hunks)csrc/xqa/utils.cuh(2 hunks)csrc/xqa/xqa_wrapper.cu(2 hunks)flashinfer/aot.py(3 hunks)flashinfer/jit/xqa.py(2 hunks)flashinfer/xqa.py(6 hunks)tests/attention/test_xqa.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
csrc/xqa/tensorMap.h (1)
csrc/xqa/tensorMap.cpp (6)
getElemBytes(10-41)getElemBytes(10-10)makeTensorMapForContiguousKVCache(43-73)makeTensorMapForContiguousKVCache(43-47)makeTensorMapForPagedKVCache(75-117)makeTensorMapForPagedKVCache(75-78)
csrc/xqa/tensorMap.cpp (1)
csrc/xqa/utils.h (1)
checkCu(39-48)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (2)
JitSpec(181-280)gen_jit_spec(283-347)
flashinfer/xqa.py (3)
flashinfer/jit/xqa.py (1)
gen_xqa_module(38-101)flashinfer/jit/core.py (1)
build_and_load(268-280)flashinfer/utils.py (3)
register_custom_op(266-275)register_custom_op(285-304)get_compute_capability(245-248)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (4)
launchHopperF8MHAFlashInfer(3168-3275)launchHopperF8MHAFlashInfer(3168-3185)scratch(506-513)scratch(506-506)csrc/xqa/mha.cu (2)
launchMHAFlashInfer(2657-2749)launchMHAFlashInfer(2657-2674)
flashinfer/aot.py (1)
flashinfer/jit/core.py (1)
JitSpec(181-280)
csrc/xqa/mha.h (1)
csrc/xqa/mha_sm90.cu (4)
launchHopperF8MHAFlashInfer(3168-3275)launchHopperF8MHAFlashInfer(3168-3185)scratch(506-513)scratch(506-506)
tests/attention/test_xqa.py (1)
flashinfer/utils.py (1)
get_compute_capability(245-248)
csrc/xqa/tma.h (1)
csrc/xqa/mha_sm90.cu (16)
void(548-577)void(579-584)void(588-598)void(1693-1727)void(1765-1797)void(1799-1816)void(1841-1887)void(1976-1997)void(1999-2017)void(2049-2131)void(2180-2254)void(2256-2275)void(2278-2296)void(2316-2332)void(2336-2359)void(2396-2420)
csrc/flashinfer_xqa_binding.cu (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (4)
output(230-396)output(230-238)output(398-566)output(398-409)
🪛 Clang (14.0.6)
csrc/xqa/tensorMap.h
[error] 3-3: 'cuda.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.0)
flashinfer/xqa.py
148-148: Avoid specifying long messages outside the exception class
(TRY003)
tests/attention/test_xqa.py
271-271: Unused function argument: beam_width
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (21)
csrc/xqa/utils.cuh (2)
34-41: Numerical-stability note: initialize rowMax safely but validate ranges seen in practice.Lowering safeInitRowMax to -1e5 avoids FMA overflow in x*log2e - bias, but it changes the effective lower bound. Please validate on adversarial logits (very negative rows) to ensure no early saturation and no accuracy regressions. Consider guarding the optimization per-arch or switching to compute (x - rowMax) before scaling to avoid FMA on large magnitudes.
49-51: Code is correct; review comment contains incorrect assumptions.SM100 (Blackwell) opt-in dynamic shared memory per block is 227 KB, which matches the value at line 50. SM120 (Hopper Next) is 99 KB, which is already correctly configured on line 46—not on lines 49-51 as the review suggests.
The conditional structure properly segregates architectures:
- Line 45-46: SM120 (
__CUDA_ARCH__ == 1200) → 99 KB ✓- Line 49-50: SM100 (
__CUDA_ARCH__ == 1000) → 227 KB ✓Lines 49-51 handle only SM90 and SM100; SM120 is on a separate branch.
Likely an incorrect or invalid review comment.
csrc/xqa/gmma.cuh (1)
60-66: Assumption: shared address fits 18 bits (0x3FFFF) — verify on SM100/SM120.The assert/encode mask assumes __cvta_generic_to_shared(data) fits 18 bits (>>4 stored). Please validate against SM100/SM120 SMEM addressing; increase field width or add debug fallback if this trips on larger SMEM configs.
csrc/xqa/mha.h (2)
138-150: API surface change verified — no action required.Declaration (mha.h:133), definition (mha.cu:2657), and the single call site (xqa_wrapper.cu:43–68) all use identical
#if PAGED_KV_CACHE_LAYOUT == 1preprocessor guards. Arguments are correctly conditionalized at the call site to match the conditional signature. Python bindings invoke throughxqa_wrapper, which handles the dispatch internally, isolating the FFI boundary from the conditional signature. ODR/ABI consistency is maintained within the compilation unit.
194-211: All concerns verified successfully.The function declaration (mha.h:194) and definition (mha_sm90.cu:3168) match exactly, including all conditional parameters (LOW_PREC_OUTPUT, PAGED_KV_CACHE_LAYOUT, SPEC_DEC). The function is properly integrated: it's used in xqa_wrapper.cu:43 with correct FP8 conditional logic (
run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer), defined in the correct TU (mha_sm90.cu), and linked properly without issues.csrc/xqa/mha.cu (2)
95-95: LGTM: SM100 architecture support added correctly.The addition of
__CUDA_ARCH__ == 1000appropriately extends support to SM100 devices with the same configuration as SM800/870/900, using larger kHeadPartBytes (128) and cacheVTileSeqLen (64).
2657-2722: Based on my analysis of the kernel signature and the function call, there is a parameter mismatch in the launchMHAFlashInfer function call when SPEC_DEC is defined.Looking at the kernel_mha signature (mha_sm90.cu:615-652), the kernel expects:
nbKHeads(always)slidingWinSize(if SLIDING_WINDOW)qScale, output, rcpOutScale, q, attentionSinks, cacheListbeamSearchParams(if USE_BEAM_SEARCH)batchSize, kvCacheScale- Tensor maps via grid_constant (not passed as regular parameters)
specDecParams(if SPEC_DEC)semaphores, scratchHowever, at line 2707-2722, when SPEC_DEC is defined, the call passes
qSeqLen, nbKHeads, headGrpSize, qCuSeqLensas four separate parameters, but the kernel expects onlynbKHeadsat that position. Additionally, the call passesmask(line 2722) but the kernel has nomaskparameter—it expectsspecDecParamsinstead.The review comment requires verification of how SpecDecParams and BeamSearchParams should be constructed and passed, since the current call site appears to pass individual fields separately rather than properly constructed structs.
flashinfer/jit/xqa.py (4)
18-24: LGTM: Imports updated appropriately.The added imports for SM-specific NVCC flags enable proper multi-architecture support.
26-35: LGTM: NVCC flags configured correctly.The flags properly enable paged KV cache with layout 1, consistent with the conditional compilation paths in the C++ code.
47-55: LGTM: Flag generation logic is correct.The conditional flag generation properly handles:
- FP16 vs BF16 input (
DTYPEandINPUT_FP16)- FP8 vs FP16/BF16 KV cache (
CACHE_ELEM_ENUM)
38-46: All call sites are already updated with the new signature.Verification confirms that:
- The new
fp16_input,fp8_kv_cache, andsm_versionparameters are consistently used across the codebase- Both call sites (
flashinfer/aot.py:404andflashinfer/xqa.py:40) correctly pass the new parameters- Wrapper functions (
get_xqa_moduleandxqa) use the updated signature- No
use_fp16parameter exists anywhere in the codebaseThe API changes are complete and properly integrated.
csrc/xqa/xqa_wrapper.cu (2)
22-38: LGTM: Function signature updated appropriately.The signature changes are well-designed:
run_fp8_mhaparameter enables runtime selection between FP8 and standard MHAOptional<TensorView>forattentionSinksis more idiomatic than raw pointers- Conditional KV cache parameters based on
PAGED_KV_CACHE_LAYOUTproperly support both layout modes
39-65: Function pointer approach is type-safe; signatures are compatible.Verification confirms that both
launchHopperF8MHAFlashInferandlaunchMHAFlashInferhave identical signatures with matching conditional compilation blocks and parameter lists, making the function pointer assignment safe and correct.flashinfer/aot.py (3)
358-372: LGTM: gen_xqa signature updated for multi-architecture support.The function signature changes are consistent with the JIT module updates:
- Parameter renaming improves clarity
- SM gating ensures generation only when supported architectures are available
- New
fp8_kv_cache_parameter enables FP8 KV cache configurations
373-412: Multi-SM architecture support implemented correctly.The iteration logic properly:
- Constructs
sm_versionslist based on available architectures- Iterates over SM versions along with other configuration parameters
- Validates configurations before generating modules
- Passes all parameters to
gen_xqa_moduleconsistently
527-546: LGTM: gen_all_modules updated consistently.The changes to
gen_all_modulesproperly wire through the new parameters and SM version support to the XQA generator.csrc/xqa/tensorMap.cpp (2)
10-41: LGTM: Data type size lookup implemented correctly.The
getElemBytesfunction provides comprehensive coverage of CUDA tensor map data types with appropriate error handling.
75-117: Paged KV cache tensor map correctly supports two layout modes with consistent stride calculations.The implementation correctly configures tensor map dimensions and strides for two distinct layouts:
- VLLM Layout (PAGED_KV_CACHE_LAYOUT == 1): dimensions
{headElems, nbKHeads, tokensPerPage, pages}with strides accounting for head-first ordering- XQA Layout (PAGED_KV_CACHE_LAYOUT == 0, default): dimensions
{headElems, tokensPerPage, nbKHeads, pages}with strides accounting for token-first orderingThe dimension ordering aligns with memory access patterns throughout the codebase (verified in mha.cu, mhaUtils.cuh, and mha_sm90.cu). Both layouts apply the same swizzle modes and error handling. No issues identified.
csrc/flashinfer_xqa_binding.cu (2)
24-25: Good: Optional attention sinks.Switching to
tvm::ffi::Optional<TensorView>makes the API safer and clearer.
21-23: No changes needed; LOW_PREC_OUTPUT=0 is already set in compilation flags.The codebase already includes
"-DLOW_PREC_OUTPUT=0"in thexqa_nvcc_flagslist withinflashinfer/jit/xqa.py. This flag is passed toextra_cuda_cflagsin thegen_jit_spec()call, ensuring thercpOutScaleparameter is not included in the C++ function signature. There is no ABI drift risk because the conditional parameter is compiled out consistently.flashinfer/xqa.py (1)
50-72: Signature wiring looks consistent with the binding.Param order matches
xqa_wrapper(includingrun_fp8_mha, optionalattentionSinks, and separate K/V caches).If
LOW_PREC_OUTPUTis ever enabled, extend these call sites to passrcpOutScaleor force-DLOW_PREC_OUTPUT=0in JIT.Also applies to: 73-91
tests/attention/test_xqa.py
Outdated
| compute_capability = get_compute_capability(torch.device(device="cuda")) | ||
| if compute_capability[0] != 9: | ||
| pytest.skip("XQA only supports on Hopper at this moment") | ||
| if compute_capability[0] != 9 and run_fp8_mha: | ||
| pytest.skip("XQA supports fp8 mha only on Hopper GPUs") | ||
| set_random_seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Compute sm_count inside the test.
Set SM count after capability checks to avoid premature CUDA access.
def test_xqa(
@@
- compute_capability = get_compute_capability(torch.device(device="cuda"))
+ compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 9 and run_fp8_mha:
pytest.skip("XQA supports fp8 mha only on Hopper GPUs")
set_random_seed(42)
+ props = torch.cuda.get_device_properties(torch.cuda.current_device())
+ sm_count = props.multi_processor_countAlso applies to: 329-330
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 180-183, the test currently calls
into CUDA to compute sm_count before checking compute capability and may access
CUDA prematurely; move the sm_count computation so it runs after the
compute_capability check and any pytest.skip decision (i.e., compute sm_count
only after verifying compute_capability and run_fp8_mha), and apply the same
change to the other occurrence around lines 329-330; ensure you call the
sm_count helper (or get_sm_count) with the CUDA device only after the skip logic
and after set_random_seed(42) if that ordering is required.
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
flashinfer/aot.py (2)
374-383: Simplify sm_versions construction (nit).Current appends are verbose; a small tidy improves readability.
- sm_versions = [] - if has_sm90: - sm_versions.append(90) - if has_sm100: - sm_versions.append(100) - if has_sm120: - sm_versions.append(120) - if has_sm121: - sm_versions.append(121) + sm_versions = [ + v for v, present in [(90, has_sm90), (100, has_sm100), (120, has_sm120), (121, has_sm121)] + if present + ]
423-446: Apply optional refactor to deduplicate MLA generation for SM120/121.The two blocks are identical except
sm_version. Looping over(version, flag)tuples reduces repetition while maintaining unique JitSpec names (which includesm_versionas a suffix).- if has_sm120: - for token_per_page in token_per_page_: - yield gen_xqa_module_mla( - input_dtype=torch.float8_e4m3fn, - kv_cache_dtype=torch.float8_e4m3fn, - page_size=token_per_page, - head_dim=576, - head_group_ratio=128, - use_sliding_window=False, - sm_version=120, - ) - - if has_sm121: - for token_per_page in token_per_page_: - yield gen_xqa_module_mla( - input_dtype=torch.float8_e4m3fn, - kv_cache_dtype=torch.float8_e4m3fn, - page_size=token_per_page, - head_dim=576, - head_group_ratio=128, - use_sliding_window=False, - sm_version=121, - ) + for sm_v, present in [(120, has_sm120), (121, has_sm121)]: + if not present: + continue + for token_per_page in token_per_page_: + yield gen_xqa_module_mla( + input_dtype=torch.float8_e4m3fn, + kv_cache_dtype=torch.float8_e4m3fn, + page_size=token_per_page, + head_dim=576, + head_group_ratio=128, + use_sliding_window=False, + sm_version=sm_v, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/aot.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/aot.py (3)
flashinfer/xqa.py (2)
xqa(56-93)xqa(124-267)flashinfer/jit/xqa.py (2)
gen_xqa_module(41-116)gen_xqa_module_mla(119-180)flashinfer/jit/cpp_ext.py (1)
get_cuda_version(59-78)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/aot.py (2)
70-70: XQA MLA import wiring — looks good.The new import aligns with the added MLA generation paths below.
359-369: Docstring is incomplete; update to document the new parameters.The
gen_xqafunction at line 358 has been extended with new parameters (fp16_input_,fp8_kv_cache_,has_sm100,has_sm120,has_sm121), but the docstring (line 370) only contains a generic description. Update it to document each parameter's purpose and type, including the new SM version flags.The call site at line 570 is correctly updated with all new arguments.
| if ( | ||
| add_xqa and get_cuda_version() > Version("12.8") | ||
| ): # TODO: Earlier cuda versions have compile issues, will be fixed in future releases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Fix CUDA version gate: ‘> 12.8’ contradicts the comment; use ‘>= 12.8’.
As written, CUDA exactly 12.8 will skip XQA despite the note implying only earlier versions have issues. Align the condition with the intent.
- if (
- add_xqa and get_cuda_version() > Version("12.8")
- ): # TODO: Earlier cuda versions have compile issues, will be fixed in future releases
+ if (
+ add_xqa and get_cuda_version() >= Version("12.8")
+ ): # Earlier CUDA versions (< 12.8) have compile issues; enable for 12.8+🏁 Script executed:
#!/bin/bash
# Show current CUDA version as parsed to confirm gating behavior around 12.8
rg -n 'def get_cuda_version' -n --type=py -C5
# Quick check of other gates for consistency (>= in detect_sm_capabilities)
rg -nP 'get_cuda_version\(\)\s*[<>]=?\s*Version' --type=py -C2Length of output: 2238
Fix CUDA version gate: change > Version("12.8") to >= Version("12.8").
The current condition > Version("12.8") excludes CUDA 12.8 itself, contradicting the comment which states only earlier versions have compile issues. Align with the codebase pattern: lines 760 and 82 use >= for version gates.
if (
- add_xqa and get_cuda_version() > Version("12.8")
+ add_xqa and get_cuda_version() >= Version("12.8")
): # TODO: Earlier cuda versions have compile issues, will be fixed in future releases🤖 Prompt for AI Agents
In flashinfer/aot.py around lines 559 to 561, the CUDA version check currently
uses a strict greater-than ("> Version(\"12.8\")") which incorrectly excludes
CUDA 12.8; change the comparison to greater-than-or-equal (">=
Version(\"12.8\")") so CUDA 12.8 is allowed, matching the comment and the
project's version-gate pattern used elsewhere.
tests/attention/test_xqa.py
Outdated
| seq_len, | ||
| tokens_per_page, | ||
| use_fp16, | ||
| fp16_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this parameter more explicit? like input_data_type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
/bot run |
|
Hi @qsang-nv xqa UT failed on https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/225084624, would you mind taking a look? |
|
[FAILED] Pipeline #37222632: 5/17 passed |
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_xqa.py (1)
28-30: Avoid CUDA device queries at import time; compute sm_count inside testsMove SM count retrieval into each test after skip logic.
-props = torch.cuda.get_device_properties(0) -sm_count = props.multi_processor_count +# sm_count will be computed inside tests after capability checks
♻️ Duplicate comments (11)
flashinfer/jit/xqa.py (2)
74-74: Validate head_group_ratio earlyFail fast on invalid head_group_ratio.
- flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"] + if not isinstance(head_group_ratio, int) or head_group_ratio <= 0: + raise ValueError(f"Invalid head_group_ratio: {head_group_ratio}") + flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"]
120-129: Replace runtime asserts with explicit exceptionsAsserts can be stripped; keep validation always-on.
- assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA" - assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla" - assert input_dtype == torch.float8_e4m3fn, ( - "Only fp8 input is supported for xqa_module_mla" - ) - assert kv_cache_dtype == torch.float8_e4m3fn, ( - "Only fp8 kv cache is supported for xqa_module_mla" - ) - assert not use_sliding_window, "Sliding window is not supported for xqa_module_mla" + if head_group_ratio != 128: + raise ValueError("Only head_group_ratio=128 is supported for xqa MLA") + if head_dim != 576: + raise ValueError("Only head_dim=576 is supported for xqa_module_mla") + if input_dtype != torch.float8_e4m3fn: + raise ValueError("Only fp8 input (float8_e4m3fn) is supported for xqa_module_mla") + if kv_cache_dtype != torch.float8_e4m3fn: + raise ValueError("Only fp8 kv cache (float8_e4m3fn) is supported for xqa_module_mla") + if use_sliding_window: + raise ValueError("Sliding window is not supported for xqa_module_mla")flashinfer/xqa.py (3)
211-213: Validate head_group_ratio divisibility before floor-divisionAvoid silent truncation: enforce num_q_heads % num_kv_heads == 0 before computing head_group_ratio.
- # Calculate head_group_ratio - head_group_ratio = num_q_heads // num_kv_heads + # Calculate head_group_ratio + if num_kv_heads <= 0 or (num_q_heads % num_kv_heads) != 0: + raise ValueError( + f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + head_group_ratio = num_q_heads // num_kv_heads
223-233: Use q.device for CC, avoid repeated queries, and clarify support messageQuery the actual tensor device once; reuse it and fix messages.
- if ( - k_cache.dtype == torch.float8_e4m3fn - and get_compute_capability(torch.device(device="cuda"))[0] == 9 - ): - run_sm90_fp8_mha = True - else: - run_sm90_fp8_mha = False - - if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: - raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs") + cc_major, cc_minor = get_compute_capability(q.device) + run_sm90_fp8_mha = (k_cache.dtype == torch.float8_e4m3fn) and (cc_major == 9) + if cc_major not in (9, 10, 12): + raise RuntimeError("XQA is only supported on SM90, SM100, or SM12x GPUs")
425-427: MLA: use q.device for CC, and mention SM12x explicitlyEnsure device alignment on multi‑GPU and clearer error.
- if get_compute_capability(torch.device(device="cuda"))[0] not in [12]: - raise RuntimeError("XQA is only supported on SM120 GPUs") + cc_major, cc_minor = get_compute_capability(q.device) + if cc_major != 12: + raise RuntimeError("XQA MLA is only supported on SM12x (SM120/SM121) GPUs")flashinfer/aot.py (1)
532-534: Fix CUDA 12.8 gate: use >= to include 12.8Current “> 12.8” excludes 12.8 but the comment implies enabling from 12.8+.
- if ( - add_xqa and get_cuda_version() > Version("12.8") - ): # TODO: Earlier cuda versions have compile issues, will be fixed in future releases + if ( + add_xqa and get_cuda_version() >= Version("12.8") + ): # Earlier CUDA versions (< 12.8) have compile issues; enable for 12.8+tests/attention/test_xqa.py (4)
166-169: Protect skipif on CPU-only runnersShort-circuit when CUDA isn’t available.
-@pytest.mark.skipif( - get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12], - reason="XQA is only supported on SM90, SM100, SM120 GPUs", -) +@pytest.mark.skipif( + (not torch.cuda.is_available()) + or (get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]), + reason="XQA is only supported on SM90, SM100, SM120 GPUs", +)
328-344: Compute sm_count inside the test (after skip/seed) and use current devicePrevents premature CUDA access and mismatched devices.
- xqa( + # Determine SM count on the active device + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + sm_count = props.multi_processor_count + xqa( q_heads, @@ - sm_count=sm_count, + sm_count=sm_count, )
406-409: Guard MLA skipif with CUDA availabilitySame issue as above.
-@pytest.mark.skipif( - get_compute_capability(torch.device(device="cuda"))[0] not in [12], - reason="XQA mla is only supported on SM120 GPUs", -) +@pytest.mark.skipif( + (not torch.cuda.is_available()) + or (get_compute_capability(torch.device(device="cuda"))[0] not in [12]), + reason="XQA mla is only supported on SM120 GPUs", +)
545-558: Compute sm_count inside MLA testInline with main test fix.
- xqa_mla( + # Determine SM count on the active device + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + sm_count = props.multi_processor_count + xqa_mla( q_heads.to(torch.float8_e4m3fn), @@ - sm_count=sm_count, + sm_count=sm_count, )csrc/xqa/mha.cu (1)
480-481: Masked rows produce incorrect uniform attention; set −inf and handle zero-row sumsSetting masked elements to safeInitRowMax yields exp(score−max)=1, causing uniform distribution over masked tokens. Use −inf and ensure softmax handles fully-masked rows as zeros.
- acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; + acc(m, n)(i, j) = (maskFlag && col < nbValidCols) ? acc(m, n)(i, j) : -CUDART_INF_F;Follow-up: in the softmax path, detect rows where rowMax is −inf and force outputs to 0 (rowSum=0 guard) to avoid NaNs and unintended mass. I can draft a patch for the softmax/update section if helpful.
🧹 Nitpick comments (2)
flashinfer/xqa.py (1)
221-222: Don’t use assert for user input validationAsserts can be stripped with -O. Use explicit check and raise.
- assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype" + if k_cache.dtype != v_cache.dtype: + raise TypeError("K and V cache must have the same dtype")tests/attention/test_xqa.py (1)
271-283: Remove unused function argument ‘beam_width’Clean up helper signatures and call sites to satisfy linters.
-def cache_head_at( +def cache_head_at( batch, is_k, idx_kv_head, pos, cache_k_heads, cache_v_heads, page_list, - beam_width, nb_k_heads, tokens_per_page, ): @@ - cache_head = cache_head_at( + cache_head = cache_head_at( batch, kv == 0, idx_kv_head, pos, cache_k_heads, cache_v_heads, page_list_arg, - beam_width, nb_k_heads, tokens_per_page, )Apply the same change to the MLA helper and its call sites.
Also applies to: 488-500
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/xqa/mha.cu(5 hunks)csrc/xqa/utils.cuh(2 hunks)flashinfer/aot.py(4 hunks)flashinfer/jit/xqa.py(1 hunks)flashinfer/xqa.py(3 hunks)tests/attention/test_xqa.py(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/xqa/utils.cuh
🧰 Additional context used
🧬 Code graph analysis (5)
tests/attention/test_xqa.py (2)
flashinfer/xqa.py (4)
xqa(54-91)xqa(122-260)xqa_mla(285-314)xqa_mla(341-450)flashinfer/utils.py (1)
get_compute_capability(251-254)
csrc/xqa/mha.cu (1)
csrc/xqa/mla_sm120.cu (1)
cacheList(88-94)
flashinfer/xqa.py (4)
flashinfer/jit/xqa.py (2)
gen_xqa_module(38-109)gen_xqa_module_mla(112-168)flashinfer/utils.py (6)
get_device_sm_count(595-596)register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)get_compute_capability(251-254)csrc/xqa/xqa_wrapper.cu (4)
xqa_wrapper(51-95)xqa_wrapper(51-67)xqa_wrapper_mla(23-48)xqa_wrapper_mla(23-31)csrc/flashinfer_xqa_binding.cu (2)
xqa_wrapper(34-50)xqa_wrapper_mla(20-28)
flashinfer/aot.py (3)
flashinfer/xqa.py (2)
xqa(54-91)xqa(122-260)flashinfer/jit/xqa.py (2)
gen_xqa_module(38-109)gen_xqa_module_mla(112-168)flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)
flashinfer/jit/xqa.py (2)
flashinfer/compilation_context.py (2)
CompilationContext(27-68)get_nvcc_flags_list(50-68)flashinfer/jit/core.py (2)
JitSpec(181-280)gen_jit_spec(283-349)
🪛 Ruff (0.14.1)
tests/attention/test_xqa.py
279-279: Unused function argument: beam_width
(ARG001)
496-496: Unused function argument: beam_width
(ARG001)
flashinfer/xqa.py
232-232: Avoid specifying long messages outside the exception class
(TRY003)
426-426: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/jit/xqa.py
51-53: Avoid specifying long messages outside the exception class
(TRY003)
63-65: Avoid specifying long messages outside the exception class
(TRY003)
69-71: Avoid specifying long messages outside the exception class
(TRY003)
133-135: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
csrc/xqa/mha.cu (1)
2663-2705: KV cache VLLM layout wiring looks correctThe wrapper correctly switches between kCacheVLLM/vCacheVLLM and pool under PAGED_KV_CACHE_LAYOUT. LGTM.
|
/bot run |
|
@qsang-nv is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
[SUCCESS] Pipeline #37374779: 7/17 passed |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
📌 Description
Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv layout.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation