From 5c5e46d31c8b7349e69c443e714cc78e09707de8 Mon Sep 17 00:00:00 2001 From: enzodechine Date: Tue, 7 May 2024 10:54:28 +0800 Subject: [PATCH] [XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (#64003) * [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary --- .../phi/kernels/xpu/flash_attn_grad_kernel.cc | 66 ++++++++++------ paddle/phi/kernels/xpu/flash_attn_kernel.cc | 76 +++++++++++++------ 2 files changed, 94 insertions(+), 48 deletions(-) diff --git a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc index 0dd3c137898680..82d44ac0aad429 100644 --- a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc @@ -69,9 +69,28 @@ void FlashAttnGradKernel(const Context& ctx, const XPUType* out_data = reinterpret_cast(out.data()); const float* softmax_lse_data = softmax_lse.data(); const XPUType* dout_data = reinterpret_cast(dout.data()); + + xpu::ctx_guard RAII_GUARD(ctx.x_context()); const float* bias_data = nullptr; if (attn_mask.get_ptr() != nullptr) { - bias_data = attn_mask->data(); + if (attn_mask->dtype() == phi::DataType::FLOAT32) { + bias_data = attn_mask->data(); + } else if (attn_mask->dtype() == phi::DataType::FLOAT16 || + attn_mask->dtype() == phi::DataType::BFLOAT16) { + float* bias_tmp = RAII_GUARD.alloc_l3_or_gm(attn_mask->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(attn_mask->data()), + bias_tmp, + attn_mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data = bias_tmp; + } else { + errors::Unimplemented( + "Unsupported dtype for attention_mask in xpu flash attention, only " + "float32, float16 and " + "bfloat16 are supported."); + } } // output XPUType* dq_data = reinterpret_cast(dq->data()); @@ -92,6 +111,7 @@ void FlashAttnGradKernel(const Context& ctx, // get seed offset const int64_t* seed_offset_data = seed_offset.data(); + // template // int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T* // k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T* @@ -106,28 +126,28 @@ void FlashAttnGradKernel(const Context& ctx, // dv_maxptr = nullptr, const float* do_maxptr = nullptr); int r = baidu::xpu::xfa::mha_varlen_bwd( ctx.x_context(), - dout_data, // dout - q_data, // q - k_data, // k - v_data, // v - out_data, // out - softmax_lse_data, // softmax_lse - dq_data, // dq - dk_data, // dk - dv_data, // dv - qlod, // lod_seqlens_q - kvlod, // lod_seqlens_k - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - num_heads, // head_num - num_heads_k, // head_num_k - head_size, // head_dim - 1.0f / std::sqrt(head_size), // softmax_scale - dropout, // p_dropout - static_cast(seed_offset_data[0]), // seed - causal, // is_causal - nullptr, // attn_mask - bias_data // bias + dout_data, // dout + q_data, // q + k_data, // k + v_data, // v + out_data, // out + softmax_lse_data, // softmax_lse + dq_data, // dq + dk_data, // dk + dv_data, // dv + qlod, // lod_seqlens_q + kvlod, // lod_seqlens_k + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + num_heads, // head_num + num_heads_k, // head_num_k + head_size, // head_dim + 1.0f / std::sqrt(head_size), // softmax_scale + dropout, // p_dropout + static_cast(seed_offset_data[0]), // seed + causal, // is_causal + nullptr, // attn_mask + bias_data // bias ); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd"); #else diff --git a/paddle/phi/kernels/xpu/flash_attn_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_kernel.cc index bdfab918db027c..0e4da3483290dc 100644 --- a/paddle/phi/kernels/xpu/flash_attn_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_kernel.cc @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" -#include "paddle/phi/core/enforce.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #ifdef PADDLE_WITH_XPU_XHPC @@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx, seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); if (fixed_seed_offset.get_ptr()) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset.get_ptr()->data(); - seed_offset_data[0] = fixed_seed_offset_data[0]; - seed_offset_data[1] = fixed_seed_offset_data[1]; + if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) { + memory_utils::Copy(phi::CPUPlace(), + seed_offset_data, + fixed_seed_offset->place(), + fixed_seed_offset->data(), + sizeof(int64_t) * 2); + } else { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset->data(); + seed_offset_data[0] = fixed_seed_offset_data[0]; + seed_offset_data[1] = fixed_seed_offset_data[1]; + } } else { std::pair seed_offset_pair; uint64_t inc = batch_size * num_heads * 32; @@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx, const XPUType* k_data = reinterpret_cast(k.data()); const XPUType* v_data = reinterpret_cast(v.data()); XPUType* out_data = reinterpret_cast(out->data()); - float* softmax_lse_data = softmax_lse->data(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + float* softmax_lse_data = softmax_lse->data(); const float* bias_data = nullptr; if (attn_mask.get_ptr() != nullptr) { - bias_data = attn_mask->data(); + if (attn_mask->dtype() == phi::DataType::FLOAT32) { + bias_data = attn_mask->data(); + } else if (attn_mask->dtype() == phi::DataType::FLOAT16 || + attn_mask->dtype() == phi::DataType::BFLOAT16) { + float* bias_tmp = RAII_GUARD.alloc_l3_or_gm(attn_mask->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(attn_mask->data()), + bias_tmp, + attn_mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data = bias_tmp; + } else { + errors::Unimplemented( + "Unsupported dtype for attention_mask in xpu flash attention, only " + "float32, float16 and " + "bfloat16 are supported."); + } } // template int // mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T* @@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx, // nullptr); int r = baidu::xpu::xfa::mha_varlen_fwd( ctx.x_context(), - q_data, // q - k_data, // k - v_data, // v - out_data, // out - softmax_lse_data, // softmax_lse - qlod, // lod_seqlens_q - kvlod, // lod_seqlens_k - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - num_heads, // head_num - num_heads_k, // head_num_k - head_size, // head_dim - 1.0f / std::sqrt(head_size), // softmax_scale - dropout, // p_dropout - static_cast(seed_offset_data[0]), // seed - causal, // is_causal - nullptr, // attn_mask - bias_data // bias + q_data, // q + k_data, // k + v_data, // v + out_data, // out + softmax_lse_data, // softmax_lse + qlod, // lod_seqlens_q + kvlod, // lod_seqlens_k + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + num_heads, // head_num + num_heads_k, // head_num_k + head_size, // head_dim + 1.0f / std::sqrt(head_size), // softmax_scale + dropout, // p_dropout + static_cast(seed_offset_data[0]), // seed + causal, // is_causal + nullptr, // attn_mask + bias_data // bias ); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd"); #else