Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 123 additions & 113 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,23 @@ namespace ggml_cuda_mma {
// For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
};
// Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR

constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
}

template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
struct tile {};

Expand Down Expand Up @@ -115,9 +123,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l;
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
return threadIdx.x % 32;
} else {
NO_DEVICE_CODE;
return -1;
Expand All @@ -132,9 +140,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
NO_DEVICE_CODE;
return -1;
Expand Down Expand Up @@ -171,28 +179,19 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
#elif defined(RDNA3)
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
#endif // defined(RDNA4)
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
#if defined(RDNA4)
return 8 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
Expand All @@ -201,7 +200,17 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
// matrix C
#if defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
return ne * (threadIdx.x / 16) + l;
#endif // defined(RDNA3)
} else if constexpr (I == 16 && J == 8) {
// mmq input for RDNA4
return ne * (threadIdx.x / 16) + l;
} else if constexpr (I == 16 && J == 4) {
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
Expand Down Expand Up @@ -293,12 +302,7 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
Expand All @@ -317,14 +321,7 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
Expand Down Expand Up @@ -382,42 +379,19 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;

#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#else
static constexpr int ne = I * J / WARP_SIZE;
Expand Down Expand Up @@ -458,6 +432,28 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE)
};

template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;

static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
}

static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
}

static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
};

template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
Expand Down Expand Up @@ -524,6 +520,42 @@ namespace ggml_cuda_mma {
}
};

template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;

static constexpr int ne = I * J / 32 * 2;

T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};

#if defined(TURING_MMA_AVAILABLE)
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
Expand Down Expand Up @@ -569,55 +601,28 @@ namespace ggml_cuda_mma {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#if defined(RDNA4)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
#elif defined(RDNA3)
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
#else
NO_DEVICE_CODE;
#endif // defined(RDNA4)
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
xi[0] = xs[0];
xi[1] = xs[1];
#endif // defined(RDNA4)
} else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];

const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
xi[0] = xs[0];
xi[1] = xs[1];
const int64_t * xs1 = xs + 2;
xi[2] = xs1[0];
xi[3] = xs1[1];
#endif // defined(RDNA4)
// All wmma layout has continues data when i-major.
if constexpr (is_i_major(dl)) {
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
#pragma unroll
for (int i = 0; i < aligned_copy_count; ++i) {
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
}
} else {
NO_DEVICE_CODE;
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
} else {
NO_DEVICE_CODE;
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
#else
#pragma unroll
Expand Down Expand Up @@ -660,9 +665,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}

template <typename T>
template <typename T, data_layout dl>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
Expand Down Expand Up @@ -832,8 +837,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}

template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
Expand Down Expand Up @@ -886,9 +892,10 @@ namespace ggml_cuda_mma {
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}


template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
Expand Down Expand Up @@ -939,9 +946,10 @@ namespace ggml_cuda_mma {
NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE
}


template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
Expand All @@ -967,8 +975,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}

template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
#if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x;
Expand Down Expand Up @@ -1122,8 +1131,9 @@ namespace ggml_cuda_mma {
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma(
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
Expand Down
Loading