Skip to content

Commit ef687e9

Browse files
qsang-nvyzh119
andauthored
add xqa fp8 mha and fp8 kv cache (#1769)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv layout. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * MLA-based attention path and dedicated MLA entrypoints (SM120/121) * FP8 KV-cache support with optional paged KV layout and separate K/V cache inputs * Asynchronous tensor-map/TMA and matrix-descriptor primitives for high-throughput GPU transfers * Dtype-driven config and expanded GPU SM gating for builds/runtimes * **Bug Fixes** * Improved numerical stability for attention mask initialization * **Tests** * Expanded coverage for MLA, FP8, FP16/BF16, and new cache layouts * **Documentation** * Added XQA API docs and new public symbols <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Qidi Sang <[email protected]> Co-authored-by: yzh119 <[email protected]>
1 parent 99657ed commit ef687e9

File tree

19 files changed

+11894
-223
lines changed

19 files changed

+11894
-223
lines changed

β€Žcsrc/flashinfer_xqa_binding.cuβ€Ž

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,32 @@
1616

1717
#include "tvm_ffi_utils.h"
1818

19-
void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize,
20-
double qScale, TensorView output,
19+
#if MLA_WRAPPER
20+
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
21+
#if PAGED_KV_CACHE_LAYOUT == 1
22+
TensorView kCacheVLLM, TensorView vCacheVLLM,
23+
#else
24+
TensorView pool,
25+
#endif
26+
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
27+
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
28+
TensorView scratch);
29+
30+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
31+
32+
#else
33+
34+
void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
35+
int64_t slidingWinSize, double qScale, TensorView output,
2136
#if LOW_PREC_OUTPUT
2237
TensorView rcpOutScale,
2338
#endif
24-
TensorView q, TensorView attentionSinks, TensorView pool,
39+
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks,
40+
#if PAGED_KV_CACHE_LAYOUT == 1
41+
TensorView kCacheVLLM, TensorView vCacheVLLM,
42+
#else
43+
TensorView pool,
44+
#endif
2545
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
2646
int64_t batchSize, TensorView kvCacheScale,
2747
#if SPEC_DEC
@@ -30,3 +50,5 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW
3050
TensorView semaphores, TensorView scratch);
3151

3252
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);
53+
54+
#endif

β€Žcsrc/xqa/gmma.cuhβ€Ž

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
4+
*
5+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6+
* property and proprietary rights in and to this material, related
7+
* documentation and any modifications thereto. Any use, reproduction,
8+
* disclosure or distribution of this material and related documentation
9+
* without an express license agreement from NVIDIA CORPORATION or
10+
* its affiliates is strictly prohibited.
11+
*/
12+
13+
#pragma once
14+
#include "cuda_hint.cuh"
15+
#include "mha_stdheaders.cuh"
16+
#include "utils.cuh"
17+
#ifndef __CUDACC__
18+
#include <cuda_runtime.h>
19+
#endif
20+
#include <cuda_fp16.h>
21+
#include <cuda_fp8.h>
22+
23+
namespace gmma {
24+
25+
enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 };
26+
27+
struct MatDesc {
28+
uint64_t addr : 16;
29+
uint64_t dimKOffset : 16;
30+
uint64_t dimMNOffset : 16;
31+
uint64_t pad0 : 1;
32+
uint64_t baseOffset : 3;
33+
uint64_t pad1 : 10;
34+
SwizzleMode swizzle : 2;
35+
36+
enum class Raw : uint64_t {};
37+
38+
[[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const {
39+
MatDesc ret = *this;
40+
ret.addr = encode(__cvta_generic_to_shared(data));
41+
return ret;
42+
}
43+
44+
static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; }
45+
46+
__device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); }
47+
48+
__device__ inline Raw const& raw() const {
49+
static_assert(sizeof(MatDesc) == 8);
50+
return reinterpret_cast<Raw const&>(*this);
51+
}
52+
53+
static __device__ inline MatDesc fromRaw(Raw const& raw) {
54+
return reinterpret_cast<MatDesc const&>(raw);
55+
}
56+
};
57+
58+
static_assert(sizeof(MatDesc) == 8);
59+
60+
[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) {
61+
assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0);
62+
MatDesc::Raw ret = base;
63+
auto& u32x2 = reinterpret_cast<uint32_t(&)[2]>(ret);
64+
u32x2[0] += static_cast<uint32_t>(__cvta_generic_to_shared(data)) >> 4;
65+
return ret;
66+
}
67+
68+
__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
69+
uint32_t dimMNByteOffset, void const* patternStartAddr,
70+
SwizzleMode swizzleMode) {
71+
uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr);
72+
uint32_t const baseAlign = [&]() -> uint32_t {
73+
switch (swizzleMode) {
74+
case SwizzleMode::kNONE:
75+
return 1;
76+
case SwizzleMode::k128:
77+
return 1024;
78+
case SwizzleMode::k64:
79+
return 512;
80+
case SwizzleMode::k32:
81+
return 256;
82+
}
83+
asm volatile("trap;\n");
84+
return 0;
85+
}();
86+
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
87+
return MatDesc{
88+
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
89+
/*dimKOffset=*/MatDesc::encode(dimKByteOffset),
90+
/*dimMNOffset=*/MatDesc::encode(dimMNByteOffset),
91+
/*pad0=*/0,
92+
/*baseOffset=*/baseOffset,
93+
/*pad1=*/0,
94+
/*swizzle=*/swizzleMode,
95+
};
96+
}
97+
98+
__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
99+
uint32_t dimMNByteOffset, SwizzleMode swizzleMode) {
100+
return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode);
101+
}
102+
103+
inline constexpr uint32_t instM = 64;
104+
105+
template <typename MathElem>
106+
inline constexpr uint32_t instK = 32 / sizeof(MathElem);
107+
108+
inline constexpr uint32_t instNBase = 8;
109+
110+
// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N
111+
// acc is used as both input and output.
112+
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
113+
__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA,
114+
MatDesc::Raw descB, bool accHasVal);
115+
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
116+
__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2],
117+
uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal);
118+
119+
__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); }
120+
121+
__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); }
122+
123+
template <uint32_t targetNbInFlightGroups>
124+
__device__ inline void wait_group() {
125+
asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups));
126+
}
127+
128+
template <bool swizzle, typename T, uint32_t rows, uint32_t cols, bool alignedForSwizzle>
129+
constexpr SwizzleMode getSwizzleMode(Array2D<T, rows, cols, alignedForSwizzle> const&) {
130+
constexpr auto rowBytes = Array2D<T, rows, cols, alignedForSwizzle>::rowBytes;
131+
if constexpr (!swizzle) {
132+
return SwizzleMode::kNONE;
133+
}
134+
if constexpr (rowBytes % 128 == 0) {
135+
return SwizzleMode::k128;
136+
} else if constexpr (rowBytes == 64) {
137+
return SwizzleMode::k64;
138+
} else {
139+
static_assert(rowBytes == 32);
140+
return SwizzleMode::k32;
141+
}
142+
}
143+
} // namespace gmma
144+
145+
#include "gmma_impl.cuh"

0 commit comments

Comments
Β (0)