Skip to content

Commit

Permalink
[XPU] fix segmentfault caused by setting fix_seed_offset on XPU
Browse files Browse the repository at this point in the history
  • Loading branch information
runzhech committed Apr 30, 2024
1 parent a13f7dc commit 8cf862d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 46 deletions.
45 changes: 23 additions & 22 deletions paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void FlashAttnGradKernel(const Context& ctx,

// get seed offset
const int64_t* seed_offset_data = seed_offset.data<int64_t>();

// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
Expand All @@ -106,28 +107,28 @@ void FlashAttnGradKernel(const Context& ctx,
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
#else
Expand Down
55 changes: 31 additions & 24 deletions paddle/phi/kernels/xpu/flash_attn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_XPU_XHPC
Expand Down Expand Up @@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
memory_utils::Copy(phi::CPUPlace(),
seed_offset_data,
fixed_seed_offset->place(),
fixed_seed_offset->data<int64_t>(),
sizeof(int64_t) * 2);
} else {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
}
} else {
std::pair<uint64_t, uint64_t> seed_offset_pair;
uint64_t inc = batch_size * num_heads * 32;
Expand All @@ -264,7 +272,6 @@ void FlashAttnKernel(const Context& ctx,
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
float* softmax_lse_data = softmax_lse->data<float>();

const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
Expand All @@ -281,24 +288,24 @@ void FlashAttnKernel(const Context& ctx,
// nullptr);
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
#else
Expand Down

0 comments on commit 8cf862d

Please sign in to comment.