Skip to content

Commit 0fb80c7

Browse files
committed
add restore kernel for moe transpose
1 parent 2ca4fe1 commit 0fb80c7

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ struct ggml_backend_opencl_context {
453453
cl_kernel kernel_mul_mat_f16_f32_tiled;
454454
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
455455
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
456-
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4;
456+
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
457457
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
458458
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
459459
cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -780,6 +780,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
780780
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
781781
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
782782
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
783+
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
783784
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
784785
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
785786
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -3338,6 +3339,11 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
33383339
tensor->ne[2] == 1 && tensor->ne[3] == 1;
33393340
}
33403341

3342+
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
3343+
int ne01 = tensor->ne[1];
3344+
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
3345+
}
3346+
33413347
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
33423348
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
33433349

@@ -3641,13 +3647,12 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
36413647
CL_CHECK(err);
36423648

36433649
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3644-
if (strstr(tensor->name, "ffn") != NULL) {
3650+
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
3651+
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
3652+
36453653
int ne00 = tensor->ne[0];
36463654
int ne01 = tensor->ne[1];
36473655
int ne02 = tensor->ne[2];
3648-
3649-
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
3650-
36513656
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
36523657
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
36533658
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
@@ -3815,6 +3820,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
38153820
ggml_nbytes(tensor), NULL, &err);
38163821
CL_CHECK(err);
38173822

3823+
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3824+
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
3825+
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
3826+
3827+
int ne00 = tensor->ne[0];
3828+
int ne01 = tensor->ne[1];
3829+
int ne02 = tensor->ne[2];
3830+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
3831+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
3832+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
3833+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
3834+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
3835+
3836+
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
3837+
size_t local_work_size[3] = {64, 2, 1};
3838+
3839+
cl_event evt;
3840+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
3841+
global_work_size, local_work_size, 0, NULL, &evt));
3842+
CL_CHECK(clWaitForEvents(1, &evt));
3843+
CL_CHECK(clEnqueueReadBuffer(
3844+
queue, data_device, CL_TRUE, offset,
3845+
size, data, 0, NULL, NULL));
3846+
CL_CHECK(clReleaseMemObject(data_device));
3847+
return;
3848+
}
3849+
#endif
38183850
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
38193851
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
38203852
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -7766,6 +7798,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
77667798

77677799
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
77687800

7801+
int tile_size = 320;
77697802
if (ne12 == 1) { // for gemv
77707803
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
77717804

@@ -7785,7 +7818,6 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
77857818
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
77867819

77877820
// preprocess router table
7788-
int tile_size = 320;
77897821
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
77907822
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
77917823
void * host_src2 = malloc(ne21 * nb21);
@@ -7842,7 +7874,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
78427874
if (ne12 == 1) {
78437875
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
78447876
} else {
7845-
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne02));
7877+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
78467878
}
78477879

78487880
// launch kernel

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
183183
}
184184
}
185185

186+
kernel void kernel_restore_block_mxfp4_trans(
187+
__global uint4 * src_q,
188+
__global uchar * src_e,
189+
global struct block_mxfp4 * dst,
190+
uint ne00,
191+
uint ne01
192+
) {
193+
int i00 = get_global_id(1);
194+
uint i01 = get_global_id(0);
195+
uint i02 = get_global_id(2);
196+
197+
uint ne00_blk = ne00 / QK_MXFP4;
198+
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
199+
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
200+
201+
global struct block_mxfp4 * b = dst + dst_blk_offset;
202+
203+
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
204+
b->e = src_e[src_blk_offset];
205+
}
206+
186207
//------------------------------------------------------------------------------
187208
// block_q8_0
188209
//------------------------------------------------------------------------------

ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#define QK_MXFP4 32
66
#define N_SIMDGROUP 2
77
#define SIMDGROUP_WIDTH 64
8-
#define TILE_SIZE 320
98

109
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
1110
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
@@ -76,7 +75,7 @@ __kernel void kernel_gemm_moe_mxfp4_f32(
7675
ulong offsetd,
7776
int ne00,
7877
int ne01,
79-
int ne02
78+
int tile_size
8079
) {
8180
uint i01 = get_global_id(0);
8281
uint i20 = get_global_id(2);
@@ -89,12 +88,12 @@ __kernel void kernel_gemm_moe_mxfp4_f32(
8988
ushort i1 = router.z;
9089
ushort tile_id = router.w;
9190

92-
if (tile_id * TILE_SIZE + i01 > ne01) { // handle edge case when ne01 is not multiple of TILE_SIZE
91+
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
9392
return;
9493
}
9594

9695
uint expert_offset = expert_id * ne00 * ne01 / 32;
97-
uint tile_offset = expert_offset + tile_id * TILE_SIZE + i01;
96+
uint tile_offset = expert_offset + tile_id * tile_size + i01;
9897

9998
__private float sum = 0.0f; // each thread calculate partial sum of one output
10099

@@ -157,7 +156,7 @@ __kernel void kernel_gemm_moe_mxfp4_f32(
157156
// 1 outputs per thread in subgroup 0
158157
if (sgid == 0) {
159158
dst = dst + (offsetd >> 2);
160-
dst[i01 + tile_id * TILE_SIZE + i1 * ne01] = sum;
159+
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
161160
}
162161

163162
}

0 commit comments

Comments
 (0)