Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions backends/iluvatar_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,14 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/core/mixed_vector.cc
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cusparse.cc
# kernels/funcs
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/concat_and_split_functor.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/deformable_conv_functor.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/eigen/*.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math_function.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/*.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math/*.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/eigen/*.cu
# cudnn/cublas
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cudnn.cc
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublas.cc
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cublasLt.cc
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cufft.cc
# kernels/gpu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/abs_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/abs_kernel.cu
Expand Down Expand Up @@ -219,6 +216,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/tril_triu_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/unbind_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/uniform_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/where_kernel.cu
# kernels/selected_rows
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu
Expand Down Expand Up @@ -655,7 +653,6 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/multinomial_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/nll_loss_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/moe_unpermute_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/pool_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/moving_average_abs_max_scale_kernel.cu
Expand Down Expand Up @@ -752,7 +749,6 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/pad3d_kernel.cu
# ############################################################################
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/array_grad_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/set_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/is_empty_kernel.cc
Expand Down Expand Up @@ -860,7 +856,6 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/skip_layernorm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Expand All @@ -869,12 +864,9 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math_function.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/log_softmax_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/backends/context_pool.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/repeat_tensor2index_tensor.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/binomial_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bernoulli_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bmm_grad_kernel_impl.h
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bmm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/box_coder_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu
Expand All @@ -896,31 +888,9 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gather_tree_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/graph_reindex_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/group_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_act_dequant_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_weighted_swiglu_act_quant_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_elemwise_activation_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/fp8_gemm/fp8_gemm_with_cublasLt/fp8_fp8_half_gemm.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_grad_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/gemm_epilogue_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_elemwise_activation_grad_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/as_real_kernel.cc
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/as_complex_kernel.cc
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/complex_grad_kernel.cc
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/stride/complex_kernel.cc
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/shape_kernel.cc
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu
# ############################################################################
# kernels/fusion kernels/selected_rows
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu
# kernels/kps
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/kps/elementwise_kernel.cu
Expand All @@ -932,7 +902,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/array_kernel.cc)

set(CUDA_SRCS ${CUDA_SRCS1} ${CUDA_SRCS2})
list(REMOVE_DUPLICATES CUDA_SRCS1)
list(REMOVE_DUPLICATES CUDA_SRCS)

list(
REMOVE_ITEM
Expand All @@ -942,14 +912,12 @@ list(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/softmax.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/weight_only_gemv.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math/context_project.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/fft.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/lstm_compute.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/fake_quantize_functor.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/qkv_unpack_mha_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/check_numerics_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/collect_fpn_proposals_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/dgc_kernel.cu
Expand Down
3 changes: 0 additions & 3 deletions backends/iluvatar_gpu/build_paddle.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ else
WITH_FLAGCX="OFF"
fi

bash clean_paddle.sh

if ! git -C "$PADDLE_SOURCE_DIR" apply --reverse --check "$PATCH_FILE" > /dev/null 2>&1; then
if ! git -C "$PADDLE_SOURCE_DIR" apply "$PATCH_FILE"; then
echo "Error: Failed to apply patch!"
Expand Down Expand Up @@ -68,7 +66,6 @@ cmake -G Ninja -DPY_VERSION=${PYTHON_VERSION} -DWITH_COREX=ON -DPADDLE_SOURCE_DI
-DCMAKE_CUDA_FLAGS='-Xclang -fcuda-allow-variadic-functions -mllvm --skip-double' \
-DCMAKE_C_FLAGS="-pthread" \
-DWITH_ARM=OFF -DWITH_DGC=OFF .. 2>&1 | tee compile.log
# make VERBOSE=1 -j$(nproc) 2>&1 | tee -a compile.log
ninja -k 0 -j$(nproc) 2>&1 | tee -a compile.log
FAILED_LOG="failed_files.log"
grep -E "FAILED: " compile.log | tee ${FAILED_LOG}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
namespace cub = hipcub;
#endif

#include "../gpudnn/softmax_gpudnn.h"
#include "kernels/gpudnn/softmax_gpudnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
Expand Down Expand Up @@ -277,4 +277,5 @@ PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_grad,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License. */
namespace cub = hipcub;
#endif

