diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 4c4eda1cbe5..60e98a60741 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -52,7 +52,7 @@ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 -// default size for legacy matrix multiplication +// default size for reg-tile matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost @@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_k = 0; uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; + std::string vendor; }; struct webgpu_pipeline { @@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( /** Matrix Multiplication **/ -struct ggml_webgpu_legacy_mul_mat_pipeline_key { - ggml_type src0_type; - ggml_type src1_type; - - bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; - } -}; - -struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { - size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.src0_type); - ggml_webgpu_hash_combine(seed, key.src1_type); - return seed; - } -}; - struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_mmvq == other.use_mmvq; } }; @@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } }; @@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t vec_size; }; +struct ggml_webgpu_quantize_q8_pipeline_key { + ggml_type src0_type; + + bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; } +}; + +struct ggml_webgpu_quantize_q8_pipeline_key_hash { + size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + return seed; + } +}; + struct ggml_webgpu_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; @@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash { } }; +/** MMVQ **/ + +inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, + const ggml_tensor * src1, + bool supports_dot_product, + const std::string & vendor) { + if (src1->ne[1] == 1) { + bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; + if (supports_dp4a && supports_dot_product) { + switch (src1->type) { + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + return src0->ne[0] % 4 == 0; + default: + break; + } + break; + default: + break; + } + } + } + return false; +} + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; - std::unordered_map - mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map + quantize_q8_pipelines; std::unordered_map mul_mat_id_gather_pipelines; // key is fixed std::unordered_map mul_mat_id_pipelines; // src0_type/src1_type @@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib { return pad_pipelines[key]; } + webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_quantize_q8_pipeline_key key = {}; + key.src0_type = context.src0->type; + + auto it = quantize_q8_pipelines.find(key); + if (it != quantize_q8_pipelines.end()) { + return it->second; + } + const char * shader_src = wgsl_quantize_q8; + std::vector defines; + std::string variant = "quantize_q8"; + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + + defines.push_back("SRC1_INNER_TYPE=f32"); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("Q8_1_T"); + + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + quantize_q8_pipelines[key] = pipeline; + return quantize_q8_pipelines[key]; + } + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; @@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.use_mmvq = + ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); switch (context.src0->type) { + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + if (key.use_mmvq) { + defines.push_back("LEGACY_QUANTS"); + } + break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + if (key.use_mmvq) { + defines.push_back("K_QUANTS"); + } + break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_S: @@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } + if (key.use_mmvq) { + defines.push_back("MMVQ"); + defines.push_back("Q8_1_T"); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); @@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib { return mul_mat_fast_pipelines[key]; } - webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - - auto it = mul_mat_legacy_pipelines.find(key); - if (it != mul_mat_legacy_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "mul_mat"; - - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); - } - - const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); - const char * src0_name = src0_traits->type_name; - - switch (context.src0->type) { - case GGML_TYPE_F32: - defines.push_back("SRC0_TYPE=f32"); - defines.push_back("FLOAT"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC0_TYPE=f16"); - defines.push_back("FLOAT"); - variant += "_f16"; - break; - default: - { - std::string type_upper = src0_name; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - switch (context.src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - { - // Quantized types using u32 buffers for portability. - defines.push_back("SRC0_TYPE=u32"); - defines.push_back("U32_DEQUANT_HELPERS"); - break; - } - default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } - } - - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - defines.push_back(type_upper); - defines.push_back(type_upper + "_SCALE_MIN"); - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); - - variant += std::string("_") + src0_name; - break; - } - } - - auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); - - auto decisions = std::make_shared(); - decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; - - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - mul_mat_legacy_pipelines[key] = pipeline; - return mul_mat_legacy_pipelines[key]; - } - webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = mul_mat_id_gather_pipelines.find(1); if (it != mul_mat_id_gather_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 921c12b41ac..da0600a430b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -181,6 +181,7 @@ struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroups = false; bool supports_subgroup_matrix = false; + bool supports_dot_product = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -210,6 +211,8 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + std::string vendor; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -1384,6 +1387,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + std::vector & dispatches) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx); + + // quantize_q8 pipeline + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t q8_src1_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t q8_src1_binding_size = + ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); + + std::vector q8_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector q8_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size) + }; + + auto q8_decisions = static_cast(qq8_pipeline.context.get()); + + uint32_t q8_wg_size = q8_decisions->wg_size; + uint32_t q8_wg_x = 1; + uint32_t q8_wg_y = 1; + const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; + const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); + + dispatches.push_back({ + qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y } + }); +} + static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1391,47 +1446,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); - // Determine if we should use fast path - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q1_0: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_MXFP4: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } + // use MMVQ path for mat-vec + bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, + ctx->global_ctx->vendor); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; @@ -1446,16 +1463,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product; + shader_lib_ctx.vendor = ctx->global_ctx->vendor; // Get or create pipeline - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; + std::vector dispatches; - if (use_fast && is_vec) { + if (is_vec) { + if (use_mmvq) { + ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); + } pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); - } else if (use_fast) { - pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } else { - pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } // Build params @@ -1479,25 +1500,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Build bind group entries - std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; + std::vector entries = {}; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + if (use_mmvq) { + auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1]; + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset, + mmvq_qq8_entry.size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); // Calculate workgroup dimensions uint32_t wg_x = 1; uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (use_fast && is_vec) { + if (is_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - } else if (use_fast) { + } else { auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations @@ -1518,15 +1545,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - - } else { // legacy - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_size = decisions->wg_size; - uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + dispatches.push_back({ + pipeline, std::move(params), std::move(entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, @@ -3582,6 +3607,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + bool use_mmvq = + ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, + ctx->webgpu_global_ctx->vendor); + if (use_mmvq) { + const size_t q8_src1_size = + src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + res = ROUNDUP_POW2(res + q8_src1_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; case GGML_OP_MUL_MAT_ID: { const ggml_tensor * src0 = tensor->src[0]; @@ -3707,12 +3748,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetInfo(&info); ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); + ctx->webgpu_global_ctx->vendor = info.vendor; wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + // for dot4I8packed + ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature( + wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct); bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 372ea79bf9d..758efa17d77 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -95,11 +95,10 @@ struct q5_1 { }; #endif - #ifdef Q8_1_T struct q8_1 { d: f16, - m: f16, + s: f16, // d * sum(qs[i]) qs: array }; #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index fcbefdeb802..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,747 +0,0 @@ -enable f16; - -#define DECLARE_BYTE_LOADERS_SRC0 -#include "common_decls.tmpl" - - -#ifdef FLOAT -const BLOCK_SIZE = 1u; - -#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) -const BLOCK_SIZE = 32u; - -#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) -const BLOCK_SIZE = 256u; -#endif - -#ifdef FLOAT -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#endif - -#ifdef Q4_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q4_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - let qh_packed = load_u32_at_src0(block_byte_base + 2); - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q8_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q8_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q2_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} -#endif - -#ifdef Q3_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - - // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 108); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - // Bytes 96-107: 12 bytes of scales (3 u32s) - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 96); - scale_vals[1] = load_u32_at_src0(block_byte_base + 100); - scale_vals[2] = load_u32_at_src0(block_byte_base + 104); - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // Bytes 0-31: 32 bytes of hmask (8 u32s) - var hmask_vals: array; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 32-95: 64 bytes of qs (16 u32s) - var qs_vals: array; - for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q4_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef Q5_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q6_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - - // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 208); - - // Bytes 0-127: 128 bytes of ql (32 u32s) - var ql_vals: array; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 128-191: 64 bytes of qh (16 u32s) - var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); - } - - // Bytes 192-207: 16 bytes of scales (4 u32s) - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} -#endif - -#ifdef IQ2_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0_offset = block_byte_base + 2 + ib * 2; - let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at_src0(aux0_offset); - let aux1 = load_u32_at_src0(aux1_offset); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var scale_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qs_vals : array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - - var qh_vals: array; - qh_vals[0] = load_u32_at_src0(block_byte_base + 66); - qh_vals[1] = load_u32_at_src0(block_byte_base + 70); - - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 74); - scale_vals[1] = load_u32_at_src0(block_byte_base + 78); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ3_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at_src0(sc_sign_offset); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} -#endif - -#ifdef IQ3_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qh_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sign_vals: array; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); - } - - var scale_vals = load_u32_at_src0(block_byte_base + 106); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#endif - -#ifdef IQ1_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; - let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - - -#ifdef IQ1_M -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ4_NL -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} -#endif - -#ifdef IQ4_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = unpack2x16float(block.d_scales_h)[0]; - let scales_h = block.d_scales_h >> 16; - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} -#endif - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let global_idx = wg_linear * 256u + local_id.x; - - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_idx >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_idx / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_idx % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a194cf40468..f0a7fbd059a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -3,10 +3,18 @@ enable subgroups; #endif enable f16; +#ifdef MMVQ +requires packed_4x8_integer_dot_product; +#endif + #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef MMVQ +#include "mul_mat_vec_q_acc.tmpl" +#else #include "mul_mat_vec_acc.tmpl" +#endif struct MulMatParams { offset_src0: u32, @@ -28,9 +36,14 @@ struct MulMatParams { }; @group(0) @binding(0) var src0: array; + +#ifdef MMVQ +@group(0) @binding(1) var src1q: array; +#else @group(0) @binding(1) var src1: array; -@group(0) @binding(2) var dst: array; +#endif +@group(0) @binding(2) var dst: array; // "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; @@ -75,10 +88,15 @@ fn main( let src12_idx = dst2_idx; let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; +#ifdef MMVQ + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); +#else + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); +#endif #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 711c7e829d8..08753b9d643 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src } #endif - #ifdef MUL_ACC_Q3_K #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl new file mode 100644 index 00000000000..3ef2f77ebe0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -0,0 +1,303 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +#ifdef LEGACY_QUANTS +#define BLOCK_SIZE 32 +#define THREADS_PER_BLOCK 4 +#elif K_QUANTS +#define BLOCK_SIZE 256 +#define THREADS_PER_BLOCK 16 +#endif + +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +#define Q8_BLOCK_SIZE 32 + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE_BYTES 18 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE_BYTES 20 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base)), + f32(load_f16_at_src0(block_byte_base + 2u)) + ); +} +fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE_BYTES 34 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + return vec2( + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)), + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1)) + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id * 2u], + src1q[block].qs[inner_id * 2u + 1], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[block].d); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds); +} +#endif + +#ifdef LEGACY_QUANTS +fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { + var row_sum = 0; + let a_repacked = repack_a(a_byte_base, b_inner_id); + + row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); + row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); + + return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); +} + +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let b_inner_id = thread_id % THREADS_PER_BLOCK; + let b_block_idx = src1q_idx_base + block; + + let b_repacked = repack_b_qs(b_block_idx, b_inner_id); + let b_ds = repack_b_dm(b_block_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE_BYTES 84 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let ih2 = tid / 8u; + let phase = tid % 2u; + let iq4_idx = 2u * ih2 + phase; + let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx; + let qs_shift = tid & 6u; + return vec4( + (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[q8_block_idx].d); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 80u)), + f32(load_f16_at_src0(block_byte_base + 82u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let scale_byte = block_byte_base + tid; + let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); + return vec2(f32(scale & 0xFu), f32(scale >> 4u)); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE_BYTES 144 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let iq4 = tid / 4u; + let phase = tid % 2u; + let nibble = (tid >> 1u) % 2u; + let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase; + let qs_shift = 4u * nibble; + return vec4( + (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[q8_block_idx].d), + f32(src1q[q8_block_idx].s), + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 0u)), + f32(load_f16_at_src0(block_byte_base + 2u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let sc_m_idx = tid / 2u; + let scales_byte_base = block_byte_base + 4u; + let scales0_3 = load_u32_at_src0_aligned(scales_byte_base); + let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u); + let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u); + + let byte_idx = sc_m_idx & 3u; + let is_high = sc_m_idx >= 4u; + + let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu; + let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u); + let scale = f32(select(sc_low, sc_high, is_high)); + + let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu; + let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u); + let min_val = f32(select(mn_low, mn_high, is_high)); + + return vec2(scale, min_val); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +} +#endif + +#ifdef K_QUANTS +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + + for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { + let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl new file mode 100644 index 00000000000..b3f1fa04b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -0,0 +1,173 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +requires packed_4x8_integer_dot_product; + +#include "common_decls.tmpl" + +struct Params { + offset_src1: u32, + stride_12: u32, + stride_13: u32, + ne0: u32, + ne2: u32, + ne3: u32, +}; + +#define SRC1_TYPE vec4 + +@group(0) @binding(0) var src1: array; +@group(0) @binding(1) var src1q: array; + +@group(0) @binding(2) var params: Params; + +#ifdef USE_SUBGROUP_REDUCTION +fn cluster_max_8(v: f32) -> f32 { + var r = v; + r = max(r, subgroupShuffleXor(r, 1u)); + r = max(r, subgroupShuffleXor(r, 2u)); + r = max(r, subgroupShuffleXor(r, 4u)); + return r; +} + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) +fn cluster_add_i4x8(v: i32) -> i32 { + var r= v; + r += subgroupShuffleXor(r, 1u); + r += subgroupShuffleXor(r, 2u); + r += subgroupShuffleXor(r, 4u); + return r; +} +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION +#define CLUSTER_SIZE 8 + +var partial_amaxs: array, WG_SIZE / CLUSTER_SIZE>; +var partial_sums: array, WG_SIZE / CLUSTER_SIZE>; +#endif + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + let thread_id = local_id.x; + let num_vec4 = params.ne0 / 4u; + + let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne2 * params.ne3; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + if (wg_linear >= total_batches) { + return; + } + + let src13_idx = wg_linear / (params.ne2 * wg_per_vec); + let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; + let src11_wg_idx = wg_linear % wg_per_vec; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let src1_idx_vec4_base = src1_idx_base / 4u; + + let blocks_per_row = params.ne0 / 32u; + let blocks_per_wg = (WG_SIZE * 4u) / 32u; + let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; + let qs_idx = thread_id % 8u; + + // reduction + var q4 = vec4(0.0); + var q4_quants = 0u; + var thread_amax = 0.0; + + let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; + let is_valid = src11_vec4_idx < num_vec4; + +#ifdef USE_SUBGROUP_REDUCTION + + var d = 0.0; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3])); + } + + d = cluster_max_8(thread_amax) / 127.0; + + if (is_valid) { + let id = select(0.0, 1.0 / d, d > 0.0); + q4_quants = pack4xI8(vec4(round(q4 * id))); + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + src1q[src1q_idx].qs[qs_idx] = q4_quants; + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u); + let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum))); + + if (is_valid) { + if (qs_idx == 0u) { + src1q[src1q_idx].s = s; + } + } +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION + + var d = 0.0; + let cluster_id = thread_id / 8u; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3])); + partial_amaxs[cluster_id][qs_idx] = thread_amax; + } + + workgroupBarrier(); + + if (is_valid) { + let amax = max( + max( + max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])), + max( + max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7])) + ); + + d = amax / 127.0; + let id = select(0.0f, 1.0f / d, d > 0.0f); + + q4_quants = pack4xI8(vec4(round(q4 * id))); + src1q[src1q_idx].qs[qs_idx] = q4_quants; + + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + + partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u); + + workgroupBarrier(); + + if (is_valid) { + if (qs_idx == 0u) { + let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3] + + partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]); + src1q[src1q_idx].s = f16(s); + } + } + +#endif +#endif + +}