Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 161 additions & 13 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context {
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t max_subgroup_size = 0;
uint32_t n_experts = 0;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry one last change here: I want to try and avoid putting any op-specific information in this context, and instead derive it in the library as needed from the context. In this case, we can derive n_experts from the src0 tensor directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. I'll update.

};

struct webgpu_pipeline {
Expand Down Expand Up @@ -664,7 +665,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
}
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
}
Expand Down Expand Up @@ -701,10 +702,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
Expand Down Expand Up @@ -862,9 +863,12 @@ struct ggml_webgpu_mul_mat_shader_decisions {
struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
int vectorized;

bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type;
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
vectorized == other.vectorized;
}
};

Expand All @@ -873,6 +877,8 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
Expand Down Expand Up @@ -1023,6 +1029,8 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
mul_mat_id_pipelines; // src0_type/src1_type
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
mul_mat_id_vec_pipelines; // src0_type/src1_type

std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
Expand Down Expand Up @@ -1516,7 +1524,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);

auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
Expand Down Expand Up @@ -1633,10 +1641,10 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;

auto it = mul_mat_vec_pipelines.find(key);
if (it != mul_mat_vec_pipelines.end()) {
Expand Down Expand Up @@ -2012,6 +2020,11 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_id_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.n_experts;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;

auto it = mul_mat_id_pipelines.find(key);
if (it != mul_mat_id_pipelines.end()) {
Expand Down Expand Up @@ -2041,14 +2054,12 @@ class ggml_webgpu_shader_lib {
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
variant += "_f16";
Expand All @@ -2064,12 +2075,32 @@ class ggml_webgpu_shader_lib {
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");

switch (context.src0->type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
defines.push_back(type_upper + "_GRID");
break;
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
defines.push_back(type_upper + "_GRID");
defines.push_back(type_upper + "_TABLES");
break;
default:
break;
}

variant += std::string("_") + src0_name;
break;
}
}

defines.push_back("SCALAR");
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");

// mul_mat_id is register-tile only.
const uint32_t tile_k =
Expand Down Expand Up @@ -2102,6 +2133,123 @@ class ggml_webgpu_shader_lib {
return mul_mat_id_pipelines[key];
}

webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_id_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.n_experts;
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;

auto it = mul_mat_id_vec_pipelines.find(key);
if (it != mul_mat_id_vec_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "mul_mat_id_vec";
const char * shader_src = wgsl_mul_mat_id_vec;

// src1 type
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_INNER_TYPE=f32");
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_INNER_TYPE=f16");
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
}

// src0 type
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
switch (context.src0->type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
defines.push_back(type_upper + "_GRID");
break;
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
defines.push_back(type_upper + "_GRID");
defines.push_back(type_upper + "_TABLES");
break;
default:
break;
}
break;
}
}

// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");

uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;

if (key.src0_type == GGML_TYPE_Q1_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}

// variant suffix for src1 type
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
if (key.vectorized) {
variant += "_vectorized";
}

defines.push_back(std::string("N_EXPERTS=") + std::to_string(context.n_experts));

auto processed = preprocessor.preprocess(shader_src, defines);

auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
decisions->wg_size = wg_size;
decisions->outputs_per_wg = outputs_per_wg;

webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
mul_mat_id_vec_pipelines[key] = pipeline;
return mul_mat_id_vec_pipelines[key];
}

webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool is_unary = context.dst->op == GGML_OP_UNARY;
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
Expand Down
75 changes: 74 additions & 1 deletion ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
Expand Down Expand Up @@ -1527,16 +1526,81 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}

static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];

ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst;
shader_lib_ctx.n_experts = src0->ne[2];
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;

webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx);

std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
param_n_expert,
param_n_expert_used,
(uint32_t) src1->ne[1],
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
};

std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0),
ggml_webgpu_tensor_binding_size(ctx, src0)),
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1),
ggml_webgpu_tensor_binding_size(ctx, src1)),
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2),
ggml_webgpu_tensor_binding_size(ctx, src2)),
ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst),
ggml_webgpu_tensor_binding_size(ctx, dst)),
};

uint32_t wg_x = 1;
uint32_t wg_y = 1;

auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());

const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * param_n_expert_used;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);

return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}

static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
// we can use mat-vec fast path
if (dst->ne[2] == 1) {
return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst);
}

ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst;
shader_lib_ctx.n_experts = src0->ne[2];
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;

// Get or create pipeline
Expand Down Expand Up @@ -3879,6 +3943,15 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
supports_op = true;
break;
default:
Expand Down
Loading
Loading