Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/flashinfer_xqa_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
int64_t slidingWinSize, double qScale, tvm::ffi::Optional<TensorView> qScaleTensor,
TensorView output, double rcpOutScale, TensorView q,
tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
TensorView seqLen, int64_t batchSize, double kvCacheScale,
TensorView vCacheVLLM, tvm::ffi::Optional<TensorView> kSfCacheVLLM,
tvm::ffi::Optional<TensorView> vSfCacheVLLM, TensorView kvCachePageList,
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale,
tvm::ffi::Optional<TensorView> kvScaleTensor, int64_t qSeqLen,
tvm::ffi::Optional<TensorView> mask, TensorView semaphores, TensorView scratch,
bool enable_pdl);
Expand Down
57 changes: 53 additions & 4 deletions csrc/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define CACHE_ELEM_ENUM 2
#endif

#if CACHE_ELEM_ENUM == 3
#define ENABLE_4BIT_KV_CACHE 1
#else
#define ENABLE_4BIT_KV_CACHE 0
#endif

// don't modify
#define USE_KV_CACHE true

Expand Down Expand Up @@ -181,8 +187,51 @@ static_assert(CACHE_ELEM_ENUM != 0);

#include <cuda_fp16.h>
#include <cuda_fp8.h>

#if ENABLE_4BIT_KV_CACHE
#include <cuda_fp4.h>
#endif

template <int32_t elemTypeEnum>
using ElemType = mha::conditional_t<
elemTypeEnum == 0, INPUT_ELEM,
mha::conditional_t<elemTypeEnum == 1, int8_t,
mha::conditional_t<elemTypeEnum == 2, __nv_fp8_e4m3, void>>>;
struct ElemTypeConverter;
// Specialization for elemTypeEnum = 0 (half/bf16)
template <>
struct ElemTypeConverter<0> {
using Type = INPUT_ELEM;
using ContainerType = INPUT_ELEM;
static constexpr int ElemsPerContainer = 1;
using ScalingFactorType = void;
static constexpr int QuantVectorSize = 1;
};

// Specialization for elemTypeEnum = 1 (int8)
template <>
struct ElemTypeConverter<1> {
using Type = int8_t;
using ContainerType = int8_t;
static constexpr int ElemsPerContainer = 1;
using ScalingFactorType = void;
static constexpr int QuantVectorSize = 1;
};

// Specialization for elemTypeEnum = 2 (fp8)
template <>
struct ElemTypeConverter<2> {
using Type = __nv_fp8_e4m3;
using ContainerType = __nv_fp8_e4m3;
static constexpr int ElemsPerContainer = 1;
using ScalingFactorType = void;
static constexpr int QuantVectorSize = 1;
};

#if ENABLE_4BIT_KV_CACHE
// Specialization for elemTypeEnum = 3 (NVFP4)
template <>
struct ElemTypeConverter<3> {
using Type = __nv_fp4_e2m1;
using ContainerType = __nv_fp4x2_e2m1;
static constexpr int ElemsPerContainer = 2;
using ScalingFactorType = __nv_fp8_e4m3;
static constexpr int QuantVectorSize = 16;
};
#endif
447 changes: 396 additions & 51 deletions csrc/xqa/mha.cu

Large diffs are not rendered by default.

