Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 101 additions & 30 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
T x[ne] = {0};

Expand Down Expand Up @@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
Comment on lines +162 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please define only the actually used shapes of 16x4 and 16x16 and use NO_DEVICE_CODE instead of a static assert as was recently changed in the surrounding code.

}
#endif
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];

const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
}
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)

Comment on lines +788 to +817
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to my understanding currently unused, so please remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I believe this is used in vec_dot_q8_0_q8_1_mma function which are called in Q4_0, Q5_0, Q8_0, MXFP4 etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that function I'm only seeing 16x4 and 16x16 tiles, not 16x8.

Copy link
Contributor Author

@jiachengjason jiachengjason Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the Q4_0 case, execution flows into vec_dot_q8_0_q8_1_mma in mmq.cuh. Inside vec_dot_q8_0_q8_1_mma function, tile A and tile B are shaped as 16×8 blocks. These tiles are forwarded to the mma function (the one shown here), where they are processed by the WMMA instructions.

#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand All @@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)

#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand Down Expand Up @@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif
}
}

8 changes: 7 additions & 1 deletion ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
}

return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
Loading
Loading