Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MLA && trick avoid append_dec #18

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
}
}

if (max_dec_len_this_time_data > 0) {
if (rotary_embs && max_dec_len_this_time_data > 0) {
cudaStream_t exec_stream;
if (max_enc_len_this_time_data > 0) {
cudaStreamWaitEvent(decoder_stream, main_event);
Expand Down
149 changes: 39 additions & 110 deletions csrc/gpu/append_attn/decode_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,84 +187,42 @@ __device__ __forceinline__ void compute_qk(const CacheT* base_smem,
smem = base_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("q_vec:\n");
for (uint32_t i = 0; i < vec_size; ++i) {
printf("%f ", static_cast<float>(q_vec[i]));
if (iter_base + j < iter_bound) {
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
if constexpr (std::is_same<T, half>::value) {
s[j] = __float2half(0.f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = __float2bfloat16(0.f);
}
printf("\n");
printf("k_vec:\n");
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
printf("%f ", static_cast<float>(k_vec[i]));
s[j] += q_vec[i] * k_vec[i];
}
printf("\n");
}
__syncthreads();
#endif
if constexpr (std::is_same<T, half>::value) {
s[j] = __float2half(0.f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = __float2bfloat16(0.f);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
s[j] += q_vec[i] * k_vec[i];
}
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("be t%d-s[2]:%f \n", threadIdx.x, static_cast<float>(s[j]));
}
__syncthreads();
#endif
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
}

__syncthreads();
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("af t%d-s[2]:%f \n", threadIdx.x, static_cast<float>(s[j]));
}
__syncthreads();
#endif
tmp_smem[bidy] = s[j];
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("tmp_smem:\n");
for (uint32_t i = 0; i < bdy; ++i) {
printf("%f ", static_cast<float>(tmp_smem[i]));
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
}

__syncthreads();

tmp_smem[bidy] = s[j];

__syncthreads();
if constexpr (std::is_same<T, half>::value) {
s[j] = __float2half(0.f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = __float2bfloat16(0.f);
}
for(uint32_t i = 0; i < bdy; ++i) {
s[j] += tmp_smem[i];
}

} else {
if constexpr (std::is_same<T, half>::value) {
s[j] = __float2half(-5e4f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = __float2bfloat16(-3.38953e38f);
}
printf("\n");
}
__syncthreads();
#endif
__syncthreads();
if constexpr (std::is_same<T, half>::value) {
s[j] = __float2half(0.f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = __float2bfloat16(0.f);
}
for(uint32_t i = 0; i < bdy; ++i) {
s[j] += tmp_smem[i];
}
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("s[%d]: %f\n", j, static_cast<float>(s[j]));
}
__syncthreads();
#endif
if constexpr (std::is_same<T, half>::value) {
s[j] = (iter_base + j < iter_bound) ? s[j] : __float2half(-5e4f);
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = (iter_base + j < iter_bound) ? s[j] : __float2bfloat16(-3.38953e38f);
}
st.m = st.m > s[j] ? st.m : s[j];
}
Expand All @@ -280,7 +238,7 @@ __device__ __forceinline__ void compute_qk(const CacheT* base_smem,
st.d += s[j];
#ifdef DEBUG_DEC_ATTN
int tile_id = iter_base + j;
if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && (iter_base + j * bdz + zid < iter_bound)) {
if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0 && (iter_base + j * bdz + zid < iter_bound)) {
printf("update s and d, zip: %d, gid: %d, vid: %d, tile_id: %d, j: %d, s[%d]: %f, m: %f, d: %f\n",
(int)zid, (int)threadIdx.y, (int)vid, (int)tile_id, (int)j, (int)j, static_cast<float>(s[j]), static_cast<float>(st.m), static_cast<float>(st.d));
}
Expand All @@ -297,56 +255,27 @@ template<uint32_t vec_size, uint32_t half_vec_size, uint32_t DEAL_EACH_TIME, uin
__device__ __forceinline__ void compute_sv(const T *s,
const CacheT *base_v_smem,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t vid,
softmax_state_t<vec_size, T>& st) {
uint32_t zid = threadIdx.z;
const CacheT* v_smem;
AlignedVector<T, vec_size> v_vec;
AlignedVector<T, vec_size> bac_vec;
// v_smem = base_v_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
v_smem = base_v_smem + (stage_idx * DEAL_EACH_TIME + zid) * HEAD_DIM_QK;
#pragma unroll
for (int j = 0; j < DEAL_EACH_TIME; ++j) {
// Load<T, vec_size>(v_smem + j * HEAD_DIM + vid * vec_size, &v_vec);
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
Load<T, vec_size>(v_smem + j * HEAD_DIM_QK + vid * vec_size, &v_vec);
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("v_vec:\n");
for (uint32_t i = 0; i < vec_size; ++i) {
printf("%f ", static_cast<float>(v_vec[i]));
}
printf("\n");
}
__syncthreads();
#endif
#pragma unroll
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {

bac_vec[reg_id] = st.o[reg_id];
st.o[reg_id] += s[j] * v_vec[reg_id];
#ifdef DEBUG_DEC_ATTN
if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && j < 31) {
printf("zip: %d, gid: %d, vid: %d, j: %d, s_vec[%d]: %f, v_vec[%d]: %f, o[%d]: %f, d: %f\n",
(int)zid, (int)threadIdx.y, (int)vid, (int)j, (int)j, static_cast<float>(s[j]), (int)reg_id, static_cast<float>(v_vec[reg_id]), (int)reg_id, static_cast<float>(st.o[reg_id]), static_cast<float>(st.d));
}
__syncthreads();
#endif
}
}
#ifdef DEBUG_DEC_PRE
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
printf("st.o:\n");
for (uint32_t i = 0; i < vec_size; ++i) {
printf("%f ", static_cast<float>(st.o[i]));
}
printf("\n");
printf("st.d:\n");
for (uint32_t i = 0; i < vec_size; ++i) {
printf("%f ", static_cast<float>(st.d));
}
printf("\n");
}
__syncthreads();
#endif

}

// template<uint32_t vec_size, uint32_t HEAD_DIM, uint32_t bdy, uint32_t bdz, typename T>
Expand Down
Loading
Loading