Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
};

uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}

static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
Expand Down Expand Up @@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,

auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());

uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);

return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
Expand Down Expand Up @@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,

auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());

uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);

return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
Expand Down Expand Up @@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
gathered_count_ids_binding_size),
};

const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;

const uint32_t gather_total_wg = param_n_expert;
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
// n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B)
const uint32_t gather_wg_x = param_n_expert;

dispatches.push_back({
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y }
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 }
});

// params for mul_mat_id.wgsl
Expand Down Expand Up @@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
uint32_t total_wg = wg_m * max_wg_n;

compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);

dispatches.push_back({
main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y }
Expand Down Expand Up @@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
block_size, npr, nrows
};

const uint32_t total_wg_init = npr * nrows;
const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
uint32_t wg_x_init;
uint32_t wg_y_init;
const uint32_t total_wg_init = npr * nrows;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init);

std::vector<wgpu::BindGroupEntry> init_entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size)
Expand Down Expand Up @@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out)
};

uint32_t wg_x_merge;
uint32_t wg_y_merge;
const uint32_t total_wg_merge = nm * nrows;
const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge);

dispatches.push_back({
argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge }
});
Expand Down Expand Up @@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s

webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);

uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);

return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}

Expand Down
10 changes: 6 additions & 4 deletions ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,22 @@ struct Params{
var<uniform> params: Params;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(
@builtin(global_invocation_index) gindex: u32,
) {
if (gindex >= params.ne) {
return;
}

var i = gid.x;
var i = gindex;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;

var j = gid.x;
var j = gindex;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
Expand Down
43 changes: 20 additions & 23 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,32 @@ var<workgroup> count:atomic<u32>;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
@builtin(local_invocation_id) local_id: vec3<u32>) {

let thread_id = local_id.x;
let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup
let own_expert = wg_id.x; // the expert assigned to this workgroup

if (own_expert < params.n_expert) {
if (thread_id == 0u) {
atomicStore(&count, 0);
}
if (thread_id == 0u) {
atomicStore(&count, 0);
}

workgroupBarrier();

for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
let row = i / params.n_expert_used;
let col = i % params.n_expert_used;
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
if (own_expert == expert) {
let pos = atomicAdd(&count, 1u);
let gathered_id = own_expert * params.n_tokens + pos;
global_gathered_expert_used[gathered_id] = col;
global_gathered_tokens[gathered_id] = row;
}
workgroupBarrier();

for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
let row = i / params.n_expert_used;
let col = i % params.n_expert_used;
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
if (own_expert == expert) {
let pos = atomicAdd(&count, 1u);
let gathered_id = own_expert * params.n_tokens + pos;
global_gathered_expert_used[gathered_id] = col;
global_gathered_tokens[gathered_id] = row;
}
}

workgroupBarrier();
workgroupBarrier();

if (thread_id == 0u) {
gathered_count_ids[own_expert] = atomicLoad(&count);
}
if (thread_id == 0u) {
gathered_count_ids[own_expert] = atomicLoad(&count);
}
}
Loading