diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 369475eaf50..05b3844a003 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -42,11 +42,20 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 // Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 + // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 + +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 + +// Requires 32 threads per output (wg_size/outputs_per_wg == 32) +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 +// Requires at least two (and multiple of 2) k-quant blocks per tile +#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -183,7 +192,8 @@ struct ggml_webgpu_binary_pipeline_key { bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && + src_overlap == other.src_overlap; } }; @@ -731,29 +741,17 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "mul_mat_vec"; - // src1 type (vector) - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_INNER_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_INNER_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); - } - // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; break; default: { @@ -761,6 +759,7 @@ class ggml_webgpu_shader_lib { const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); std::string src0_name = src0_traits->type_name; std::string type_upper = src0_name; + variant += "_" + src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back("BYTE_HELPERS"); @@ -772,12 +771,35 @@ class ggml_webgpu_shader_lib { } } + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; - uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type >= GGML_TYPE_Q2_K) { + tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); @@ -1043,10 +1065,10 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, .src_overlap = context.src_overlap, }; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 19451618ec5..2fbf2e8a748 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -776,7 +777,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "\nggml_webgpu: gpu breakdown:\n"; for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; - std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) + << pct << "%)\n"; } #endif @@ -836,7 +838,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1079,12 +1081,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, use_fast = (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: + // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q6_K: use_fast = true; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat + use_fast = !is_vec; + break; default: break; } @@ -1153,8 +1169,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { @@ -1410,7 +1426,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t offset_merged_src0 = 0; uint32_t offset_merged_src1 = 0; if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); } @@ -1419,7 +1435,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), offset_merged_src0, offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 5c1074ebc10..f28c278c8b8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -101,3 +101,675 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif + +#ifdef INIT_SRC0_SHMEM_Q4_1 + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_lo = f16(q_byte & 0xF) * d + m; + let q_hi = f16((q_byte >> 4) & 0xF) * d + m; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q5_0 + +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif + + +#ifdef INIT_SRC0_SHMEM_Q5_1 + +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q8_0 + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q8_1 + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d + m; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q2_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 42u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + // Use standard thread layout instead of lane/row_group + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 40u]; + let dmin = src0[scale_idx + 41u]; + + // Decode the element at position k_in_block + let block_of_32 = k_in_block / 32u; + let pos_in_32 = k_in_block % 32u; + + let q_b_idx = (block_of_32 / 4u) * 32u; + let shift = (block_of_32 % 4u) * 2u; + let k = (pos_in_32 / 16u) * 16u; + let l = pos_in_32 % 16u; + + let is = k_in_block / 16u; + + let sc_0 = src0[scale_idx + 2u * (is / 4u)]; + let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; + let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc = get_byte(sc_packed, is % 4u); + + let dl = d * f16(sc & 0xFu); + let ml = dmin * f16(sc >> 4u); + + let q_idx = q_b_idx + k + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q3_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 55u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 54u]; + + // Load and unpack scales + let kmask1: u32 = 0x03030303u; + let kmask2: u32 = 0x0f0f0f0fu; + + var scale_vals: array; + for (var i: u32 = 0u; i < 4u; i++) { + let scale_0 = src0[scale_idx + 48u + (2u*i)]; + let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); + scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); + + // Load hmask and qs arrays + var hmask_vals: array; + for (var i: u32 = 0u; i < 8u; i++) { + let hmask_0 = src0[scale_idx + (2u*i)]; + let hmask_1 = src0[scale_idx + (2u*i) + 1u]; + hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + } + + var qs_vals: array; + for (var i: u32 = 0u; i < 16u; i++) { + let qs_0 = src0[scale_idx + 16u + (2u*i)]; + let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; + qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + } + + let half = k_in_block / 128u; // 0 or 1 + let pos_in_half = k_in_block % 128u; // 0-127 + let shift_group = pos_in_half / 32u; // 0-3 + let pos_in_32 = pos_in_half % 32u; // 0-31 + let k_group = pos_in_32 / 16u; // 0 or 1 + let l = pos_in_32 % 16u; // 0-15 + + let q_b_idx = half * 32u; // 0 or 32 + let shift = shift_group * 2u; // 0, 2, 4, 6 + let k = k_group * 16u; // 0 or 16 + let is = k_in_block / 16u; // 0-15 + + // m increments every 32 elements across entire 256 element block + let m_shift = k_in_block / 32u; // 0-7 + let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 + + let sc = get_byte(scale_vals[is / 4u], is % 4u); + let dl = d * (f16(sc) - 32.0); + + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + + let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); + + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = (f16(qs_val) - f16(hm)) * dl; + shmem[elem_idx] = q_val; + } +} + +#endif +#ifdef INIT_SRC0_SHMEM_Q4_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 72u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // Map k_in_block to loop structure: + // Outer loop over 64-element groups (alternating q_b_idx) + // Inner loop over 2 shifts per group + let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 0xFu; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif + +#ifdef INIT_SRC0_SHMEM_Q5_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 88u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // The original loop processes elements in groups of 64 + // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] + // But u increments EVERY 32 elements (after each l loop) + let group_of_64 = k_in_block / 64u; // 0-3 + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) + let u_shift = k_in_block / 32u; // 0-7 + let u: u32 = 1u << u_shift; + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + + let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; + let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; + let qh_packed = bitcast(vec2(qh_0, qh_1)); + + let qh_byte = get_byte(qh_packed, l % 4u); + + let qs_val = (q_byte >> shift) & 0xFu; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + + let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif +#ifdef INIT_SRC0_SHMEM_Q6_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let half = k_in_block / 128u; + let pos_in_half = k_in_block % 128u; + let quarter = pos_in_half / 32u; + let l = pos_in_half % 32u; + + let ql_b_idx = half * 64u; + let qh_b_idx = half * 32u; + let sc_b_idx = half * 8u; + + // Load only ql13 word needed + let ql13_flat = ql_b_idx + l; + let ql13_word = ql13_flat / 4u; + let ql13 = bitcast(vec2( + src0[scale_idx + 2u * ql13_word], + src0[scale_idx + 2u * ql13_word + 1u] + )); + let ql13_b = get_byte(ql13, ql13_flat % 4u); + + // Load only ql24 word needed + let ql24_flat = ql_b_idx + l + 32u; + let ql24_word = ql24_flat / 4u; + let ql24 = bitcast(vec2( + src0[scale_idx + 2u * ql24_word], + src0[scale_idx + 2u * ql24_word + 1u] + )); + let ql24_b = get_byte(ql24, ql24_flat % 4u); + + // Load only qh word needed + let qh_flat = qh_b_idx + l; + let qh_word = qh_flat / 4u; + let qh = bitcast(vec2( + src0[scale_idx + 64u + 2u * qh_word], + src0[scale_idx + 64u + 2u * qh_word + 1u] + )); + let qh_b = get_byte(qh, qh_flat % 4u); + + let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); + let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); + let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); + let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); + + // Load only the scale word needed + let is = l / 16u; + let sc_idx = sc_b_idx + is + quarter * 2u; + let sc_word = sc_idx / 4u; + let sc = bitcast(vec2( + src0[scale_idx + 96u + 2u * sc_word], + src0[scale_idx + 96u + 2u * sc_word + 1u] + )); + let sc_val = get_byte_i32(sc, sc_idx % 4u); + + let d = src0[scale_idx + 104u]; + + var q_val: f16; + if (quarter == 0u) { + q_val = q1; + } else if (quarter == 1u) { + q_val = q2; + } else if (quarter == 2u) { + q_val = q3; + } else { + q_val = q4; + } + + shmem[elem_idx] = d * f16(sc_val) * q_val; + } +} + +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 761e3017c14..b1da421a691 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 { const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f9ea95e07b9..860d84813af 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -84,6 +84,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } #endif +#ifdef MUL_ACC_Q4_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 10u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = f32(src0[scale_idx + 1u]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q5_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 11u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q5_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 12u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 17u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 18u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + f32(m); + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q6_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { + let aligned = byte_offset & ~3u; + let idx = bbase + aligned / 2u; + return bitcast(vec2(src0[idx], src0[idx + 1u])); +} + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let tid = tig / 2u; + let ix = tig % 2u; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let nb = tile_size / BLOCK_SIZE; + let k_block_start = k_outer / BLOCK_SIZE; + + // Aligned scale byte position (is can be odd) + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + var local_sum = 0.0; + + for (var i = ix; i < nb; i += 2u) { + let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + + let d_raw = load_u32_at(bbase, 208u); + let d = f32(bitcast>(d_raw)[0]); + + let ql1_u32 = load_u32_at(bbase, q_offset_l); + let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); + let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); + let sc_u32_0 = load_u32_at(bbase, sc_base_byte); + let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let y_base = i * BLOCK_SIZE + y_offset + l; + let yl0 = f32(shared_vector[y_base]); + let yl1 = f32(shared_vector[y_base + 32u]); + let yl2 = f32(shared_vector[y_base + 64u]); + let yl3 = f32(shared_vector[y_base + 96u]); + + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += yl0 * dq0; + sums[1] += yl1 * dq1; + sums[2] += yl2 * dq2; + sums[3] += yl3 * dq3; + } + + local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + + return local_sum; +} +#endif + struct MulMatParams { offset_src0: u32, offset_src1: u32, @@ -191,4 +479,3 @@ fn main( dst[dst_idx / VEC_SIZE] = store_val(group_base); } } - diff --git a/src/ggml-webgpu.cpp b/src/ggml-webgpu.cpp new file mode 100644 index 00000000000..e69de29bb2d