Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/iluvatar_gpu/common/cuda_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,10 @@ PHI_DEFINE_EXPORTED_int32(
1,
"Whether use the impMode of ixdnn for flash attention "
", default is CUDNN_FATTN_LEAST_MEM_MODE.");

PHI_DEFINE_EXPORTED_int32(
ixdnn_causal_mode,
0,
"Whether use the causalMode of ixdnn for flash attention "
", default is 0.");
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
COMMON_DECLARE_int32(ixdnn_imp_mode);
COMMON_DECLARE_int32(ixdnn_causal_mode);

COMMON_DECLARE_bool(cudnn_deterministic);

Expand Down Expand Up @@ -253,7 +254,7 @@ void FlashAttnUnpaddedGradBaseKernel(
bool is_mha = (num_heads == num_heads_k);

int64_t total_q = dims[0];
bool is_unpad = (total_q == batch_size * max_seqlen_q) ? false : true;
bool is_unpad = true;
const int64_t head_size_rounded = head_size + 32 - head_size % 32;

DenseTensor q_padded, k_padded, v_padded, out_padded, dout_padded;
Expand Down Expand Up @@ -313,6 +314,7 @@ void FlashAttnUnpaddedGradBaseKernel(
flashAttnInfo.softmax_scale = std::sqrt(1.f / head_size);
flashAttnInfo.dropout_prob = dropout;
flashAttnInfo.is_causal = causal;
flashAttnInfo.causal_mode = FLAGS_ixdnn_causal_mode;
// flashAttnInfo.is_alibi = use_alibi;
// flashAttnInfo.alibi_mode = alibi_mode;
flashAttnInfo.return_softmax_lse = false;
Expand Down Expand Up @@ -718,6 +720,7 @@ void FlashAttnGradBaseKernel(
flashAttnInfo.softmax_scale = softmax_scale;
flashAttnInfo.dropout_prob = dropout;
flashAttnInfo.is_causal = causal;
flashAttnInfo.causal_mode = FLAGS_ixdnn_causal_mode;
// flashAttnInfo.is_alibi = use_alibi;
// flashAttnInfo.alibi_mode = alibi_mode;
flashAttnInfo.return_softmax_lse = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/phi/kernels/pad_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
COMMON_DECLARE_int32(ixdnn_imp_mode);
COMMON_DECLARE_int32(ixdnn_causal_mode);

namespace phi {
template <typename OutT>
Expand Down Expand Up @@ -100,7 +101,7 @@ void FlashAttnUnpaddedBaseKernel(
// TODO(umiswing): add shape check
// ixdnn
int64_t total_q = dims[0];
bool is_unpad = (total_q == batch_size * max_seqlen_q) ? false : true;
bool is_unpad = true;
const int64_t head_size_rounded = head_size + 32 - head_size % 32;

DenseTensor q_padded, k_padded, v_padded;
Expand Down Expand Up @@ -152,7 +153,7 @@ void FlashAttnUnpaddedBaseKernel(
flashAttnInfo.softmax_scale = std::sqrt(1.f / head_size);
flashAttnInfo.dropout_prob = is_test ? 0.0f : dropout;
flashAttnInfo.is_causal = causal;
flashAttnInfo.causal_mode = 1;
flashAttnInfo.causal_mode = FLAGS_ixdnn_causal_mode;
// flashAttnInfo.is_alibi = use_alibi;
// flashAttnInfo.alibi_mode = alibi_mode;
flashAttnInfo.return_softmax_lse = true;
Expand Down Expand Up @@ -283,6 +284,7 @@ void FlashAttnUnpaddedBaseKernel(
out->data(),
softmax_lse->data<float>()));

cudaDeviceSynchronize();
out->Resize({total_q, num_heads, head_size});

phi::dynload::cudnnDestroyFlashAttnDescriptor(flashAttnDesc);
Expand Down Expand Up @@ -511,6 +513,7 @@ void FlashAttnBaseKernel(
flashAttnInfo.softmax_scale = std::sqrt(1.f / head_size);
flashAttnInfo.dropout_prob = dropout;
flashAttnInfo.is_causal = causal;
flashAttnInfo.causal_mode = FLAGS_ixdnn_causal_mode;
// flashAttnInfo.is_alibi = use_alibi;
// flashAttnInfo.alibi_mode = alibi_mode;
flashAttnInfo.return_softmax_lse = true;
Expand Down Expand Up @@ -642,6 +645,7 @@ void FlashAttnBaseKernel(
out->data(),
softmax_lse->data<float>()));

cudaDeviceSynchronize();
phi::dynload::cudnnDestroyFlashAttnDescriptor(flashAttnDesc);
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnDestroyTensorDescriptor(q_desc));
Expand Down
54 changes: 3 additions & 51 deletions backends/iluvatar_gpu/patches/paddle-corex.patch
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
From bbd64f7c4396f12bc13020ba959e2cc89f1011bf Mon Sep 17 00:00:00 2001
From 23d11cd970ef7b9067edefcb979a4f298a44709a Mon Sep 17 00:00:00 2001
From: "tianyu.zhou" <[email protected]>
Date: Mon, 26 May 2025 14:21:47 +0800
Subject: [PATCH] Adapt for Iluvatar 0707.
Subject: [PATCH] Adapt for Iluvatar 0709.

Fix bug.
---
CMakeLists.txt | 2 +-
paddle/common/backend_header.h | 2 +-
.../operators/collective/recv_v2_op.cu.cc | 2 +-
.../operators/collective/send_v2_op.cu.cc | 2 +-
.../fluid/platform/device/gpu/nccl_helper.h | 2 +-
Expand All @@ -26,7 +24,6 @@ Fix bug.
paddle/phi/core/utils/data_type.h | 2 +-
paddle/phi/kernels/funcs/affine_grid_utils.h | 2 ++
paddle/phi/kernels/funcs/segmented_array.h | 8 +++++++
paddle/phi/kernels/funcs/select_impl.cu.h | 2 +-
paddle/phi/kernels/funcs/softmax_impl.h | 1 +
paddle/phi/kernels/gpu/c_embedding_kernel.cu | 2 +-
paddle/phi/kernels/gpu/elementwise_grad.h | 4 ++++
Expand All @@ -39,8 +36,7 @@ Fix bug.
paddle/phi/kernels/squeeze_kernel.cc | 2 ++
paddle/phi/kernels/strided_slice_kernel.cc | 2 ++
paddle/phi/kernels/unsqueeze_kernel.cc | 2 ++
paddle/utils/flat_hash_map.h | 5 ++++
35 files changed, 122 insertions(+), 37 deletions(-)
32 files changed, 115 insertions(+), 35 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index b2c4e6a650..c51f3df1f0 100755
Expand All @@ -55,19 +51,6 @@ index b2c4e6a650..c51f3df1f0 100755
option(CINN_WITH_CUDNN "Compile CINN with CUDNN support" ON)
option(WITH_PIP_CUDA_LIBRARIES
"Paddle uses the CUDA library provided by NVIDIA" ON)
diff --git a/paddle/common/backend_header.h b/paddle/common/backend_header.h
index 240b04c3ad..82f5aca792 100644
--- a/paddle/common/backend_header.h
+++ b/paddle/common/backend_header.h
@@ -18,7 +18,7 @@
#include <cuda.h>
#endif

-#if defined(__CUDACC__) && CUDA_VERSION >= 11000
+#if defined(__CUDACC__) && (CUDA_VERSION >= 11000 || defined(PADDLE_WITH_COREX))
#define PADDLE_CUDA_BF16
#include <cuda_bf16.h>
#endif
diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc
index ab866f015c..10a8111637 100644
--- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc
Expand Down Expand Up @@ -450,19 +433,6 @@ index dad852093e..8ebceb7074 100644

auto ptr = allocation->ptr();
allocations.emplace_back(std::move(allocation));
diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h
index f96e392d94..56629c9b15 100644
--- a/paddle/phi/kernels/funcs/select_impl.cu.h
+++ b/paddle/phi/kernels/funcs/select_impl.cu.h
@@ -396,7 +396,7 @@ void SelectKernel(const KPDevice &dev_ctx,
using CT = int64_t; // set Count_data Type
const int t_size = sizeof(CT);

- const phi::GPUPlace &cuda_place = dev_ctx.GetPlace();
+ const phi::CustomPlace &cuda_place = dev_ctx.GetPlace();
phi::CPUPlace cpu_place = phi::CPUPlace();

// 1.1 get stored data num of per block
diff --git a/paddle/phi/kernels/funcs/softmax_impl.h b/paddle/phi/kernels/funcs/softmax_impl.h
index 8f6b0fdd32..b95ea07883 100644
--- a/paddle/phi/kernels/funcs/softmax_impl.h
Expand Down Expand Up @@ -703,23 +673,5 @@ index c30752337d..dc1ca3bdc5 100644
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(unsqueeze,
GPU,
diff --git a/paddle/utils/flat_hash_map.h b/paddle/utils/flat_hash_map.h
index b643fc1a57..960bc7d9b1 100644
--- a/paddle/utils/flat_hash_map.h
+++ b/paddle/utils/flat_hash_map.h
@@ -683,8 +683,13 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal {
}

size_t num_buckets_for_reserve(size_t num_elements) const {
+#ifdef PADDLE_WITH_COREX
+ return static_cast<size_t>(std::ceil(
+ num_elements / std::min((double)0.5, static_cast<double>(_max_load_factor))));
+#else
return static_cast<size_t>(std::ceil(
num_elements / std::min(0.5, static_cast<double>(_max_load_factor))));
+#endif
}
void rehash_for_other_container(const sherwood_v3_table &other) {
rehash(
--
2.34.1