Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
51da6b7
modify cmake for warpctc and warprnnt
jxwangmetax Sep 16, 2025
b60cba2
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 16, 2025
1abea54
modify conv for tf32 and fp32
jxwangmetax Sep 16, 2025
fc63171
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 16, 2025
f26987f
modify conv kernel
jxwangmetax Sep 16, 2025
a0cb0a7
modify library to static library
jxwangmetax Sep 16, 2025
d6579b8
Merge branch 'develop' into develop
jxwangmetax Sep 16, 2025
9a9372b
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 16, 2025
9cd1f2e
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 16, 2025
fb902f4
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 17, 2025
7b018df
modify kernel
jxwangmetax Sep 17, 2025
e61cf0d
modify fused_bias_dropout_residual_layer_norm
jxwangmetax Sep 17, 2025
0e5be1a
Merge branch 'develop' into develop
jxwangmetax Sep 17, 2025
2757fb7
modify compile
jxwangmetax Sep 18, 2025
74a263a
Merge branch 'develop' of https://github.com/jxwangmetax/PaddleCustom…
jxwangmetax Sep 18, 2025
88d5eae
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 19, 2025
b2b41c2
modify blas
jxwangmetax Sep 19, 2025
6ea2fab
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 22, 2025
6556cce
modify blas
jxwangmetax Sep 22, 2025
1cbe0d8
modify blas
jxwangmetax Sep 22, 2025
554b3cb
modify blas
jxwangmetax Sep 22, 2025
dfac884
modify context
jxwangmetax Sep 22, 2025
5845b41
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 22, 2025
93f75b7
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 24, 2025
e7278d0
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 26, 2025
36aab9d
modify kernels
jxwangmetax Sep 26, 2025
01586f9
Merge branch 'metax666:develop' into develop
jxwangmetax Sep 28, 2025
3531ed0
modify kernels
jxwangmetax Sep 29, 2025
f0562bb
Merge branch 'develop' into develop
jxwangmetax Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/metax_gpu/kernels/impl/addmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void AddmmKernel(const Context& dev_ctx,
y_dims[0]));

dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) return;
auto blas = funcs::GetBlas<Context, T>(dev_ctx);

// calc broadcast dim
Expand Down
60 changes: 59 additions & 1 deletion backends/metax_gpu/patch/paddle.patch
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,21 @@ index d69eb67d6f..1d8b6e9375 100644
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

diff --git a/paddle/phi/kernels/funcs/embedding_grad.h b/paddle/phi/kernels/funcs/embedding_grad.h
index 461e6e2474..48a64ae9ce 100644
--- a/paddle/phi/kernels/funcs/embedding_grad.h
+++ b/paddle/phi/kernels/funcs/embedding_grad.h
@@ -143,8 +143,8 @@ void LaunchEmbeddingGradDeterministicKernel(const GPUContext& dev_ctx,
constexpr int kWarpSize = 64;
constexpr int kBlockDimY = 16;
#else
- constexpr int kWarpSize = 32;
- constexpr int kBlockDimY = 32;
+ constexpr int kWarpSize = 64;
+ constexpr int kBlockDimY = 16;
#endif
dim3 threads(kWarpSize, kBlockDimY);
dim3 grids(static_cast<int>((D + kWarpSize - 1) / kWarpSize));
diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu
index cb35feee32..64f5bd24ac 100644
--- a/paddle/phi/kernels/funcs/fc_functor.cu
Expand Down Expand Up @@ -501,6 +516,49 @@ index 15e1a4a3c3..e4780538d7 100644
#include "paddle/phi/kernels/funcs/im2col.h"

namespace phi {
diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h
index e5361b836e..5ad238df08 100644
--- a/paddle/phi/kernels/funcs/math_cuda_utils.h
+++ b/paddle/phi/kernels/funcs/math_cuda_utils.h
@@ -175,12 +175,12 @@ struct KeyValuePair<half> {
#define WARP_SIZE_WIDTH_MASK 0x3f
typedef u_int64_t warp_mask_t;
#else
-#define FINAL_MASK 0xffffffff
-#define HALF_WARP 16
-#define WARP_SIZE 32
-#define WARP_SIZE_WIDTH 5
-#define WARP_SIZE_WIDTH_MASK 0x1f
-typedef unsigned warp_mask_t;
+#define FINAL_MASK 0xffffffffffffffffUL
+#define HALF_WARP 32
+#define WARP_SIZE 64
+#define WARP_SIZE_WIDTH 6
+#define WARP_SIZE_WIDTH_MASK 0x3f
+typedef u_int64_t warp_mask_t;
#endif

template <typename T>
@@ -200,19 +200,13 @@ __inline__ __device__ T BlockReduceSum(T val, warp_mask_t mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> WARP_SIZE_WIDTH;
-
val = WarpReduceSum<T>(val, mask);
-
- __syncthreads();
if (lane == 0) shared[wid] = val;
-
__syncthreads();
-
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = WarpReduceSum<T>(val, mask);
-
return val;
}

diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu
index e101224970..a52eb6096f 100644
--- a/paddle/phi/kernels/funcs/matrix_inverse.cu
Expand Down Expand Up @@ -534,7 +592,7 @@ index 558d363b39..05da04b517 100644
#include "paddle/phi/kernels/funcs/scatter.cu.h"

diff --git a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu
index 8b0baf5f5f..260482f124 100644
index 047f52bd91..a05b34d3ba 100644
--- a/paddle/phi/kernels/funcs/multihead_matmul_functor.cu
+++ b/paddle/phi/kernels/funcs/multihead_matmul_functor.cu
@@ -27,7 +27,7 @@ namespace cub = hipcub;
Expand Down
Loading