Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

static int get_mmq_x_max_host(const int cc) {
static constexpr int get_mmq_x_max_host(int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
Expand All @@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc) {
static constexpr int get_mmq_y_host(int cc) {
return cc >= CC_VOLTA ? 128 : 64;
}

Expand Down
56 changes: 56 additions & 0 deletions ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_A_I16K8 {
Expand All @@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_B_J8K4 {
Expand All @@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(x[0])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_B_J8K8 {
Expand All @@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
GGML_CUDA_ASSUME(ret < K);
return ret;
}

__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
};

struct mma_int_C_I16J8 {
Expand Down
Loading