32 changes: 26 additions & 6 deletions csrc/xqa/mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
#if SPEC_DEC
#include "specDec.h"
#endif
using CacheElem = ElemType<CACHE_ELEM_ENUM>;
using CacheElemConverter = ElemTypeConverter<CACHE_ELEM_ENUM>;
using CacheElem = CacheElemConverter::Type;
constexpr uint32_t validElemsPerHead = HEAD_ELEMS;
constexpr bool isMLA = IS_MLA;
static_assert((isMLA || validElemsPerHead <= 256) &&
Expand Down Expand Up @@ -56,7 +57,12 @@ constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE;

using IOHead = Vec<InputElem, validElemsPerHead>;
using InputHead = IOHead;
using GMemCacheHead = Vec<CacheElem, validElemsPerHead>;
using GMemCacheHead = Vec<CacheElemConverter::ContainerType,
exactDiv(validElemsPerHead, CacheElemConverter::ElemsPerContainer)>;
#if ENABLE_4BIT_KV_CACHE
using GMemCacheHeadSf = Vec<CacheElemConverter::ScalingFactorType,
exactDiv(validElemsPerHead, CacheElemConverter::QuantVectorSize)>;
#endif

constexpr uint32_t validElemsPerKHead = validElemsPerHead;
constexpr bool lowPrecOutput = LOW_PREC_OUTPUT;
Expand All @@ -72,7 +78,13 @@ using OutputHead = mha::conditional_t<lowPrecOutput, GMemCacheHead, InputHead>;
using OutputElem = OutputHead::Elem;

using PaddedInputHead = Vec<InputElem, headElems>;
using PaddedCacheHead = Vec<CacheElem, headElems>;
// For 4 bit KV cache, each 16 elements (64b) are padded with 64b to match 128b banks.
using PaddedCacheHead = Vec<CacheElemConverter::ContainerType, headElems>;

#if ENABLE_4BIT_KV_CACHE
using PaddedCacheHeadSf =
Vec<CacheElemConverter::ScalingFactorType, headElems / CacheElemConverter::QuantVectorSize>;
#endif

// impl detail, may be moved to mha.cu/mha_sm90.cu
constexpr bool isHeadPadded = (validElemsPerHead != headElems);
Expand Down Expand Up @@ -112,6 +124,10 @@ void launchMHA(
#endif
float const* attentionSinks, // [headGrpSize]
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
#if ENABLE_4BIT_KV_CACHE
GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM,
#endif

KVCachePageIndex const*
kvCachePageList, // device pointer. shape:
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
Expand All @@ -133,9 +149,13 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
float rcpOutScale,
#endif
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float kvCacheScale, float const* kvScalePtr,
GMemCacheHead* vCacheVLLM,
#if ENABLE_4BIT_KV_CACHE
GMemCacheHeadSf* kSfCacheVLLM, GMemCacheHeadSf* vSfCacheVLLM,
#endif
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
float const* kvScalePtr,
#if SPEC_DEC
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
#endif
Expand Down
51 changes: 31 additions & 20 deletions csrc/xqa/mhaUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ struct HeadPtr<Head, 0, 0> : TinyPtr<Head> {};
// #endif

// @fixme: give evict first hint for last part.
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbPartsPerHead, bool swizzle,
bool isFull, uint32_t dstNbHeads, typename SrcHeadPtr,
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbPartsPerHead,
uint32_t grainBytesSmem, uint32_t grainBytesGmem, bool swizzle, bool isFull,
uint32_t dstNbHeads, typename SrcHeadPtr, typename _LdGrain,
typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyPartialHeadsAsync(
Warp const& warp,
Array2D<LdGrain, dstNbHeads, exactDiv(exactDiv(sizeof(Head), nbPartsPerHead), grainBytes)>& dst,
Array2D<_LdGrain, dstNbHeads, exactDiv(exactDiv(sizeof(Head), nbPartsPerHead), grainBytesSmem)>&
dst,
uint32_t dstHeadOffset, SrcHeadPtr const& src, uint32_t idxPart,
uint32_t nbAvailHeads = maxNbCopiedHeads,
LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) {
Expand All @@ -117,41 +119,46 @@ __device__ inline void copyPartialHeadsAsync(
constexpr uint32_t warpLdBytes = partBytes * maxNbCopiedHeads;
constexpr uint32_t thrdLdBytes = exactDiv(warpLdBytes, warp_size);
assertIsPowerOf2<thrdLdBytes>();
static_assert(thrdLdBytes >= grainBytes);
static_assert(thrdLdBytes >= grainBytesSmem);
// a segment is responsible for loading one partial head collaboratively
constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytes);
constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytesSmem);
static_assert(thrdsPerSeg > 0 && thrdsPerSeg <= warp_size);
assertIsPowerOf2<thrdsPerSeg>();
assert(__shfl_sync(0xFU << (laneId() / 4 * 4), src.offset, 0, 4) == src.offset);
auto const warpLane = laneId();
uint32_t const segIdx = warpLane / thrdsPerSeg;
uint32_t const segLane = warpLane % thrdsPerSeg;
constexpr uint32_t partsPerWarpInst = exactDiv(grainBytes * warp_size, partBytes);
constexpr uint32_t partsPerWarpInst = exactDiv(grainBytesSmem * warp_size, partBytes);
#pragma unroll
for (uint32_t i = 0; i < thrdLdBytes / grainBytes; i++) {
for (uint32_t i = 0; i < thrdLdBytes / grainBytesSmem; i++) {
uint32_t const idxHeadLocal = partsPerWarpInst * i + segIdx;
assert(idxHeadLocal < maxNbCopiedHeads);
bool const isHeadInBound = isFull || (idxHeadLocal < nbAvailHeads);
constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytes);
constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytesSmem);
using SrcHead = mha::decay_t<decltype(src[0])>;
constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytes);
constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytesGmem);
uint32_t const idxGrainInsideHead = grainsPerPart * idxPart + segLane;
bool const isGrainInBound = (!isHeadPadded || idxGrainInsideHead < nbValidGrains);
SrcHead const* const pSrcHead = src + localHeadIdxMap(idxHeadLocal);
bool const isValidPage = (pSrcHead != nullptr);
LdGrain const* const pSrc = reinterpret_cast<LdGrain const*>(pSrcHead) + idxGrainInsideHead;
LdGrain* const pDst = &dst.template at<swizzle>(dstHeadOffset + idxHeadLocal, segLane);
Vec<uint8_t, grainBytesGmem> const* const pSrc =
reinterpret_cast<Vec<uint8_t, grainBytesGmem> const*>(pSrcHead) + idxGrainInsideHead;
Vec<uint8_t, grainBytesSmem>* const pDst = reinterpret_cast<Vec<uint8_t, grainBytesSmem>*>(
&dst.template at<swizzle>(dstHeadOffset + idxHeadLocal, segLane));
#if !ENABLE_4BIT_KV_CACHE
// 4-bit KV cache is not bank-conflict free now.
assert(!hasBankConflict(pDst));
ldgsts::copyAsync<grainBytes>(pDst, pSrc,
isValidPage && isHeadInBound && isGrainInBound ? grainBytes : 0u);
#endif
ldgsts::copyAsync<grainBytesGmem>(
pDst, pSrc, isValidPage && isHeadInBound && isGrainInBound ? grainBytesGmem : 0u);
}
}

