Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2203,7 +2203,7 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
}
}

static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {

Expand All @@ -2220,10 +2220,12 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
moe_counts.resize(n_as, 0);
cum_moe_counts.resize(n_as + 1);

bool is_ser = false;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
else is_ser = true;
}
}
cum_moe_counts[0] = 0;
Expand All @@ -2244,16 +2246,20 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n

for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];

CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(),
cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

return is_ser;
}

static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];

CUDA_CHECK(cudaMemsetAsync((char *)dst->data, 0, ggml_nbytes(dst), ctx.stream()));

if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
ggml_is_quantized(src0->type) &&
ggml_backend_buffer_is_cuda(src0->buffer) &&
Expand Down Expand Up @@ -2361,7 +2367,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *

ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
std::vector<int> moe_counts, cum_moe_counts;
prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
if (is_ser) {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
}

ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
Expand Down Expand Up @@ -2519,13 +2528,16 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
auto local_src0 = *next->src[0];
local_src0.ne[2] = local_src0.ne[3] = 1;

CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));

ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());

return true;
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
CUDA_CHECK(cudaGetLastError());
Expand All @@ -2534,7 +2546,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}


GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
Expand Down Expand Up @@ -2662,7 +2673,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
std::vector<int> moe_counts, cum_moe_counts;

prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
if (is_ser) {
if (fuse_down) {
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
}
}

for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = moe_counts[i02];
Expand Down
29 changes: 15 additions & 14 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,21 @@ static __global__ void mul_mat_vec_q(
char * cdst = (char *)dst + i2*nb2;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
if (i02 < 0) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int rows_per_cuda_block = 1;
#else
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int row0 = rows_per_cuda_block*blockIdx.x;
if (threadIdx.y == 0) {
dst = (float *)cdst;
for (int j = 0; j < ncols_y; ++j) {
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = 0;
}
}
}
// We clear the buffer via cudaMemset instead
//#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
// constexpr int rows_per_cuda_block = 1;
//#else
// constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
//#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
// const int row0 = rows_per_cuda_block*blockIdx.x;
// if (threadIdx.y == 0) {
// dst = (float *)cdst;
// for (int j = 0; j < ncols_y; ++j) {
// if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
// dst[j*nrows_dst + row0 + threadIdx.x] = 0;
// }
// }
// }
return;
}
const char * cx = (const char *)vx + i02*nb02;
Expand Down