Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1327,8 +1327,8 @@ CUBIN_EXPORT __global__
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {

float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these changes matter but wouldn't hurt as well.

float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
assert(allowMultiBlockMode || gridDim.x == 1);
bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1);
uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1;
Expand Down
4 changes: 2 additions & 2 deletions csrc/xqa/mha_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ __launch_bounds__(128 * 3)
uint32_t* __restrict__ const semaphores =
nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
void* __restrict__ const scratch = nullptr) {
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \
(IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1
uint32_t const idxReq = blockIdx.z / nbKHeads;
Expand Down
4 changes: 2 additions & 2 deletions csrc/xqa/mla_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1564,8 +1564,8 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
PartialResult* __restrict__ const partialResults =
nullptr) // [totalNbInputTokens][maxNbSubSeq]
{
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1);
extern __shared__ char smemBuf[];
uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size);
Expand Down
Loading