-
Notifications
You must be signed in to change notification settings - Fork 19.2k
ggml-webgpu: add q4_0/q8_0 SET_ROWS #23760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
00d3b17
12910a9
91d3a3c
27ca550
338b57a
d30666a
acf393b
64c3e98
e221316
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context { | |
| ggml_tensor * src5; | ||
| ggml_tensor * dst; | ||
|
|
||
| uint32_t max_wg_size; | ||
| size_t wg_mem_limit_bytes = 0; | ||
| bool supports_subgroups = false; | ||
| bool supports_subgroup_matrix = false; | ||
| uint32_t sg_mat_m = 0; | ||
| uint32_t sg_mat_n = 0; | ||
| uint32_t sg_mat_k = 0; | ||
| uint32_t min_subgroup_size = 0; | ||
| uint32_t max_subgroup_size = 0; | ||
| bool supports_dot_product = false; | ||
| uint32_t max_wg_size; | ||
| size_t wg_mem_limit_bytes = 0; | ||
| bool supports_subgroups = false; | ||
| bool supports_subgroup_matrix = false; | ||
| uint32_t sg_mat_m = 0; | ||
| uint32_t sg_mat_n = 0; | ||
| uint32_t sg_mat_k = 0; | ||
| uint32_t min_subgroup_size = 0; | ||
| uint32_t max_subgroup_size = 0; | ||
| bool supports_dot_product = false; | ||
| std::string vendor; | ||
| }; | ||
|
|
||
|
|
@@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key { | |
| int dst_type; | ||
| int vec4; | ||
| int i64_idx; | ||
| int pair_blocks; | ||
|
|
||
| bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { | ||
| return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; | ||
| return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && | ||
| pair_blocks == other.pair_blocks; | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -178,13 +180,15 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { | |
| ggml_webgpu_hash_combine(seed, key.dst_type); | ||
| ggml_webgpu_hash_combine(seed, key.vec4); | ||
| ggml_webgpu_hash_combine(seed, key.i64_idx); | ||
| ggml_webgpu_hash_combine(seed, key.pair_blocks); | ||
| return seed; | ||
| } | ||
| }; | ||
|
|
||
| struct ggml_webgpu_set_rows_shader_decisions { | ||
| bool vec4; | ||
| bool i64_idx; | ||
| bool pair_blocks; | ||
| uint32_t wg_size; | ||
| }; | ||
|
|
||
|
|
@@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( | |
| (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); | ||
| const bool kv_vec_type_supported = | ||
| K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; | ||
| const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : | ||
| (uint32_t) ggml_blck_size(K->type); | ||
| const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && | ||
| context.src2->ne[0] % kv_vec_head_align == 0; | ||
| const uint32_t kv_vec_head_align = | ||
| K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); | ||
| const bool kv_vec_head_dims_aligned = | ||
| context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; | ||
| // Compile with enough invocations to cover the largest reported subgroup. | ||
| const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && | ||
| kv_vec_head_dims_aligned && kv_vec_type_supported && | ||
| (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && | ||
| const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && | ||
| kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && | ||
| (context.src2->type == K->type); | ||
| const bool tile_can_dispatch_all_q_rows = | ||
| context.max_subgroup_size > 0 && | ||
| context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; | ||
| const bool use_subgroup_matrix = | ||
| context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && | ||
| context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; | ||
| const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && | ||
| context.src0->ne[0] % context.sg_mat_k == 0 && | ||
| context.src2->ne[0] % context.sg_mat_n == 0; | ||
| const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && | ||
| V->type == GGML_TYPE_F16 && f16_vec4_aligned && | ||
| (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && | ||
| (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && | ||
| tile_can_dispatch_all_q_rows && !use_vec; | ||
|
|
||
| decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : | ||
| use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : | ||
| use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : | ||
| GGML_WEBGPU_FLASH_ATTN_PATH_NONE; | ||
| decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : | ||
| use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : | ||
| use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : | ||
| GGML_WEBGPU_FLASH_ATTN_PATH_NONE; | ||
|
|
||
| if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { | ||
| return decisions; | ||
|
|
@@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib { | |
| ggml_webgpu_flash_attn_blk_pipeline_key_hash> | ||
| flash_attn_blk_pipelines; | ||
| std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash> | ||
| mul_mat_vec_pipelines; // fast mat-vec (n==1) | ||
| mul_mat_vec_pipelines; // fast mat-vec (n==1) | ||
| std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash> | ||
| mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) | ||
| mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) | ||
| std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash> | ||
| quantize_q8_pipelines; | ||
| std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed | ||
|
|
@@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib { | |
| } | ||
|
|
||
| webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { | ||
| ggml_webgpu_set_rows_pipeline_key key = {}; | ||
| key.dst_type = context.dst->type; | ||
| key.vec4 = context.src0->ne[0] % 4 == 0; | ||
| key.i64_idx = context.src1->type == GGML_TYPE_I64; | ||
| const bool quantized = ggml_is_quantized(context.dst->type); | ||
| ggml_webgpu_set_rows_pipeline_key key = {}; | ||
| key.dst_type = context.dst->type; | ||
| key.vec4 = | ||
| (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; | ||
| key.i64_idx = context.src1->type == GGML_TYPE_I64; | ||
| key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); | ||
|
|
||
| auto it = set_rows_pipelines.find(key); | ||
| if (it != set_rows_pipelines.end()) { | ||
|
|
@@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib { | |
| defines.push_back("DST_F16"); | ||
| variant += "_dstf16"; | ||
| break; | ||
| case GGML_TYPE_Q8_0: | ||
| defines.push_back("DST_Q8_0"); | ||
| variant += "_dstq8_0"; | ||
| break; | ||
| case GGML_TYPE_Q4_0: | ||
| defines.push_back("DST_Q4_0"); | ||
| variant += "_dstq4_0"; | ||
| break; | ||
| default: | ||
| GGML_ABORT("Unsupported dst type for set_rows shader"); | ||
| } | ||
|
|
@@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib { | |
| defines.push_back("I64_IDX"); | ||
| variant += "_i64idx"; | ||
| } | ||
| if (key.pair_blocks) { | ||
| defines.push_back("PAIR_BLOCKS"); | ||
| variant += "_pair_blocks"; | ||
| } | ||
|
|
||
| defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); | ||
|
|
||
| auto processed = preprocessor.preprocess(wgsl_set_rows, defines); | ||
| auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); | ||
| const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; | ||
| auto processed = preprocessor.preprocess(shader_source, defines); | ||
| auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); | ||
| decisions->vec4 = key.vec4; | ||
| decisions->i64_idx = key.i64_idx; | ||
| decisions->pair_blocks = key.pair_blocks; | ||
| decisions->wg_size = context.max_wg_size; | ||
| set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); | ||
| set_rows_pipelines[key].context = decisions; | ||
|
|
@@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib { | |
| key.type = context.dst->type; | ||
| key.d_state = (int) context.src0->ne[0]; | ||
| key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && | ||
| ggml_webgpu_tensor_overlap(context.src1, context.src5); | ||
| ggml_webgpu_tensor_overlap(context.src1, context.src5); | ||
|
Comment on lines
1662
to
+1683
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pretty sure this was re-indented the other way in a previous PR, is it conflicting settings or manual work?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked this in my environment, and the clangd version seems to be the main cause. I had been using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a bit annoying they've changed the formatting. My clangd version is 21.1.6
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated to 22.1.6 and this change is avoided. I'll plan on using the more recent clangd going forward! |
||
|
|
||
| auto it = ssm_scan_pipelines.find(key); | ||
| if (it != ssm_scan_pipelines.end()) { | ||
|
|
@@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib { | |
| (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? | ||
| 1 : | ||
| 0; | ||
| key.use_mmvq = | ||
| key.use_mmvq = | ||
| ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); | ||
|
|
||
| auto it = mul_mat_vec_pipelines.find(key); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This went a bit funny. :)