Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
// MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
};
// Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR

constexpr bool is_i_major(const data_layout dl) {
static constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
}

constexpr data_layout get_input_data_layout() {
#if defined(RDNA3)
return DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr __device__ data_layout get_input_data_layout() {
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
return DATA_LAYOUT_I_MAJOR_MIRRORED;
#else
return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}

template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
Expand Down Expand Up @@ -462,31 +460,35 @@ namespace ggml_cuda_mma {
}
};

template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
static constexpr int ne = I * J / (WARP_SIZE/4);

half2 x[ne] = {{0.0f, 0.0f}};
// RDNA3
static constexpr int ne = I * J / 32 * 2;

T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 8 && J == 4) return true;
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int /*l*/) {
if constexpr (I == 8 && J == 4) {
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
Expand All @@ -496,10 +498,27 @@ namespace ggml_cuda_mma {
};

template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
#if defined(RDNA3)
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;

half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
}

static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
}

static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
}
#else // Volta
static constexpr int ne = I * J / (WARP_SIZE/4);

half2 x[ne] = {{0.0f, 0.0f}};
Expand All @@ -509,9 +528,9 @@ namespace ggml_cuda_mma {
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
static __device__ __forceinline__ int get_i(const int /*l*/) {
if constexpr (I == 8 && J == 4) {
return ((l / 2) * 4) + (threadIdx.x % 4);
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
Expand All @@ -520,43 +539,63 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return ((threadIdx.x / 16) * 2) + (l % 2);
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#endif // defined(RDNA3)
};

template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;

static constexpr int ne = I * J / 32 * 2;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

T x[ne] = {0};
static constexpr __device__ bool supported() {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
}

static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
}

static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
}
};

template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
static constexpr int ne = I * J / (WARP_SIZE/4);

half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
if (I == 8 && J == 4) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (supported()) {
return threadIdx.x % 16;
if constexpr (I == 8 && J == 4) {
return ((l / 2) * 4) + (threadIdx.x % 4);
} else {
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
if constexpr (I == 8 && J == 4) {
return ((threadIdx.x / 16) * 2) + (l % 2);
} else {
NO_DEVICE_CODE;
return -1;
Expand Down
Loading