Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
uint32_t mul_mat_wg_size;
};

/** Cpy **/

struct ggml_webgpu_cpy_pipeline_key {
ggml_type src_type;
ggml_type dst_type;

bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
return src_type == other.src_type && dst_type == other.dst_type;
}
};

struct ggml_webgpu_cpy_pipeline_key_hash {
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
return seed;
}
};

/** Glu **/

struct ggml_webgpu_glu_pipeline_key {
ggml_glu_op glu_op;
ggml_type type;
bool split;

bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
return glu_op == other.glu_op && type == other.type && split == other.split;
}
};

struct ggml_webgpu_glu_pipeline_key_hash {
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.glu_op);
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.split);
return seed;
}
};

/** Rope **/

struct ggml_webgpu_rope_pipeline_key {
ggml_type type;
bool inplace;
bool has_ff;

bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
}
};

struct ggml_webgpu_rope_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.has_ff);
return seed;
}
};

/** SoftMax **/

struct ggml_webgpu_soft_max_pipeline_key {
ggml_type mask_type;
bool has_mask;
bool has_sink;
bool inplace;

bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
inplace == other.inplace;
}
};

struct ggml_webgpu_soft_max_pipeline_key_hash {
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.mask_type);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sink);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};

class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
Expand Down Expand Up @@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;

public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
Expand Down Expand Up @@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib {
return flash_attn_pipelines[key];
}

webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_cpy_pipeline_key key = {
.src_type = context.src0->type,
.dst_type = context.dst->type,
};

auto it = cpy_pipelines.find(key);
if (it != cpy_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "cpy";

switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("SRC_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src type for cpy shader");
}

switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("DST_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported dst type for cpy shader");
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_cpy, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
cpy_pipelines[key] = pipeline;
return cpy_pipelines[key];
}

webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_glu_pipeline_key key = {
.glu_op = ggml_get_glu_op(context.dst),
.type = context.dst->type,
.split = (context.src1 != nullptr),
};

auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "glu";

switch (key.glu_op) {
case GGML_GLU_OP_REGLU:
defines.push_back("OP_REGLU");
variant += "_reglu";
break;
case GGML_GLU_OP_GEGLU:
defines.push_back("OP_GEGLU");
variant += "_geglu";
break;
case GGML_GLU_OP_SWIGLU:
defines.push_back("OP_SWIGLU");
variant += "_swiglu";
break;
case GGML_GLU_OP_SWIGLU_OAI:
defines.push_back("OP_SWIGLU_OAI");
variant += "_swiglu_oai";
break;
case GGML_GLU_OP_GEGLU_ERF:
defines.push_back("OP_GEGLU_ERF");
variant += "_geglu_erf";
break;
case GGML_GLU_OP_GEGLU_QUICK:
defines.push_back("OP_GEGLU_QUICK");
variant += "_geglu_quick";
break;
default:
GGML_ABORT("Unsupported GLU op");
}
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for GLU shader");
}

if (key.split) {
variant += "_split";
} else {
defines.push_back("NO_SPLIT");
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_glu, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
glu_pipelines[key] = pipeline;
return glu_pipelines[key];
}

webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {
.type = context.dst->type,
.inplace = context.inplace,
.has_ff = (context.src2 != nullptr),
};

auto it = rope_pipelines.find(key);
if (it != rope_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "rope";

switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for ROPE shader");
}

if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}

if (key.has_ff) {
defines.push_back("FF_FUNC");
variant += "_ff";
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rope_pipelines[key] = pipeline;
return rope_pipelines[key];
}

webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_soft_max_pipeline_key key = {
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
.has_mask = (context.src1 != nullptr),
.has_sink = (context.src2 != nullptr),
.inplace = context.inplace,
};

auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "soft_max";

if (key.has_mask) {
defines.push_back("HAS_MASK");
switch (key.mask_type) {
case GGML_TYPE_F32:
defines.push_back("MASK_F32");
variant += "_mask_f32";
break;
case GGML_TYPE_F16:
defines.push_back("MASK_F16");
variant += "_mask_f16";
break;
default:
GGML_ABORT("Unsupported type for SOFT_MAX shader");
}
}

if (key.has_sink) {
defines.push_back("HAS_SINK");
variant += "_sink";
}

if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
soft_max_pipelines[key] = pipeline;
return soft_max_pipelines[key];
}

private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
Expand Down
Loading
Loading