#include "../gpudnn/softmax_gpudnn.h"
#include "kernels/gpudnn/softmax_gpudnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
Expand Down Expand Up @@ -1412,4 +1412,5 @@ PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,15 @@ PD_CUSTOM_KERNEL_REGISTER(fft_c2c_grad,
iluvatar_gpu,
ALL_LAYOUT,
phi::FFTC2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_CUSTOM_KERNEL_REGISTER(fft_c2r_grad,
iluvatar_gpu,
ALL_LAYOUT,
phi::FFTC2RGradKernel,
float,
double) {
phi::dtype::complex<float>) {}
PD_CUSTOM_KERNEL_REGISTER(
fft_c2r_grad, iluvatar_gpu, ALL_LAYOUT, phi::FFTC2RGradKernel, float) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
PD_CUSTOM_KERNEL_REGISTER(fft_r2c_grad,
iluvatar_gpu,
ALL_LAYOUT,
phi::FFTR2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
phi::dtype::complex<float>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ PD_CUSTOM_KERNEL_REGISTER(fft_c2c,
iluvatar_gpu,
ALL_LAYOUT,
phi::FFTC2CKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<float>) {}
PD_CUSTOM_KERNEL_REGISTER(fft_c2r,
iluvatar_gpu,
ALL_LAYOUT,
phi::FFTC2RKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
phi::dtype::complex<float>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_CUSTOM_KERNEL_REGISTER(
fft_r2c, iluvatar_gpu, ALL_LAYOUT, phi::FFTR2CKernel, float, double) {
fft_r2c, iluvatar_gpu, ALL_LAYOUT, phi::FFTR2CKernel, float) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,34 @@ void FlashAttnUnpaddedGradBaseKernel(
int64_t offset = static_cast<int64_t>(seed_offset_data[1]);
PhiloxCudaState philox_state = PhiloxCudaState(seed, offset);

bool accuracy_first = false;
if (attn_mask) {
causal = false;
phi::DenseTensor min_tensor;
min_tensor.Resize({1});
dev_ctx.template Alloc<T>(&min_tensor);

std::vector<int> reduce_dims;
for (int64_t i = 0; i < attn_mask->dims().size(); ++i) {
reduce_dims.push_back(i);
}
funcs::ReduceKernel<T, T, kps::MinFunctor, kps::IdentityFunctor<T>>(
dev_ctx,
*attn_mask,
&min_tensor,
kps::IdentityFunctor<T>(),
reduce_dims);

std::vector<T> host_min;
TensorToVector(min_tensor, dev_ctx, &host_min);

float min_val = static_cast<float>(host_min[0]);
constexpr float threshold = -3.3895313892515355e+37f;
accuracy_first = (min_val < threshold);
VLOG(2) << "flash_attn attn_mask accuracy_first: " << accuracy_first
<< ", causal: " << causal;
}

if (FLAGS_enable_ixattnbkd) {
// ixattnbkd unpad bwd
ixAttnBkdConfigInfo ixAttnbkdInfo;
Expand All @@ -317,7 +345,7 @@ void FlashAttnUnpaddedGradBaseKernel(
ixAttnbkdInfo.batch = batch_size;
ixAttnbkdInfo.max_seq_len_src = max_seqlen_q;
ixAttnbkdInfo.max_seq_len_trg = max_seqlen_k;
ixAttnbkdInfo.accuracy_first = false;
ixAttnbkdInfo.accuracy_first = accuracy_first;

ixAttnBkdDataType_t dataType;
if (q.dtype() == phi::DataType::FLOAT16) {
Expand All @@ -338,7 +366,7 @@ void FlashAttnUnpaddedGradBaseKernel(
SetIxAttnBkdTensor(&k_desc, k, dataType);
SetIxAttnBkdTensor(&v_desc, v, dataType);
SetIxAttnBkdTensor(&o_desc, out, dataType);
if (attn_mask.get_ptr()) {
if (attn_mask) {
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -420,7 +448,7 @@ void FlashAttnUnpaddedGradBaseKernel(
flashAttnInfo.batch = batch_size;
flashAttnInfo.max_seq_len_src = max_seqlen_q;
flashAttnInfo.max_seq_len_trg = max_seqlen_k;
flashAttnInfo.accuracy_first = false;
flashAttnInfo.accuracy_first = accuracy_first;

int32_t nb_dims = 3;
std::vector<int32_t> qShape, kShape, vShape, oShape, lseShape, dqShape,
Expand Down Expand Up @@ -506,7 +534,7 @@ void FlashAttnUnpaddedGradBaseKernel(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor(
do_desc, dataType, nb_dims, doShape.data(), doStride.data()));

if (attn_mask.get_ptr()) {
if (attn_mask) {
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -590,7 +618,7 @@ void FlashAttnUnpaddedGradBaseKernel(
destroy_tensor_desc(v_desc);
destroy_tensor_desc(o_desc);
destroy_tensor_desc(lse_desc);
if (attn_mask.get_ptr()) {
if (attn_mask) {
destroy_tensor_desc(m_desc);
}
destroy_tensor_desc(dq_desc);
Expand All @@ -604,7 +632,7 @@ void FlashAttnUnpaddedGradBaseKernel(
v_desc = nullptr;
o_desc = nullptr;
lse_desc = nullptr;
if (attn_mask.get_ptr()) {
if (attn_mask) {
m_desc = nullptr;
}
dq_desc = nullptr;
Expand Down Expand Up @@ -883,6 +911,34 @@ void FlashAttnGradBaseKernel(
int64_t offset = static_cast<int64_t>(seed_offset_data[1]);
PhiloxCudaState philox_state = PhiloxCudaState(seed, offset);

bool accuracy_first = false;
if (attn_mask) {
causal = false;
phi::DenseTensor min_tensor;
min_tensor.Resize({1});
dev_ctx.template Alloc<T>(&min_tensor);

std::vector<int> reduce_dims;
for (int64_t i = 0; i < attn_mask->dims().size(); ++i) {
reduce_dims.push_back(i);
}
funcs::ReduceKernel<T, T, kps::MinFunctor, kps::IdentityFunctor<T>>(
dev_ctx,
*attn_mask,
&min_tensor,
kps::IdentityFunctor<T>(),
reduce_dims);

std::vector<T> host_min;
TensorToVector(min_tensor, dev_ctx, &host_min);

float min_val = static_cast<float>(host_min[0]);
constexpr float threshold = -3.3895313892515355e+37f;
accuracy_first = (min_val < threshold);
VLOG(2) << "flash_attn attn_mask accuracy_first: " << accuracy_first
<< ", causal: " << causal;
}

if (FLAGS_enable_ixattnbkd) {
// ixattnbkd bwd
ixAttnBkdConfigInfo ixAttnbkdInfo;
Expand All @@ -902,7 +958,7 @@ void FlashAttnGradBaseKernel(
ixAttnbkdInfo.batch = batch_size;
ixAttnbkdInfo.max_seq_len_src = seqlen_q;
ixAttnbkdInfo.max_seq_len_trg = seqlen_k;
ixAttnbkdInfo.accuracy_first = false;
ixAttnbkdInfo.accuracy_first = accuracy_first;

ixAttnBkdDataType_t dataType;
if (q.dtype() == phi::DataType::FLOAT16) {
Expand All @@ -923,7 +979,7 @@ void FlashAttnGradBaseKernel(
SetIxAttnBkdTensor(&k_desc, k, dataType);
SetIxAttnBkdTensor(&v_desc, v, dataType);
SetIxAttnBkdTensor(&o_desc, out, dataType);
if (attn_mask.get_ptr()) {
if (attn_mask) {
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -1005,7 +1061,7 @@ void FlashAttnGradBaseKernel(
flashAttnInfo.batch = batch_size;
flashAttnInfo.max_seq_len_src = seqlen_q;
flashAttnInfo.max_seq_len_trg = seqlen_k;
flashAttnInfo.accuracy_first = false;
flashAttnInfo.accuracy_first = accuracy_first;

int32_t nb_dims = 4;
std::vector<int32_t> qShape, kShape, vShape, oShape, lseShape, dqShape,
Expand Down Expand Up @@ -1090,7 +1146,7 @@ void FlashAttnGradBaseKernel(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor(
do_desc, dataType, nb_dims, doShape.data(), doStride.data()));

if (attn_mask.get_ptr()) {
if (attn_mask) {
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -1173,7 +1229,7 @@ void FlashAttnGradBaseKernel(
destroy_tensor_desc(v_desc);
destroy_tensor_desc(o_desc);
destroy_tensor_desc(lse_desc);
if (attn_mask.get_ptr()) {
if (attn_mask) {
destroy_tensor_desc(m_desc);
}
destroy_tensor_desc(dq_desc);
Expand All @@ -1187,7 +1243,7 @@ void FlashAttnGradBaseKernel(
v_desc = nullptr;
o_desc = nullptr;
lse_desc = nullptr;
if (attn_mask.get_ptr()) {
if (attn_mask) {
m_desc = nullptr;
}
dq_desc = nullptr;
Expand Down
Loading