template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbWarps, bool swizzle, bool isFull,
uint32_t dstNbHeads, typename SrcHeadPtr,
typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbWarps, uint32_t grainBytesSmem,
uint32_t grainBytesGmem, bool swizzle, bool isFull, uint32_t dstNbHeads,
typename SrcHeadPtr, typename _LdGrain, typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyHeadsAsync(
uint32_t idxWarp, Array2D<LdGrain, dstNbHeads, exactDiv(sizeof(Head), grainBytes)>& dst,
uint32_t idxWarp, Array2D<_LdGrain, dstNbHeads, exactDiv(sizeof(Head), grainBytesSmem)>& dst,
SrcHeadPtr const& src, uint32_t nbAvailHeads = maxNbCopiedHeads,
LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) {
assert(idxWarp < nbWarps);
Expand All @@ -161,9 +168,9 @@ __device__ inline void copyHeadsAsync(
uint32_t const warpNbAvailHeads =
(dstHeadOffset < nbAvailHeads ? nbAvailHeads - dstHeadOffset : 0);
constexpr uint32_t idxPart = 0;
copyPartialHeadsAsync<Head, maxNbHeadsPerWarp, 1, swizzle, isFull, dstNbHeads>(
warp, dst, dstHeadOffset, src, idxPart, warpNbAvailHeads,
[&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); });
copyPartialHeadsAsync<Head, maxNbHeadsPerWarp, 1, grainBytesSmem, grainBytesGmem, swizzle, isFull,
dstNbHeads>(warp, dst, dstHeadOffset, src, idxPart, warpNbAvailHeads,
[&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); });
}

template <bool isAsync, uint32_t maxTotalNbGrains, uint32_t nbWarps, bool isFull = true>
Expand Down Expand Up @@ -235,6 +242,10 @@ template <>
struct KVCacheList<true> {
GMemCacheHead* kCacheVLLM;
GMemCacheHead* vCacheVLLM;
#if ENABLE_4BIT_KV_CACHE
GMemCacheHeadSf* kSfCacheVLLM;
GMemCacheHeadSf* vSfCacheVLLM;
#endif
KVCachePageIndex const*
kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility)
Expand Down
Loading
Loading