Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 112 additions & 119 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4

// default size for legacy matrix multiplication
// default size for reg-tile matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256

// Same hash combine function as in boost
Expand Down Expand Up @@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context {
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
std::string vendor;
};

struct webgpu_pipeline {
Expand Down Expand Up @@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(

/** Matrix Multiplication **/

struct ggml_webgpu_legacy_mul_mat_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;

bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type;
}
};

struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
return seed;
}
};

struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
bool use_mmvq;

bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
use_mmvq == other.use_mmvq;
}
};

Expand All @@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
};
Expand All @@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions {
uint32_t vec_size;
};

struct ggml_webgpu_quantize_q8_pipeline_key {
ggml_type src0_type;

bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
};

struct ggml_webgpu_quantize_q8_pipeline_key_hash {
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
return seed;
}
};

struct ggml_webgpu_mul_mat_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
Expand Down Expand Up @@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash {
}
};

/** MMVQ **/

inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] == 1) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
case GGML_TYPE_F32:
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
return src0->ne[0] % 4 == 0;
default:
break;
}
break;
default:
break;
}
}
}
return false;
}

class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
Expand Down Expand Up @@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline,
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
mul_mat_vec_pipelines; // fast mat-vec (n==1)
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
quantize_q8_pipelines;
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
mul_mat_id_pipelines; // src0_type/src1_type
Expand Down Expand Up @@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);

auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
Expand Down Expand Up @@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib {
return pad_pipelines[key];
}

webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_quantize_q8_pipeline_key key = {};
key.src0_type = context.src0->type;

auto it = quantize_q8_pipelines.find(key);
if (it != quantize_q8_pipelines.end()) {
return it->second;
}
const char * shader_src = wgsl_quantize_q8;
std::vector<std::string> defines;
std::string variant = "quantize_q8";

uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;

defines.push_back("SRC1_INNER_TYPE=f32");
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));

const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("Q8_1_T");

defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";

auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
quantize_q8_pipelines[key] = pipeline;
return quantize_q8_pipelines[key];
}

webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
Expand All @@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);

auto it = mul_mat_vec_pipelines.find(key);
if (it != mul_mat_vec_pipelines.end()) {
Expand Down Expand Up @@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib {
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
switch (context.src0->type) {
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
if (key.use_mmvq) {
defines.push_back("LEGACY_QUANTS");
}
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
if (key.use_mmvq) {
defines.push_back("K_QUANTS");
}
break;
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_S:
Expand Down Expand Up @@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}

if (key.use_mmvq) {
defines.push_back("MMVQ");
defines.push_back("Q8_1_T");
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
Expand Down Expand Up @@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib {
return mul_mat_fast_pipelines[key];
}

webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;

auto it = mul_mat_legacy_pipelines.find(key);
if (it != mul_mat_legacy_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "mul_mat";

switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back("SRC1_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC1_TYPE=f16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
}

const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
const char * src0_name = src0_traits->type_name;

switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_TYPE=f32");
defines.push_back("FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_TYPE=f16");
defines.push_back("FLOAT");
variant += "_f16";
break;
default:
{
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC0_TYPE=u32");
defines.push_back("U32_DEQUANT_HELPERS");
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}

defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
defines.push_back(type_upper + "_SCALE_MIN");
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");

variant += std::string("_") + src0_name;
break;
}
}

auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);

auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;

webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
mul_mat_legacy_pipelines[key] = pipeline;
return mul_mat_legacy_pipelines[key];
}

webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = mul_mat_id_gather_pipelines.find(1);
if (it != mul_mat_id_gather_pipelines.end()) {
Expand Down
Loading
Loading