Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,20 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2

// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256

// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256

#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256

// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512

// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
Expand Down Expand Up @@ -183,7 +192,8 @@ struct ggml_webgpu_binary_pipeline_key {
bool src_overlap;

bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
src_overlap == other.src_overlap;
}
};

Expand Down Expand Up @@ -731,36 +741,25 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";

// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_INNER_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_INNER_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}

// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

defines.push_back("BYTE_HELPERS");
Expand All @@ -772,12 +771,35 @@ class ggml_webgpu_shader_lib {
}
}

// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_INNER_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_INNER_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}

// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");

uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;

if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
Expand Down Expand Up @@ -1043,10 +1065,10 @@ class ggml_webgpu_shader_lib {

webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};

Expand Down
28 changes: 22 additions & 6 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
Expand Down Expand Up @@ -776,7 +777,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "\nggml_webgpu: gpu breakdown:\n";
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
<< pct << "%)\n";
}
#endif

Expand Down Expand Up @@ -836,7 +838,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);

return flags;
}
Expand Down Expand Up @@ -1079,12 +1081,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
break;
default:
break;
}
Expand Down Expand Up @@ -1153,8 +1169,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
};

// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
uint32_t wg_x = 1;
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;

if (use_fast && is_vec) {
Expand Down Expand Up @@ -1410,7 +1426,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
uint32_t offset_merged_src0 = 0;
uint32_t offset_merged_src1 = 0;
if (flags.src_overlap) {
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
}
Expand All @@ -1419,7 +1435,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
Expand Down
Loading
Loading