Skip to content

Commit

Permalink
[XPU] fix bugs in processing of attention_mask and fix_seed_offset on…
Browse files Browse the repository at this point in the history
… XPU (PaddlePaddle#64003)

* [XPU] fix segmentfault caused by setting fix_seed_offset on XPU

* cast attention_mask to float32 when necessary
  • Loading branch information
runzhech authored and co63oc committed May 10, 2024
1 parent 2ed82cd commit 5c5e46d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 48 deletions.
66 changes: 43 additions & 23 deletions paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,28 @@ void FlashAttnGradKernel(const Context& ctx,
const XPUType* out_data = reinterpret_cast<const XPUType*>(out.data<T>());
const float* softmax_lse_data = softmax_lse.data<float>();
const XPUType* dout_data = reinterpret_cast<const XPUType*>(dout.data<T>());

xpu::ctx_guard RAII_GUARD(ctx.x_context());
const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
bias_data = attn_mask->data<float>();
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
attn_mask->dtype() == phi::DataType::BFLOAT16) {
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
bias_tmp,
attn_mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_tmp;
} else {
errors::Unimplemented(
"Unsupported dtype for attention_mask in xpu flash attention, only "
"float32, float16 and "
"bfloat16 are supported.");
}
}
// output
XPUType* dq_data = reinterpret_cast<XPUType*>(dq->data<T>());
Expand All @@ -92,6 +111,7 @@ void FlashAttnGradKernel(const Context& ctx,

// get seed offset
const int64_t* seed_offset_data = seed_offset.data<int64_t>();

// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
Expand All @@ -106,28 +126,28 @@ void FlashAttnGradKernel(const Context& ctx,
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
#else
Expand Down
76 changes: 51 additions & 25 deletions paddle/phi/kernels/xpu/flash_attn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_XPU_XHPC
Expand Down Expand Up @@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
memory_utils::Copy(phi::CPUPlace(),
seed_offset_data,
fixed_seed_offset->place(),
fixed_seed_offset->data<int64_t>(),
sizeof(int64_t) * 2);
} else {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
}
} else {
std::pair<uint64_t, uint64_t> seed_offset_pair;
uint64_t inc = batch_size * num_heads * 32;
Expand All @@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx,
const XPUType* k_data = reinterpret_cast<const XPUType*>(k.data<T>());
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
float* softmax_lse_data = softmax_lse->data<float>();

xpu::ctx_guard RAII_GUARD(ctx.x_context());
float* softmax_lse_data = softmax_lse->data<float>();
const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
bias_data = attn_mask->data<float>();
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
attn_mask->dtype() == phi::DataType::BFLOAT16) {
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
bias_tmp,
attn_mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_tmp;
} else {
errors::Unimplemented(
"Unsupported dtype for attention_mask in xpu flash attention, only "
"float32, float16 and "
"bfloat16 are supported.");
}
}
// template <typename T, typename TACCUM, typename TGEMM, typename TID> int
// mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T*
Expand All @@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx,
// nullptr);
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
#else
Expand Down

0 comments on commit 5c5e46d

Please sign in to comment.