Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,20 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2

// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256

// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256

#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256

// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512

// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
Expand Down Expand Up @@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
bool src_overlap;

bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
src_overlap == other.src_overlap;
}
};

Expand Down Expand Up @@ -749,36 +759,25 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";

// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_INNER_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_INNER_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}

// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

defines.push_back("BYTE_HELPERS");
Expand All @@ -790,12 +789,35 @@ class ggml_webgpu_shader_lib {
}
}

// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_INNER_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_INNER_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}

// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");

uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;

if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
Expand Down Expand Up @@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {

webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};

Expand Down
Loading
Loading