diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 8bcbafafd6..dd24da1692 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -32,8 +32,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK int64_t slidingWinSize, double qScale, tvm::ffi::Optional qScaleTensor, TensorView output, double rcpOutScale, TensorView q, tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, - TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, - TensorView seqLen, int64_t batchSize, double kvCacheScale, + TensorView vCacheVLLM, tvm::ffi::Optional kSfCacheVLLM, + tvm::ffi::Optional vSfCacheVLLM, TensorView kvCachePageList, + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, tvm::ffi::Optional kvScaleTensor, int64_t qSeqLen, tvm::ffi::Optional mask, TensorView semaphores, TensorView scratch, bool enable_pdl); diff --git a/csrc/xqa/defines.h b/csrc/xqa/defines.h index 3794708a3b..32726d031d 100644 --- a/csrc/xqa/defines.h +++ b/csrc/xqa/defines.h @@ -83,6 +83,12 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena #define CACHE_ELEM_ENUM 2 #endif +#if CACHE_ELEM_ENUM == 3 +#define ENABLE_4BIT_KV_CACHE 1 +#else +#define ENABLE_4BIT_KV_CACHE 0 +#endif + // don't modify #define USE_KV_CACHE true @@ -181,8 +187,51 @@ static_assert(CACHE_ELEM_ENUM != 0); #include #include + +#if ENABLE_4BIT_KV_CACHE +#include +#endif + template -using ElemType = mha::conditional_t< - elemTypeEnum == 0, INPUT_ELEM, - mha::conditional_t>>; +struct ElemTypeConverter; +// Specialization for elemTypeEnum = 0 (half/bf16) +template <> +struct ElemTypeConverter<0> { + using Type = INPUT_ELEM; + using ContainerType = INPUT_ELEM; + static constexpr int ElemsPerContainer = 1; + using ScalingFactorType = void; + static constexpr int QuantVectorSize = 1; +}; + +// Specialization for elemTypeEnum = 1 (int8) +template <> +struct ElemTypeConverter<1> { + using Type = int8_t; + using ContainerType = int8_t; + static constexpr int ElemsPerContainer = 1; + using ScalingFactorType = void; + static constexpr int QuantVectorSize = 1; +}; + +// Specialization for elemTypeEnum = 2 (fp8) +template <> +struct ElemTypeConverter<2> { + using Type = __nv_fp8_e4m3; + using ContainerType = __nv_fp8_e4m3; + static constexpr int ElemsPerContainer = 1; + using ScalingFactorType = void; + static constexpr int QuantVectorSize = 1; +}; + +#if ENABLE_4BIT_KV_CACHE +// Specialization for elemTypeEnum = 3 (NVFP4) +template <> +struct ElemTypeConverter<3> { + using Type = __nv_fp4_e2m1; + using ContainerType = __nv_fp4x2_e2m1; + static constexpr int ElemsPerContainer = 2; + using ScalingFactorType = __nv_fp8_e4m3; + static constexpr int QuantVectorSize = 16; +}; +#endif diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index c8d6ca2c22..1b13c39157 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -59,6 +59,15 @@ static_assert(inputElemSize >= cacheElemSize); constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); + +// If cache is 4-bit, in GMEM the elements are packed, but after loading to SMEM they are padded. +// Each 16 elements (64b) are padded with 64b to match 128b smem grains. Therefore, +// the grainBytes of GMEM is half of the grainBytes of SMEM. +constexpr uint32_t grainBytesGmemCache = grainBytes / CacheElemConverter::ElemsPerContainer; +#if ENABLE_4BIT_KV_CACHE +constexpr uint32_t grainBytesSf = 4; +#endif + constexpr bool enableMicroFastPath = false; // x: horizontal stacking for cta horizontal tile size @@ -314,12 +323,32 @@ struct alignas(128) SharedMem { using VSmemBuffer = Array2D; +#if ENABLE_4BIT_KV_CACHE + using KSfSmemBuffer = + Array2D; + using VSfSmemBuffer = + Array2D; + using KSfSmemBufferPlain = + Array2D<__nv_fp8_e4m3, warpTile.x, kHeadPartBytes / CacheElemConverter::QuantVectorSize>; + using VSfSmemBufferPlain = + Array2D<__nv_fp8_e4m3, cacheVTileSeqLen, + (grpLoadV ? headElems : warpTile.x) / CacheElemConverter::QuantVectorSize>; +#endif + QSmemBuffer q[ctaShapeInWarps.y][nbQBuffers]; KSmemBuffer k[ctaShapeInWarps.x][nbKBuffers]; XSmemBuffer x[ctaShapeInWarps.y][ctaShapeInWarps.x]; static_assert(nbXBuffers == 1); VSmemBuffer v[gemm1NbWarpGrps][grpLoadV ? 1 : gemm1WarpsPerGrp][nbVBuffers]; +#if ENABLE_4BIT_KV_CACHE + KSfSmemBuffer kSf[ctaShapeInWarps.x][nbKBuffers]; + VSfSmemBuffer vSf[gemm1NbWarpGrps][grpLoadV ? 1 : gemm1WarpsPerGrp][nbVBuffers]; +#endif + SMemWarpRowMax warpRowMax[ctaShapeInWarps.y] [ctaShapeInWarps.x]; // the max used when computing this->x SMemWarpRowMax warpRowSum[ctaShapeInWarps.y][ctaShapeInWarps.x]; // the row sum of gemm0 output @@ -674,6 +703,34 @@ __device__ inline void storeOrderedGemmOutTile(Warp const& warp, SharedMem::XSme #endif } +__device__ inline void storeReorderedXTile(Warp const& warp, SharedMem::XSmemBuffer& dst, + GemmOutRegTile const& src) { + static_assert(sizeof(dst) == sizeof(src) * warp_size); + uint32_t const lane = laneId(); + +#pragma unroll + for (uint32_t m = 0; m < exactDiv(dst.rows, 8); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(dst.cols * grainBytes / inputElemSize, 16); n++) { + uint32_t const idxRowLocal = laneId() / 4; + // Reorder from e0, e1, e2, e3, e4, e5, e6, e7 | e8, e9, e10, e11, e12, e13, e14, e15 + // to ========> e0, e1, e4, e5, e8, e9, e12, e13 | e2, e3, e6, e7, e10, e11, e14, e15 + + //// Tid: dst reg idx, ldGrain idx, idx word + // T0: 0, 2 0 (0, 2) + // T1: 4, 6 1 (0, 2) + // T2: 1, 3 0 (1, 3) + // T3: 5, 7 1 (1, 3) + uint32_t const idxColLdGrain = n * 2 + laneId() % 2; + uint32_t const idxWordLocal = (laneId() % 4) / 2; + dst.template at(8 * m + idxRowLocal, idxColLdGrain)[idxWordLocal] = + reinterpret_cast(src(m, 2 * n)); + dst.template at(8 * m + idxRowLocal, idxColLdGrain)[idxWordLocal + 2] = + reinterpret_cast(src(m, 2 * n + 1)); + } + } +} + // Reorder to compensate the reorder caused by V cache load+conversion. __device__ inline void reorderAndStoreGemmOutTile(Warp const& warp, SharedMem::XSmemBuffer& dst, GemmOutRegTile const& src) { @@ -845,8 +902,8 @@ using InstInMatWTrans = InstInMat +template __device__ inline InstInMatWTrans loadInstInMat( Warp const& warp, Array2D const& src, uint32_t rowOffset, uint32_t colOffset) { @@ -862,7 +919,16 @@ __device__ inline InstInMatWTrans loadInstInMat( LdGrain const* const ptr = &src.template at(rowOffset + 8 * srcIdxMNEx + laneId() % 8, colOffset + srcIdxKEx); - Vec const data = ldmatrix_4x(warp, ptr); + Vec data; +#if ENABLE_4BIT_KV_CACHE + if constexpr (is4BitElem) { + data = ldmatrix_4x_unpack_4b(warp, ptr); + } else { + data = ldmatrix_4x(warp, ptr); + } +#else + data = ldmatrix_4x(warp, ptr); +#endif static_assert(sizeof(Dst) == sizeof(data)); Dst dst; #pragma unroll @@ -880,7 +946,7 @@ using Array2DWTrans = Array2D __device__ inline Array2DWTrans, dstRows, dstCols, transArr2D> @@ -893,7 +959,7 @@ loadMatrix(Warp const& warp, Array2D const& src, uint #pragma unroll for (uint32_t j = 0; j < dstCols; j++) { (transArr2D ? dst(j, i) : dst(i, j)) = - loadInstInMat( + loadInstInMat( warp, src, rowBeg + (mnEx * 8) * i, colBeg + kEx * j); } } @@ -906,21 +972,79 @@ loadMatrix(Warp const& warp, Array2D const& src, uint template __device__ inline void smemQKPartGemm(Warp const& warp, WarpAcc& acc, SharedMem::QSmemBuffer const& q, uint32_t qColBeg, - SharedMem::KSmemBuffer const& k) { + SharedMem::KSmemBuffer const& k +#if ENABLE_4BIT_KV_CACHE + , + SharedMem::KSfSmemBuffer const& kSf +#endif +) { assert(qColBeg % (SharedMem::KSmemBuffer::cols) == 0); constexpr uint32_t kEx = 2; constexpr uint32_t mnEx = 2; static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); +#if !ENABLE_4BIT_KV_CACHE static_assert((mha::is_same_v || mha::is_same_v || mha::is_same_v || mha::is_same_v), "not implemented"); +#else + static_assert(mha::is_same_v, "not implemented"); +#endif constexpr uint32_t nbInstInMatPerSliceInGemmKDim = 1; constexpr uint32_t kElemSize = sizeof(KElemType); constexpr uint32_t elemsPerKHeadPart = exactDiv(kHeadPartBytes, kElemSize); constexpr uint32_t gemmKSplit = exactDiv(elemsPerKHeadPart, 8 * kEx * nbInstInMatPerSliceInGemmKDim); +#if ENABLE_4BIT_KV_CACHE + constexpr uint32_t cvtExp = exactDiv(inputElemSize, kElemSize); + constexpr uint32_t mnExK = mnEx * cvtExp; + constexpr uint32_t kExK = exactDiv(kEx, cvtExp); + constexpr uint32_t kSliceRows = exactDiv(warpTile.x, 8 * mnExK); // in InstInMat + constexpr uint32_t kSliceCols = nbInstInMatPerSliceInGemmKDim; + // The FP16/BF16 SF for K tile. Each InputElem2 element holds 2 SFs for 2 mnExK. + uint32_t kSfSlice[gemmKSplit][kSliceRows][kSliceCols][exactDiv(mnExK, 2)][kExK]; + // This simplifies the loop. + static_assert(kExK == 1, "not implemented"); + // We want to load 4 SFs in K dimension to utilize 32b LDS. + static_assert(gemmKSplit % 4 == 0, "not implemented"); +#pragma unroll + for (uint32_t m = 0; m < kSliceRows; m++) { +#pragma unroll + for (uint32_t n = 0; n < kSliceCols; n++) { +#pragma unroll + for (uint32_t i = 0; i < exactDiv(mnExK, 2); i++) { + // The corresponding SF row index. + uint32_t const sfRowIdx = (m * mnExK + i * 2) * 8 + laneId() / 4; +#pragma unroll + for (uint32_t s = 0; s < exactDiv(gemmKSplit, 4); s++) { + // Load the SF from 2 rows. + uint32_t const sfOrigRow0 = kSf(sfRowIdx, s); + uint32_t const sfOrigRow1 = kSf(sfRowIdx + 8, s); + // Permute the elements. + // SFRow0: [(0,0) (0,1) (0,2) (0,3)] + // SFRow1: [(1,0) (1,1) (1,2) (1,3)] + // After permutation: + // tmpReg0: [(0,0) (1,0) (0,2) (1,2)] + // tmpReg1: [(0,1) (1,1) (0,3) (1,3)] + uint32_t tmpReg0 = prmt(sfOrigRow0, sfOrigRow1, {0, 4, 2, 6}); + uint32_t tmpReg1 = prmt(sfOrigRow0, sfOrigRow1, {1, 5, 3, 7}); + // Result after conversion: + // sfVal0[0]: [(0,0) (1,0)]; sfVal0[1]: [(0,2) (1,2)] + // sfVal1[0]: [(0,1) (1,1)]; sfVal1[1]: [(0,3) (1,3)] + auto sfVal0 = convertKCacheWordToF16(tmpReg0); + auto sfVal1 = convertKCacheWordToF16(tmpReg1); + // Store to kSfSlice. + kSfSlice[s * 4 + 0][m][n][i][0] = sfVal0[0]; + kSfSlice[s * 4 + 2][m][n][i][0] = sfVal0[1]; + kSfSlice[s * 4 + 1][m][n][i][0] = sfVal1[0]; + kSfSlice[s * 4 + 3][m][n][i][0] = sfVal1[1]; + } + } + } + } +#endif + // @fixme: check if compiler mixes LDS+HMMA and does prefetch properly. We are not doing prefetch // explicitly. But we do fully unroll and expect compiler to do that for us. constexpr uint32_t nbUnroll = cacheElemSize == 2 ? gemmKSplit : 2; @@ -930,7 +1054,7 @@ __device__ inline void smemQKPartGemm(Warp const& warp, WarpAcc& acc, constexpr uint32_t qSliceRows = exactDiv(warpTile.y, 8 * mnEx); // in InstInMat constexpr uint32_t qSliceCols = nbInstInMatPerSliceInGemmKDim; Array2D, qSliceRows, qSliceCols> const qSlice = - loadMatrix( + loadMatrix( warp, q, 0, qColBeg + kEx * qSliceCols * s); // load k constexpr uint32_t cvtExp = exactDiv(inputElemSize, kElemSize); @@ -939,13 +1063,12 @@ __device__ inline void smemQKPartGemm(Warp const& warp, WarpAcc& acc, constexpr uint32_t kSliceRows = exactDiv(warpTile.x, 8 * mnExK); // in InstInMat constexpr uint32_t kSliceCols = nbInstInMatPerSliceInGemmKDim; Array2D, kSliceRows, kSliceCols> const kSliceOrig = - loadMatrix(warp, k, 0, - kExK * kSliceCols * s); + loadMatrix( + warp, k, 0, kExK * kSliceCols * s); auto const kSlice = [&]() -> Array2D, kSliceRows, kSliceCols> { if constexpr (mha::is_same_v) { return kSliceOrig; - } else if constexpr ((mha::is_same_v || - mha::is_same_v)) { + } else { Array2D, kSliceRows, kSliceCols> ret; #pragma unroll for (uint32_t m = 0; m < kSliceRows; m++) { @@ -955,8 +1078,15 @@ __device__ inline void smemQKPartGemm(Warp const& warp, WarpAcc& acc, for (uint32_t i = 0; i < mnExK; i++) { #pragma unroll for (uint32_t j = 0; j < kExK; j++) { - auto const data = + auto data = convertKCacheWordToF16(kSliceOrig(m, n).data[i][j]); +#if ENABLE_4BIT_KV_CACHE + uint32_t const sfPacked = kSfSlice[s][m][n][i / 2][0]; + uint16_t const sf = reinterpret_cast(&sfPacked)[i % 2]; + data[0] = applyF16ScalingFactor(data[0], sf); + data[1] = applyF16ScalingFactor(data[1], sf); +#endif + ret(m, n).data[i][j * cvtExp] = data[0]; ret(m, n).data[i][j * cvtExp + 1] = data[1]; } @@ -964,9 +1094,6 @@ __device__ inline void smemQKPartGemm(Warp const& warp, WarpAcc& acc, } } return ret; - } else { - assert(!"not implemented"); - trap(); } }(); // compute @@ -993,12 +1120,20 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX UniformRescaleMask xRowNeedRescaleMask, ThrdRegRowMax xRowScales, SharedMem::XSmemBuffer const& x, uint32_t idxVTilePerXTile, SharedMem::VSmemBuffer const& vt, +#if ENABLE_4BIT_KV_CACHE + SharedMem::VSfSmemBuffer const& vSf, +#endif uint32_t idxNSplit) { static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); +#if !ENABLE_4BIT_KV_CACHE static_assert((mha::is_same_v || mha::is_same_v || mha::is_same_v || mha::is_same_v), "not implemented"); +#else + static_assert(mha::is_same_v, "not implemented"); +#endif + constexpr uint32_t kEx = 2; constexpr uint32_t mnEx = 2; constexpr uint32_t nbInstInMatPerSliceInGemmKDim = 1; @@ -1030,6 +1165,25 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX replicateForQuad(warp, reinterpret_cast(xRowScalesF16)); } +#if ENABLE_4BIT_KV_CACHE + // Prefetch buffer for SF + constexpr uint32_t nbSfPrefetchBuffers = exactDiv(warpTile.x, 16 * sizeof(uint32_t)); + Array2D vSfPrefetch; +#pragma unroll + for (uint32_t i = 0; i < vSfPrefetch.rows; i += 4) { +#pragma unroll + for (uint32_t j = 0; j < nbSfPrefetchBuffers; j++) { + // T0 reads 0-3, 16-19 + // T1 reads 4-7, 20-23 + // T2 reads 8-11, 24-27 + // T3 reads 12-15, 28-31 + uint32_t const sfRowIdxInSlice = (laneId() % 4) * 4; + uint32_t const sfRowIdx = (i / 4) * 16 + i % 4 + sfRowIdxInSlice; + vSfPrefetch(i, j) = reinterpret_cast(vSf.template at(sfRowIdx, j)); + } + } +#endif + // @fixme: check if compiler mixes LDS+HMMA and does prefetch properly. We are not doing prefetch // explicitly. But we do fully unroll and expect compiler to do that for us. #pragma unroll @@ -1041,7 +1195,8 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX SharedMem::XSmemBuffer::cols / nbCacheVTilesPerXTile * idxVTilePerXTile + exactDiv(inputElemSize * 8 * kEx * nbInstInMatPerSliceInGemmKDim, grainBytes) * s; Array2D, xSliceRows, xSliceCols> xSlice = - loadMatrix(warp, x, 0u, colBeg); + loadMatrix(warp, x, 0u, + colBeg); if (!enableMicroFastPath || !skipXRowRescale) { #pragma unroll for (uint32_t m = 0; m < xSliceRows; m++) { @@ -1065,13 +1220,71 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX constexpr uint32_t vSliceRows = nbInstInMatPerSliceInGemmKDim; uint32_t const rowBeg = 8 * kEx * nbInstInMatPerSliceInGemmKDim * s; Array2D, vSliceCols, vSliceRows> const vSliceOrig = - loadMatrix( + loadMatrix( warp, vt, rowBeg, mnEx * vSliceCols * idxNSplit); + + // Load and convert SFs for V. +#if ENABLE_4BIT_KV_CACHE + // The FP16/BF16 SF for V tile. + uint32_t vSfSlice[vSliceCols][vSliceRows][mnEx][kEx]; + // We want to load 4 SFs in head dimension to utilize 32b LDS. + static_assert(vSliceCols % 2 == 0, "not implemented"); + static_assert(mnEx == 2, "not implemented"); + // Assert kEx is 2 + static_assert(kEx == 2, "not implemented"); + for (uint32_t m = 0; m < exactDiv(vSliceCols, 2); m++) { +#pragma unroll + for (uint32_t n = 0; n < vSliceRows; n++) { + // Load the SF from 4 rows. + uint32_t sfOrig[4]; +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { + uint32_t const rowIdx = rowBeg + (laneId() % 4) * 4 + j; + // The column index is in 32b load unit. + uint32_t const colIdx = idxNSplit * (vSliceCols / 2) + m; + sfOrig[j] = vSf(rowIdx, colIdx); + } + // Permute the elements. + // SFRow0: [(0,0) (0,1) (0,2) (0,3)] + // SFRow1: [(1,0) (1,1) (1,2) (1,3)] + // SFRow2: [(2,0) (2,1) (2,2) (2,3)] + // SFRow3: [(3,0) (3,1) (3,2) (3,3)] + // After permutation: + // tmpReg0: [(0,0) (1,0) (0,2) (1,2)] + // tmpReg1: [(0,1) (1,1) (0,3) (1,3)] + // tmpReg2: [(2,0) (3,0) (2,2) (3,2)] + // tmpReg3: [(2,1) (3,1) (2,3) (3,3)] + uint32_t tmpReg[4]; + tmpReg[0] = prmt(sfOrig[0], sfOrig[1], {0, 4, 2, 6}); + tmpReg[1] = prmt(sfOrig[0], sfOrig[1], {1, 5, 3, 7}); + tmpReg[2] = prmt(sfOrig[2], sfOrig[3], {0, 4, 2, 6}); + tmpReg[3] = prmt(sfOrig[2], sfOrig[3], {1, 5, 3, 7}); + // Result after conversion: + // sfVal0[0]: [(0,0) (1,0)]; sfVal0[1]: [(0,2) (1,2)] + // sfVal1[0]: [(0,1) (1,1)]; sfVal1[1]: [(0,3) (1,3)] + // sfVal2[0]: [(2,0) (3,0)]; sfVal2[1]: [(2,2) (3,2)] + // sfVal3[0]: [(2,1) (3,1)]; sfVal3[1]: [(2,3) (3,3)] + auto sfVal0 = convertKCacheWordToF16(tmpReg[0]); + auto sfVal1 = convertKCacheWordToF16(tmpReg[1]); + auto sfVal2 = convertKCacheWordToF16(tmpReg[2]); + auto sfVal3 = convertKCacheWordToF16(tmpReg[3]); + // Store to kSfSlice. + vSfSlice[m * 2 + 0][n][0][0] = sfVal0[0]; + vSfSlice[m * 2 + 0][n][0][1] = sfVal2[0]; + vSfSlice[m * 2 + 0][n][1][0] = sfVal1[0]; + vSfSlice[m * 2 + 0][n][1][1] = sfVal3[0]; + vSfSlice[m * 2 + 1][n][0][0] = sfVal0[1]; + vSfSlice[m * 2 + 1][n][0][1] = sfVal2[1]; + vSfSlice[m * 2 + 1][n][1][0] = sfVal1[1]; + vSfSlice[m * 2 + 1][n][1][1] = sfVal3[1]; + } + } +#endif + Array2D, vSliceCols, vSliceRows> const vSlice = [&]() { if constexpr (mha::is_same_v) { return vSliceOrig; - } else if constexpr ((mha::is_same_v || - mha::is_same_v)) { + } else { Array2D, vSliceCols, vSliceRows> ret; #pragma unroll for (uint32_t m = 0; m < ret.rows; m++) { @@ -1079,6 +1292,28 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX for (uint32_t n = 0; n < ret.cols; n++) { auto const& src = vSliceOrig(m, n); auto& dst = ret(m, n); +#if ENABLE_4BIT_KV_CACHE +#pragma unroll + for (uint32_t i = 0; i < mnEx; i++) { +#pragma unroll + for (uint32_t j = 0; j < kEx; j++) { + // Does not need PRMT, so use convertKCacheWordToF16 instead of + // convertVCacheWordToF16 + auto data = convertKCacheWordToF16(src.data[i][j]); + // Apply scaling factor to the data + InputElem2 scaledData0 = reinterpret_cast(data[0]) * + reinterpret_cast(vSfSlice[m][n][i][0]); + InputElem2 scaledData1 = reinterpret_cast(data[1]) * + reinterpret_cast(vSfSlice[m][n][i][1]); + data[0] = reinterpret_cast(scaledData0); + data[1] = reinterpret_cast(scaledData1); +#pragma unroll + for (uint32_t e = 0; e < cvtExpansion; e++) { + dst.data[i * cvtExpansion + j][e] = data[e]; + } + } + } +#else #pragma unroll for (uint32_t i = 0; i < mnEx; i++) { #pragma unroll @@ -1090,12 +1325,10 @@ __device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipX } } } +#endif } } return ret; - } else { - assert(!"not implemented"); - trap(); } }(); // compute @@ -1453,11 +1686,13 @@ CUBIN_EXPORT __global__ bool const isFullTile = (nbValidHeadTokens == warpTile.y); static_assert(nbQBuffers == 1); if (isFullTile) { - copyHeadsAsync( - warpIdx.x, smem.q[warpIdx.y][0], src, nbValidHeadTokens, localQHeadTokenIdxMap); + copyHeadsAsync(warpIdx.x, smem.q[warpIdx.y][0], src, + nbValidHeadTokens, localQHeadTokenIdxMap); } else { - copyHeadsAsync( - warpIdx.x, smem.q[warpIdx.y][0], src, nbValidHeadTokens, localQHeadTokenIdxMap); + copyHeadsAsync(warpIdx.x, smem.q[warpIdx.y][0], src, + nbValidHeadTokens, localQHeadTokenIdxMap); } ldgsts::barArrive(smem.qBarrier[warpIdx.y], true); @@ -1488,8 +1723,9 @@ CUBIN_EXPORT __global__ constexpr bool isFullTile = (nbValidRows == warpTile.y); static_assert(nbQBuffers == 1); - copyHeadsAsync(warpIdx.x, smem.q[warpIdx.y][0], src, nbValidRows, localQHeadIdxMap); + copyHeadsAsync(warpIdx.x, smem.q[warpIdx.y][0], src, + nbValidRows, localQHeadIdxMap); ldgsts::barArrive(smem.qBarrier[warpIdx.y], true); } #endif @@ -1533,6 +1769,11 @@ CUBIN_EXPORT __global__ auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& { return smem.k[warpIdx.x][idx]; }; +#if ENABLE_4BIT_KV_CACHE + auto const getSMemKSfTile = [&](uint32_t idx) -> SharedMem::KSfSmemBuffer& { + return smem.kSf[warpIdx.x][idx]; + }; +#endif #if BEAM_WIDTH > 1 auto loadCacheIndir = [&](uint32_t seqIter, uint32_t idxBeam) mutable { auto& dst = smem.gemm0CacheIndir[warpIdx.x]; @@ -1562,6 +1803,10 @@ CUBIN_EXPORT __global__ assert(seqIter % nbSubSeqPerSeq == seqIterInit % nbSubSeqPerSeq); auto const idxNextSMemKBuf = idxCurrSMemKBuf.next(); auto& dst = getSMemKTile(idxNextSMemKBuf); +#if ENABLE_4BIT_KV_CACHE + auto& dstSf = getSMemKSfTile(idxNextSMemKBuf); +#endif + uint32_t const dstHeadOffset = 0; uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x; uint32_t const tokenOffset = seqOffset % tokensPerPage; @@ -1570,6 +1815,12 @@ CUBIN_EXPORT __global__ HeadPtr const src{ cacheList.kCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, kv_stride_page, kv_stride_token, kv_stride_head}; +#if ENABLE_4BIT_KV_CACHE + HeadPtr const srcSf{ + cacheList.kSfCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, + kv_stride_page, kv_stride_token, kv_stride_head}; +#endif + #else IndexedHeadPtr const src{ /*indices=*/smem.gemm0CacheIndir[warpIdx.x].data, @@ -1580,6 +1831,11 @@ CUBIN_EXPORT __global__ /*stride_page=*/kv_stride_page, /*stride_token=*/kv_stride_token, /*stride_head=*/kv_stride_head}; + if constexpr (ENABLE_4BIT_KV_CACHE) { + // Not supported yet. + assert(!"not implemented"); + trap(); + } #endif // if (threadIdx.x == dbgPrintTid) { // printf("K: seqIter=%u, idxBeam=%u, idxPart=%u: pointers={%p, %p}, indices={", seqIter, @@ -1592,16 +1848,23 @@ CUBIN_EXPORT __global__ // } bool const isFullTile = (seqIter + 1 < nbSeqIters); if (isFullTile) { - copyPartialHeadsAsync( - warp, dst, dstHeadOffset, src, idxPart); + copyPartialHeadsAsync(warp, dst, dstHeadOffset, src, + idxPart); } else { uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset : 0U); // may also be full but it can be handled correctly anyway - copyPartialHeadsAsync( - warp, dst, dstHeadOffset, src, idxPart, nbHeadsAvail); + copyPartialHeadsAsync(warp, dst, dstHeadOffset, src, + idxPart, nbHeadsAvail); } +#if ENABLE_4BIT_KV_CACHE + copyPartialHeadsAsync(warp, dstSf, dstHeadOffset, srcSf, idxPart); +#endif + #if BEAM_WIDTH > 1 // to make sure all threads has finished usage of cache indir and pages __syncwarp(); @@ -1695,6 +1958,10 @@ CUBIN_EXPORT __global__ constexpr uint32_t qOffsetPerPart = exactDiv(elemsPerKHeadPart, inputElemsPerGrain); uint32_t const smemQOffset = qOffsetPerPart * p; SharedMem::KSmemBuffer const& smemKPart = getSMemKTile(idxCurrSMemKBuf); +#if ENABLE_4BIT_KV_CACHE + SharedMem::KSfSmemBuffer const& smemKSfPart = getSMemKSfTile(idxCurrSMemKBuf); +#endif + // #ifndef NDEGBUG // for (uint32_t i = 0; i < exactDiv(smemKPart.rows * smemKPart.cols, // warp_size); i++) { @@ -1706,7 +1973,12 @@ CUBIN_EXPORT __global__ // } // #endif // do computation. - smemQKPartGemm(warp, acc, smemQ, smemQOffset, smemKPart); + smemQKPartGemm(warp, acc, smemQ, smemQOffset, smemKPart +#if ENABLE_4BIT_KV_CACHE + , + smemKSfPart +#endif + ); idxCurrSMemKBuf++; } return acc; @@ -1783,7 +2055,12 @@ CUBIN_EXPORT __global__ initRowMax = smem.ctaRowMax[warpIdx.y][warpIdx.x].loadToReg(warp); #endif #endif + +#if ENABLE_4BIT_KV_CACHE + storeReorderedXTile(warp, smem.x[warpIdx.y][warpIdx.x], fp16Acc); +#else storeOrderedGemmOutTile(warp, smem.x[warpIdx.y][warpIdx.x], fp16Acc); +#endif smem.warpRowMax[warpIdx.y][warpIdx.x].storeFromReg(warp, regRowMax); smem.warpRowSum[warpIdx.y][warpIdx.x].storeFromReg(warp, regRowSum); unused(xBar.produced.arrive()); @@ -1815,6 +2092,12 @@ CUBIN_EXPORT __global__ auto const getSmemVTile = [&](uint32_t idx) -> SharedMem::VSmemBuffer& { return smem.v[warpGrpIdx][grpLoadV ? 0 : warpIdxInGrp][idx]; }; +#if ENABLE_4BIT_KV_CACHE + auto const getSmemVSfTile = [&](uint32_t idx) -> SharedMem::VSfSmemBuffer& { + return smem.vSf[warpGrpIdx][grpLoadV ? 0 : warpIdxInGrp][idx]; + }; +#endif + auto const getSmemVBar = [&](uint32_t idx) -> SharedMem::Barrier* { return smem.vBarrier(warpGrpIdx, idx); }; @@ -1828,6 +2111,10 @@ CUBIN_EXPORT __global__ getPage(cacheList, false, idxReq, idxBeam, idxPageBeg, nbPages); #else auto& dst = smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x]; +#if ENABLE_4BIT_KV_CACHE + static_assert(false, "4bit kv cache + beam search is not implemented"); +#endif + loadPagesForBeamSearchAsync( grpLoadV ? warpIdxInGrp : 0U, dst, cacheList, false, idxReq, idxPageBeg, nbPages); #endif @@ -1856,6 +2143,9 @@ CUBIN_EXPORT __global__ assert(seqIter % nbSubSeqPerSeq == seqIterInit % nbSubSeqPerSeq); auto const idxNextSMemVBuf = idxCurrSMemVBuf.next(); auto& dst = getSmemVTile(idxNextSMemVBuf); +#if ENABLE_4BIT_KV_CACHE + auto& dstSf = getSmemVSfTile(idxNextSMemVBuf); +#endif uint32_t const dstHeadOffset = 0; constexpr bool vSwizzle = true; @@ -1867,6 +2157,11 @@ CUBIN_EXPORT __global__ HeadPtr const src{ cacheList.vCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, kv_stride_page, kv_stride_token, kv_stride_head}; +#if ENABLE_4BIT_KV_CACHE + HeadPtr const srcSf{ + cacheList.vSfCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, + kv_stride_page, kv_stride_token, kv_stride_head}; +#endif #else IndexedHeadPtr const src{ /*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data, @@ -1877,6 +2172,11 @@ CUBIN_EXPORT __global__ /*stride_page=*/kv_stride_page, /*stride_token=*/kv_stride_token, /*stride_head=*/kv_stride_head}; + if constexpr (ENABLE_4BIT_KV_CACHE) { + // Not supported yet. + assert(!"not implemented"); + trap(); + } #endif // if (threadIdx.x == dbgPrintTid) { // printf("V: seqIter=%u, xIter=%u, idxBeam=%u, vIter=%u: pointers={%p, %p}, indices={", @@ -1895,8 +2195,13 @@ CUBIN_EXPORT __global__ : (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset : 0U); // may also be full but it can be handled correctly anyway - copyHeadsAsync( - warpIdxInGrp, dst, src, nbHeadsAvail); + copyHeadsAsync(warpIdxInGrp, dst, src, nbHeadsAvail); +#if ENABLE_4BIT_KV_CACHE + copyHeadsAsync(warpIdxInGrp, dstSf, srcSf, nbHeadsAvail); +#endif + #else uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen @@ -1904,16 +2209,23 @@ CUBIN_EXPORT __global__ : 0U); // may also be full but it can be handled correctly anyway bool const isFullTile = (seqIter + 1 < nbSeqIters); if (isFullTile) { - copyPartialHeadsAsync( - warp, dst, dstHeadOffset, src, warpIdxInGrp); + copyPartialHeadsAsync(warp, dst, dstHeadOffset, src, + warpIdxInGrp); } else { uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset : 0U); // may also be full but it can be handled correctly anyway - copyPartialHeadsAsync( + copyPartialHeadsAsync( warp, dst, dstHeadOffset, src, warpIdxInGrp, mha::min(nbHeadsAvail, cacheVTileSeqLen)); } +#if ENABLE_4BIT_KV_CACHE + copyPartialHeadsAsync(warp, dstSf, dstHeadOffset, srcSf, + warpIdxInGrp); +#endif #endif #if BEAM_WIDTH > 1 @@ -2127,14 +2439,25 @@ CUBIN_EXPORT __global__ } } auto const& smemVTile = getSmemVTile(idxCurrSMemVBuf); +#if ENABLE_4BIT_KV_CACHE + auto const& smemVSfPart = getSmemVSfTile(idxCurrSMemVBuf); +#endif + // do computation from shared memory X and V tiles #if BEAM_WIDTH == 1 smemXVPartGemm(warp, acc, skipXRowRescale, xRowNeedRescaleMask, xRowScales, - smemXTile, idxVTile, smemVTile, grpLoadV ? warpIdxInGrp : 0); + smemXTile, idxVTile, smemVTile, +#if ENABLE_4BIT_KV_CACHE + smemVSfPart, +#endif + grpLoadV ? warpIdxInGrp : 0); #else WarpAcc tmpAcc{}; smemXVPartGemm(warp, tmpAcc, skipXRowRescale, xRowNeedRescaleMask, xRowScales, smemXTile, idxVTile, smemVTile, +#if ENABLE_4BIT_KV_CACHE + smemVSfPart, +#endif grpLoadV ? warpIdxInGrp : 0); pickAccRowsForBeamSearch(warp, acc, tmpAcc, isConvergedTile(seqIter), idxBeam, [](float& d, float s) { d += s; }); @@ -2227,8 +2550,12 @@ CUBIN_EXPORT __global__ }; // merge results from different warp groups - SharedMem::XSmemBuffer* smemOutTile = - mergeAndSaveOutTile(outTile, inputElemSize == 2 && cacheElemSize == 1); +#if ENABLE_4BIT_KV_CACHE + bool reorderOutRows = false; +#else + bool reorderOutRows = inputElemSize == 2 && cacheElemSize == 1; +#endif + SharedMem::XSmemBuffer* smemOutTile = mergeAndSaveOutTile(outTile, reorderOutRows); if (isMultiBlock) { static_assert(ctaShapeInWarps.y == 1, "not implemented"); #if SPEC_DEC @@ -2494,6 +2821,9 @@ void launchMHA( #endif float const* attentionSinks, // [headGrpSize] GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM, +#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. @@ -2551,8 +2881,11 @@ void launchMHA( dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z}; auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); - KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, - maxNbPagesPerSeq}; + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + kSfCacheVLLM, vSfCacheVLLM, +#endif + kvCachePageList, seqLen, maxNbPagesPerSeq}; // Convert stride from elements to Heads uint32_t const stride_page_in_heads = static_cast(kv_stride_page / validElemsPerHead); uint32_t const stride_token_in_heads = static_cast(kv_stride_token / validElemsPerHead); @@ -2601,9 +2934,13 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, - GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, float const* kvScalePtr, + GMemCacheHead* vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM, +#endif + KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif @@ -2626,12 +2963,20 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z}; auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); - KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, - maxNbPagesPerSeq}; + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + kSfCacheVLLM, vSfCacheVLLM, +#endif + kvCachePageList, seqLen, maxNbPagesPerSeq}; // Convert stride from elements to Heads - uint32_t const stride_page_in_heads = static_cast(kv_stride_page / validElemsPerHead); - uint32_t const stride_token_in_heads = static_cast(kv_stride_token / validElemsPerHead); - uint32_t const stride_head_in_heads = static_cast(kv_stride_head / validElemsPerHead); + uint32_t const container_elems_per_head = + validElemsPerHead / CacheElemConverter::ElemsPerContainer; + uint32_t const stride_page_in_heads = + static_cast(kv_stride_page / container_elems_per_head); + uint32_t const stride_token_in_heads = + static_cast(kv_stride_token / container_elems_per_head); + uint32_t const stride_head_in_heads = + static_cast(kv_stride_head / container_elems_per_head); cudaLaunchKernelEx(&launchCfg, kernel_mha, #if SPEC_DEC diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index d7ab1c452c..14e5b5adce 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -24,7 +24,8 @@ #if SPEC_DEC #include "specDec.h" #endif -using CacheElem = ElemType; +using CacheElemConverter = ElemTypeConverter; +using CacheElem = CacheElemConverter::Type; constexpr uint32_t validElemsPerHead = HEAD_ELEMS; constexpr bool isMLA = IS_MLA; static_assert((isMLA || validElemsPerHead <= 256) && @@ -56,7 +57,12 @@ constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE; using IOHead = Vec; using InputHead = IOHead; -using GMemCacheHead = Vec; +using GMemCacheHead = Vec; +#if ENABLE_4BIT_KV_CACHE +using GMemCacheHeadSf = Vec; +#endif constexpr uint32_t validElemsPerKHead = validElemsPerHead; constexpr bool lowPrecOutput = LOW_PREC_OUTPUT; @@ -72,7 +78,13 @@ using OutputHead = mha::conditional_t; using OutputElem = OutputHead::Elem; using PaddedInputHead = Vec; -using PaddedCacheHead = Vec; +// For 4 bit KV cache, each 16 elements (64b) are padded with 64b to match 128b banks. +using PaddedCacheHead = Vec; + +#if ENABLE_4BIT_KV_CACHE +using PaddedCacheHeadSf = + Vec; +#endif // impl detail, may be moved to mha.cu/mha_sm90.cu constexpr bool isHeadPadded = (validElemsPerHead != headElems); @@ -112,6 +124,10 @@ void launchMHA( #endif float const* attentionSinks, // [headGrpSize] GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM, +#endif + KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] @@ -133,9 +149,13 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, - GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, float const* kvScalePtr, + GMemCacheHead* vCacheVLLM, +#if ENABLE_4BIT_KV_CACHE + GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM, +#endif + KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif diff --git a/csrc/xqa/mhaUtils.cuh b/csrc/xqa/mhaUtils.cuh index 7fd53f9344..98e2801f5b 100644 --- a/csrc/xqa/mhaUtils.cuh +++ b/csrc/xqa/mhaUtils.cuh @@ -98,12 +98,14 @@ struct HeadPtr : TinyPtr {}; // #endif // @fixme: give evict first hint for last part. -template __device__ inline void copyPartialHeadsAsync( Warp const& warp, - Array2D& dst, + Array2D<_LdGrain, dstNbHeads, exactDiv(exactDiv(sizeof(Head), nbPartsPerHead), grainBytesSmem)>& + dst, uint32_t dstHeadOffset, SrcHeadPtr const& src, uint32_t idxPart, uint32_t nbAvailHeads = maxNbCopiedHeads, LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) { @@ -117,41 +119,46 @@ __device__ inline void copyPartialHeadsAsync( constexpr uint32_t warpLdBytes = partBytes * maxNbCopiedHeads; constexpr uint32_t thrdLdBytes = exactDiv(warpLdBytes, warp_size); assertIsPowerOf2(); - static_assert(thrdLdBytes >= grainBytes); + static_assert(thrdLdBytes >= grainBytesSmem); // a segment is responsible for loading one partial head collaboratively - constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytes); + constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytesSmem); static_assert(thrdsPerSeg > 0 && thrdsPerSeg <= warp_size); assertIsPowerOf2(); assert(__shfl_sync(0xFU << (laneId() / 4 * 4), src.offset, 0, 4) == src.offset); auto const warpLane = laneId(); uint32_t const segIdx = warpLane / thrdsPerSeg; uint32_t const segLane = warpLane % thrdsPerSeg; - constexpr uint32_t partsPerWarpInst = exactDiv(grainBytes * warp_size, partBytes); + constexpr uint32_t partsPerWarpInst = exactDiv(grainBytesSmem * warp_size, partBytes); #pragma unroll - for (uint32_t i = 0; i < thrdLdBytes / grainBytes; i++) { + for (uint32_t i = 0; i < thrdLdBytes / grainBytesSmem; i++) { uint32_t const idxHeadLocal = partsPerWarpInst * i + segIdx; assert(idxHeadLocal < maxNbCopiedHeads); bool const isHeadInBound = isFull || (idxHeadLocal < nbAvailHeads); - constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytes); + constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytesSmem); using SrcHead = mha::decay_t; - constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytes); + constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytesGmem); uint32_t const idxGrainInsideHead = grainsPerPart * idxPart + segLane; bool const isGrainInBound = (!isHeadPadded || idxGrainInsideHead < nbValidGrains); SrcHead const* const pSrcHead = src + localHeadIdxMap(idxHeadLocal); bool const isValidPage = (pSrcHead != nullptr); - LdGrain const* const pSrc = reinterpret_cast(pSrcHead) + idxGrainInsideHead; - LdGrain* const pDst = &dst.template at(dstHeadOffset + idxHeadLocal, segLane); + Vec const* const pSrc = + reinterpret_cast const*>(pSrcHead) + idxGrainInsideHead; + Vec* const pDst = reinterpret_cast*>( + &dst.template at(dstHeadOffset + idxHeadLocal, segLane)); +#if !ENABLE_4BIT_KV_CACHE + // 4-bit KV cache is not bank-conflict free now. assert(!hasBankConflict(pDst)); - ldgsts::copyAsync(pDst, pSrc, - isValidPage && isHeadInBound && isGrainInBound ? grainBytes : 0u); +#endif + ldgsts::copyAsync( + pDst, pSrc, isValidPage && isHeadInBound && isGrainInBound ? grainBytesGmem : 0u); } } -template +template __device__ inline void copyHeadsAsync( - uint32_t idxWarp, Array2D& dst, + uint32_t idxWarp, Array2D<_LdGrain, dstNbHeads, exactDiv(sizeof(Head), grainBytesSmem)>& dst, SrcHeadPtr const& src, uint32_t nbAvailHeads = maxNbCopiedHeads, LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) { assert(idxWarp < nbWarps); @@ -161,9 +168,9 @@ __device__ inline void copyHeadsAsync( uint32_t const warpNbAvailHeads = (dstHeadOffset < nbAvailHeads ? nbAvailHeads - dstHeadOffset : 0); constexpr uint32_t idxPart = 0; - copyPartialHeadsAsync( - warp, dst, dstHeadOffset, src, idxPart, warpNbAvailHeads, - [&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); }); + copyPartialHeadsAsync(warp, dst, dstHeadOffset, src, idxPart, warpNbAvailHeads, + [&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); }); } template @@ -235,6 +242,10 @@ template <> struct KVCacheList { GMemCacheHead* kCacheVLLM; GMemCacheHead* vCacheVLLM; +#if ENABLE_4BIT_KV_CACHE + GMemCacheHeadSf* kSfCacheVLLM; + GMemCacheHeadSf* vSfCacheVLLM; +#endif KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility) diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index a9ac1805b9..e8b56af0bb 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -552,6 +552,70 @@ __device__ inline Vec ldmatrix_16x16_trans(LdGrain const* r } } +template +__device__ inline Vec ldmatrix_16x16_trans_unpack_4b(LdGrain const* row) { +#if __CUDA_ARCH__ >= 1000 + uint32_t a, b, c, d; + if constexpr (nbMat == 1) { + asm("ldmatrix.sync.aligned.m16n16.x1.trans.shared::cta.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(a), "=r"(b) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b}; + } else if constexpr (nbMat == 2) { + asm("ldmatrix.sync.aligned.m16n16.x2.trans.shared::cta.b8x16.b4x16_p64 {%0, %1, %2, %3}, " + "[%4];\n" + : "=r"(a), "=r"(b), "=r"(c), "=r"(d) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b, c, d}; + } else { + static_assert(nbMat == 1 || nbMat == 2); + } +#else + trap(); +#endif +} + +template +__device__ inline Vec ldmatrix_8x16_4x_unpack_4b(LdGrain const* row) { +#if __CUDA_ARCH__ >= 1000 + uint32_t a, b, c, d; + if constexpr (nbMat == 4) { + asm("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" + : "=r"(a), "=r"(b), "=r"(c), "=r"(d) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b, c, d}; + } else if constexpr (nbMat == 2) { + asm("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(a), "=r"(b) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b}; + } else if constexpr (nbMat == 1) { + asm("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(a) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a}; + } else { + static_assert(nbMat == 1 || nbMat == 2 || nbMat == 4); + } +#else + trap(); +#endif +} + +template +__device__ inline Vec ldmatrix_4x_unpack_4b(Warp const& warp, LdGrain const* row) { + if constexpr (transpose) { + return ldmatrix_16x16_trans_unpack_4b<2>(row); + } else { + return ldmatrix_8x16_4x_unpack_4b<4>(row); + } +} + template __device__ inline void stmatrix(LdGrain* row, Vec const& data) { #if __CUDA_ARCH__ >= 900 @@ -688,6 +752,73 @@ __device__ inline Vec convertKCacheWordToF16(uint32_t i8data) { return ret; } +#if ENABLE_4BIT_KV_CACHE +template <> +__device__ inline Vec convertKCacheWordToF16(uint32_t i8data) { + Vec ret; +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t src = i8data | (i8data >> 4); + uint32_t(&dst)[2] = reinterpret_cast(ret); + asm("{\n" + ".reg .b8 byte0, byte2;\n" + "mov.b32 {byte0, _, byte2, _}, %2;\n" + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %1, byte2;\n" + "}" + : "=r"(dst[0]), "=r"(dst[1]) + : "r"(src)); +#else + assert(!"need arch >= 1000"); + trap(); +#endif + return ret; +} + +template <> +__device__ inline Vec convertKCacheWordToF16<__nv_bfloat16, __nv_fp4_e2m1>( + uint32_t i8data) { + Vec ret; + // This needs CUDA Toolkit version >= 13.2 +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 13) && \ + (defined __CUDACC_VER_MINOR__) && (__CUDACC_VER_MINOR__ >= 2) + uint32_t src = i8data | (i8data >> 4); + uint32_t(&dst)[2] = reinterpret_cast(ret); + asm("{\n" + ".reg .b8 byte0, byte2;\n" + "mov.b32 {byte0, _, byte2, _}, %2;\n" + "cvt.rn.bf16x2.e2m1x2 %0, byte0;\n" + "cvt.rn.bf16x2.e2m1x2 %1, byte2;\n" + "}" + : "=r"(dst[0]), "=r"(dst[1]) + : "r"(src)); +#else + // Fallback: convert e2m1 -> fp16 -> bf16 + uint32_t src = i8data | (i8data >> 4); + __half halfData[4]; + uint32_t(&dst)[2] = reinterpret_cast(halfData); + asm("{\n" + ".reg .b8 byte0, byte2;\n" + "mov.b32 {byte0, _, byte2, _}, %2;\n" + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %1, byte2;\n" + "}" + : "=r"(dst[0]), "=r"(dst[1]) + : "r"(src)); + auto bf16Data = reinterpret_cast<__nv_bfloat16(&)[4]>(ret); +#pragma unroll + for (uint32_t ii = 0; ii < 4; ii++) { + bf16Data[ii] = __nv_bfloat16(halfData[ii]); + } +#endif +#else + assert(!"need arch >= 1000"); + trap(); +#endif + return ret; +} +#endif + template __device__ inline Vec convertVCacheWordToF16(uint32_t i8data) { static_assert(mha::is_same_v || mha::is_same_v, @@ -726,6 +857,32 @@ __device__ inline Vec convertVCacheWordToF16(uint32_t i8data) { return ret; } +template +__device__ inline uint32_t applyF16ScalingFactor(uint32_t x, uint16_t sf) { + // Broadcasts sf to both lanes and multiplies: + // (o0, o1) = (x0, x1) * (sf, sf) + uint32_t ret; + if constexpr (mha::is_same_v) { + asm("{\n" + ".reg .b32 sf2;\n" + "mov.b32 sf2, {%2, %2};\n" + "mul.rn.f16x2 %0, %1, sf2;\n" + "}" + : "=r"(ret) + : "r"(x), "h"(sf)); + } else { + static_assert(mha::is_same_v); + asm("{\n" + ".reg .b32 sf2;\n" + "mov.b32 sf2, {%2, %2};\n" + "mul.rn.bf16x2 %0, %1, sf2;\n" + "}" + : "=r"(ret) + : "r"(x), "h"(sf)); + } + return ret; +} + struct PermuteOrder { uint16_t x0 : 4; uint16_t x1 : 4; diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 560882bd9f..09c707d2a6 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -54,6 +54,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK int64_t slidingWinSize, double qScale, Optional qScaleTensor, TensorView output, double rcpOutScale, TensorView q, Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, + Optional kSfCacheVLLM, Optional vSfCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, Optional kvScaleTensor, int64_t qSeqLen, Optional mask, TensorView semaphores, @@ -84,6 +85,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK mask.has_value() ? reinterpret_cast(mask.value().data_ptr()) : nullptr; #endif + void* kSfCachePtr = kSfCacheVLLM.has_value() ? kSfCacheVLLM.value().data_ptr() : nullptr; + void* vSfCachePtr = vSfCacheVLLM.has_value() ? vSfCacheVLLM.value().data_ptr() : nullptr; + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, qScalePtr, reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT @@ -92,6 +96,10 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK reinterpret_cast(q.data_ptr()), attentionSinksPtr, reinterpret_cast(kCacheVLLM.data_ptr()), reinterpret_cast(vCacheVLLM.data_ptr()), +#if ENABLE_4BIT_KV_CACHE + reinterpret_cast(kSfCachePtr), + reinterpret_cast(vSfCachePtr), +#endif reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, kvCacheScale, kvScalePtr, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index eb901c4f5b..450c082398 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2150,6 +2150,7 @@ def trtllm_batch_decode_with_kv_cache( max_q_len: Optional[int] = None, cum_seq_lens_q: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2284,6 +2285,7 @@ def trtllm_batch_decode_with_kv_cache( return xqa_batch_decode_with_kv_cache( query=query, kv_cache=(k_cache, v_cache), + kv_cache_sf=kv_cache_sf, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, @@ -2451,6 +2453,9 @@ def xqa_batch_decode_with_kv_cache( q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, mask: Optional[torch.Tensor] = None, + kv_cache_sf: Union[ + torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> torch.Tensor: """ Parameters @@ -2505,6 +2510,9 @@ def xqa_batch_decode_with_kv_cache( mask : Optional[torch.Tensor] = None causal attention mask for xqa speculative decoding. + kv_cache_sf : Optional[torch.Tensor] = None + KV cache scaling factors. Must provide when NVFP4 KV cache is used. + Returns ------- out : torch.Tensor @@ -2525,6 +2533,19 @@ def xqa_batch_decode_with_kv_cache( # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) + k_cache_sf = None + v_cache_sf = None + if kv_cache_sf is not None: + if isinstance(kv_cache_sf, tuple): + k_cache_sf, v_cache_sf = kv_cache_sf + else: + assert kv_cache_sf.shape[1] == 2, ( + "When kv_cache is a single tensor, the second dimension must be 1 or 2" + ) + # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) + # it doesn't change underlying storage + k_cache_sf, v_cache_sf = kv_cache_sf.unbind(dim=1) + sm_count = get_device_sm_count(query.device) # Extract shape parameters based on layout @@ -2561,6 +2582,8 @@ def xqa_batch_decode_with_kv_cache( query_new, k_cache, v_cache, + k_cache_sf, + v_cache_sf, block_tables, seq_lens_new, out_4d, diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 7a2e0bde6f..e2750cfd7b 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -865,7 +865,7 @@ def nvfp4_quantize( a_global_sf.cuda(), sf_vec_size, sf_use_ue8m0=False, - is_sf_swizzled_layout=True, + is_sf_swizzled_layout=sfLayout != SfLayout.layout_linear, is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, enable_pdl=enable_pdl, ) diff --git a/flashinfer/jit/xqa.py b/flashinfer/jit/xqa.py index cb9d823708..110131e197 100644 --- a/flashinfer/jit/xqa.py +++ b/flashinfer/jit/xqa.py @@ -54,6 +54,8 @@ def gen_xqa_module( flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=2"] elif kv_cache_dtype == torch.int8: flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=1"] + elif kv_cache_dtype == torch.uint8: + flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=3"] else: flag_kv_cache_dtype = ["-DCACHE_ELEM_ENUM=0"] diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 176c4bd013..4d21714cf2 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -74,6 +74,8 @@ def xqa( sinks: Optional[torch.Tensor], k_cache: torch.Tensor, v_cache: torch.Tensor, + k_sf_cache: Optional[torch.Tensor], + v_sf_cache: Optional[torch.Tensor], page_table: torch.Tensor, max_seq_len: int, seq_lens: torch.Tensor, @@ -98,6 +100,8 @@ def xqa( sinks, k_cache, v_cache, + k_sf_cache, + v_sf_cache, page_table, max_seq_len, seq_lens, @@ -126,6 +130,8 @@ def _fake_xqa( sinks: Optional[torch.Tensor], k_cache: torch.Tensor, v_cache: torch.Tensor, + k_sf_cache: Optional[torch.Tensor], + v_sf_cache: Optional[torch.Tensor], page_table: torch.Tensor, max_seq_len: int, seq_lens: torch.Tensor, @@ -149,6 +155,8 @@ def xqa( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, + k_sf_cache: Optional[torch.Tensor], + v_sf_cache: Optional[torch.Tensor], page_table: torch.Tensor, seq_lens: torch.Tensor, output: torch.Tensor, @@ -179,12 +187,18 @@ def xqa( Paged K cache tensor with shape ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. - Should be the same data type as v_cache. + Should be the same data type as v_cache. When using NVFP4 KV, the data type is torch.uint8, and the last dimension should be `head_dim / 2`. v_cache: torch.Tensor Paged V cache tensor with shape ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. - Should be the same data type as k_cache. + Should be the same data type as k_cache. When using NVFP4 KV, the data type is torch.uint8, and the last dimension should be `head_dim / 2`. + k_sf_cache: Optional[torch.Tensor] + Optional scale factor cache tensor for the K cache. Use when NVFP4 KV is used. Expected shape is ``[num_pages, page_size, num_kv_heads, head_dim / 16]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim / 16]`` if :attr:`kv_layout` is ``HND``. Should be the same data type as v_sf_cache. Data type should be torch.uint8. + v_sf_cache: Optional[torch.Tensor] + Optional scale factor cache tensor for the V cache. Use when NVFP4 KV is used. Expected shape is ``[num_pages, page_size, num_kv_heads, head_dim / 16]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim / 16]`` if :attr:`kv_layout` is ``HND``. Should be the same data type as k_sf_cache. Data type should be torch.uint8. page_table : torch.Tensor Page table tensor with shape ``batch_size, nb_pages_per_seq``. Data type should be torch.int32. @@ -279,7 +293,10 @@ def xqa( # For HND: [..., H, N, D] -> NHD: [..., N, H, D] k_cache = k_cache.transpose(-3, -2) v_cache = v_cache.transpose(-3, -2) - + if k_sf_cache is not None: + k_sf_cache = k_sf_cache.transpose(-3, -2) + if v_sf_cache is not None: + v_sf_cache = v_sf_cache.transpose(-3, -2) if ( k_cache.dtype == torch.float8_e4m3fn and get_compute_capability(torch.device(device="cuda"))[0] == 9 @@ -288,6 +305,13 @@ def xqa( else: run_sm90_fp8_mha = False + if k_cache.dtype == torch.uint8: + assert get_compute_capability(torch.device(device="cuda"))[0] in [12], ( + "XQA NVFP4 KV is only supported on SM120 GPUs" + ) + assert k_sf_cache is not None, "K SF cache is required when NVFP4 KV is used" + assert v_sf_cache is not None, "V SF cache is required when NVFP4 KV is used" + if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: raise RuntimeError("XQA is only supported on SM90, SM100, SM120/SM121 GPUs") @@ -319,6 +343,8 @@ def xqa( sinks, k_cache, v_cache, + k_sf_cache, + v_sf_cache, page_table, max_seq_len, seq_lens, diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 0194611dfb..3fac89f726 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -346,6 +346,8 @@ def test_xqa( q_heads, cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads, cache_v_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_v_heads, + None, + None, page_list_arg, seq_len_list, output, diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index 2c51510cc5..19a218826d 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -4,6 +4,11 @@ import flashinfer from flashinfer.utils import get_compute_capability +from flashinfer.fp4_quantization import ( + SfLayout, + nvfp4_quantize, + e2m1_and_ufp8sf_scale_to_float, +) DTYPE_MAP = { "fp16": torch.float16, @@ -27,6 +32,27 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() +def to_nvfp4(x): + # Get the amax + min_val, max_val = x.float().aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + # The global scale, which is amax / (448. * 6.) + global_scale = amax / (448.0 * 6.0) + global_scale_inv = 1.0 / global_scale + # Global sf is 1 for now. + val, sf = nvfp4_quantize(x, global_scale_inv, sfLayout=SfLayout.layout_linear) + return val, sf.reshape(*val.shape[:-1], sf.shape[-1]), global_scale + + +def nvfp4_to_float(x, sf, global_sf): + x_flatten = x.reshape(-1, x.shape[-1]) + sf_flatten = sf.reshape(-1, sf.shape[-1]) + x_dq_flatten = e2m1_and_ufp8sf_scale_to_float( + x_flatten, sf_flatten, global_sf, sf_vec_size=16, is_sf_swizzled_layout=False + ) + return x_dq_flatten.reshape(*x.shape[:-1], -1).to(GPU_DEVICE) + + def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len): q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) @@ -78,7 +104,7 @@ def create_kv_cache( num_pages_per_seq = (max_seq_len + page_size - 1) // page_size num_pages = num_pages_per_seq * batch_size ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] - if kv_dtype != "fp8": + if kv_dtype != "fp8" and kv_dtype != "nvfp4": assert kv_dtype == ref_kv_dtype, ( "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" ) @@ -121,6 +147,8 @@ def create_kv_cache( device=GPU_DEVICE, ) + k_global_scale = None + v_global_scale = None # Convert K and V separately to fp8 if needed if kv_dtype == "fp8": k_cache, k_scale = to_float8(k_cache / 4.0) @@ -133,13 +161,25 @@ def create_kv_cache( ], dim=1, ) + elif kv_dtype == "nvfp4": + k_cache, k_scale, k_global_scale = to_nvfp4(k_cache / 4.0) + v_cache, v_scale, v_global_scale = to_nvfp4(v_cache / 4.0) + k_cache_dq = nvfp4_to_float(k_cache, k_scale, k_global_scale) + v_cache_dq = nvfp4_to_float(v_cache, v_scale, v_global_scale) + ref_kv_cache = torch.stack( + [ + k_cache_dq.to(ref_kv_dtype_torch), + v_cache_dq.to(ref_kv_dtype_torch), + ], + dim=1, + ) else: k_scale = v_scale = 1.0 ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) # Combine K and V into interleaved format for the API kv_cache = torch.stack([k_cache, v_cache], dim=1) - return kv_cache, k_scale, v_scale, ref_kv_cache + return kv_cache, k_scale, v_scale, k_global_scale, v_global_scale, ref_kv_cache def create_page_table(batch_size, seq_lens, page_size): @@ -417,7 +457,7 @@ def test_xqa_batch_decode( q_indptr = generate_cumsum_lens(q_lens) # Create KV cache and related data - kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( + kv_cache, k_scale, v_scale, _, _, ref_kv_cache = create_kv_cache( batch_size, seq_lens, page_size, @@ -539,6 +579,193 @@ def test_xqa_batch_decode( ) +@pytest.mark.skipif( + get_compute_capability(torch.device(device="cuda"))[0] not in [12], + reason="XQA with NVFP4 KV is only supported on SM120 GPUs", +) +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 4, 64, 4, 2), + (1, 1, 64, 2, 4), + (1, 1, 64, 2, 8), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("fp16", "nvfp4", "fp16"), + ("bf16", "nvfp4", "bf16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [False]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +def test_xqa_batch_decode_nvfp4_kv( + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + kv_layout, +): + """Test xqa_batch_decode_with_kv_cache function. + + This test supports both NHD and HND layouts. + """ + + # Set up test parameters + torch.manual_seed(0) + head_dim = 256 + + # Generate random sequence lengths + num_qo_heads = num_kv_heads * head_grp_size + q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( + batch_size, q_len_per_req, max_in_kv_len + ) + + # Create query tensor and related data + q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) + q_indptr = generate_cumsum_lens(q_lens) + + # Create KV cache and related data + kv_cache, k_scale, v_scale, k_global_scale, v_global_scale, ref_kv_cache = ( + create_kv_cache( + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + "bf16" if q_dtype == "fp8" else q_dtype, + kv_layout, + ) + ) + + kv_cache_sf = torch.stack([k_scale, v_scale], dim=1) + + page_table, all_page_ids, page_per_seq = create_page_table( + batch_size, seq_lens, page_size + ) + kv_indptr = generate_cumsum_lens(page_per_seq) + kv_last_page_len = get_last_page_len(seq_lens, page_size) + + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE) + + # Create output tensor and related data + out, o_scale = create_output(q, o_dtype) + + sm_scale = float(1.0 / (head_dim**0.5)) + + # Build reference output + plan_params = { + "indptr": kv_indptr, + "indices": all_page_ids, + "last_page_len": kv_last_page_len.to(GPU_DEVICE), + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "page_size": page_size, + "pos_encoding_mode": "NONE", + "kv_data_type": ref_kv_cache.dtype, + "q_data_type": ref_q.dtype, + "window_left": window_left, + } + if not enable_sink: + if q_len_per_req == 1: + wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer_ref, kv_layout, use_tensor_cores=True + ) + wrapper_ref.plan(**plan_params) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + else: + # speculative decoding test + wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer_ref, kv_layout + ) + plan_params_prefill = plan_params.copy() + plan_params_prefill.update( + { + "qo_indptr": q_indptr, + "paged_kv_indptr": plan_params_prefill.pop("indptr"), + "paged_kv_indices": plan_params_prefill.pop("indices"), + "paged_kv_last_page_len": plan_params_prefill.pop("last_page_len"), + "head_dim_qk": plan_params_prefill.pop("head_dim"), + "causal": True, + "logits_soft_cap": 0.0, + } + ) + wrapper_ref.plan(**plan_params_prefill) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + else: + # Construct flat K/V via helper + k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( + ref_kv_cache, + page_table, + seq_lens.to(GPU_DEVICE), + page_size, + kv_last_page_len, + kv_layout, + ) + sink = torch.rand(num_qo_heads, device=GPU_DEVICE, dtype=torch.float32) * 5 + output_ref = sink_attention_unified( + ref_q, + k_flat, + v_flat, + sink, + window_left, + True, + sm_scale, + mode="varlen", + batch_size=batch_size, + qo_indptr=q_indptr, + kv_indptr=kv_indptr_tokens, + ) + + if q_len_per_req > 1: + mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE) + else: + mask = None + + # Run xqa_batch_decode_with_kv_cache function + output = flashinfer.decode.xqa_batch_decode_with_kv_cache( + q.contiguous(), + kv_cache, + workspace_buffer, + page_table, + seq_lens.to(GPU_DEVICE), + torch.max(seq_lens).item(), + q_scale * k_global_scale * sm_scale, # bmm1_scale + v_global_scale / o_scale, # bmm2_scale + window_left, # window_left + out=out, + enable_pdl=enable_pdl, + sinks=(sink if enable_sink else None), + kv_layout=kv_layout, + q_len_per_req=q_len_per_req, + o_scale=o_scale, + mask=mask, + kv_cache_sf=kv_cache_sf, + ) + + # Verification + torch.testing.assert_close( + output.float(), + output_ref.float() / o_scale, + rtol=1e-1, + atol=1e-1, + ) + + if __name__ == "__main__": # Run a simple test case test_xqa_batch_decode(