Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context {
ggml_tensor * src5;
ggml_tensor * dst;

uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
uint32_t sg_mat_k = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
bool supports_dot_product = false;
std::string vendor;
};

Expand Down Expand Up @@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
int pair_blocks;

bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
pair_blocks == other.pair_blocks;
}
};

Expand All @@ -178,13 +180,15 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
ggml_webgpu_hash_combine(seed, key.pair_blocks);
return seed;
}
};

struct ggml_webgpu_set_rows_shader_decisions {
bool vec4;
bool i64_idx;
bool pair_blocks;
uint32_t wg_size;
};

Expand Down Expand Up @@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 &&
context.src2->ne[0] % kv_vec_head_align == 0;
const uint32_t kv_vec_head_align =
K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned =
context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0;
// Compile with enough invocations to cover the largest reported subgroup.
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) &&
kv_vec_head_dims_aligned && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
const bool use_subgroup_matrix =
context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 &&
context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
tile_can_dispatch_all_q_rows && !use_vec;

decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;

if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
return decisions;
Expand Down Expand Up @@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
mul_mat_vec_pipelines; // fast mat-vec (n==1)
mul_mat_vec_pipelines; // fast mat-vec (n==1)
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
quantize_q8_pipelines;
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
Expand Down Expand Up @@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib {
}

webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_rows_pipeline_key key = {};
key.dst_type = context.dst->type;
key.vec4 = context.src0->ne[0] % 4 == 0;
key.i64_idx = context.src1->type == GGML_TYPE_I64;
const bool quantized = ggml_is_quantized(context.dst->type);
ggml_webgpu_set_rows_pipeline_key key = {};
key.dst_type = context.dst->type;
key.vec4 =
(context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
key.i64_idx = context.src1->type == GGML_TYPE_I64;
Comment on lines +1272 to +1275
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This went a bit funny. :)

key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);

auto it = set_rows_pipelines.find(key);
if (it != set_rows_pipelines.end()) {
Expand All @@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib {
defines.push_back("DST_F16");
variant += "_dstf16";
break;
case GGML_TYPE_Q8_0:
defines.push_back("DST_Q8_0");
variant += "_dstq8_0";
break;
case GGML_TYPE_Q4_0:
defines.push_back("DST_Q4_0");
variant += "_dstq4_0";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
Expand All @@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
if (key.pair_blocks) {
defines.push_back("PAIR_BLOCKS");
variant += "_pair_blocks";
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
auto processed = preprocessor.preprocess(shader_source, defines);
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
decisions->vec4 = key.vec4;
decisions->i64_idx = key.i64_idx;
decisions->pair_blocks = key.pair_blocks;
decisions->wg_size = context.max_wg_size;
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
set_rows_pipelines[key].context = decisions;
Expand Down Expand Up @@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
Comment on lines 1662 to +1683
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this was re-indented the other way in a previous PR, is it conflicting settings or manual work?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this in my environment, and the clangd version seems to be the main cause. I had been using 22.1.2, but switching to 18.1.3 produced all the same formatting changes as this PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a bit annoying they've changed the formatting. My clangd version is 21.1.6

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to 22.1.6 and this change is avoided. I'll plan on using the more recent clangd going forward!


auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
Expand Down Expand Up @@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_mmvq =
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);

auto it = mul_mat_vec_pipelines.find(key);
Expand Down
11 changes: 8 additions & 3 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct
}

uint32_t threads;
if (decisions->vec4) {
if (ggml_is_quantized(dst->type)) {
const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type);
threads =
(src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row);
} else if (decisions->vec4) {
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
Expand Down Expand Up @@ -4045,8 +4049,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
break;
case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 ||
op->type == GGML_TYPE_Q4_0) &&
src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
break;
case GGML_OP_GET_ROWS:
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
Expand Down
5 changes: 2 additions & 3 deletions ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return;
}

// getting the row from gid
let elems_per_row = params.ne0 / VEC_SIZE;
var i = gid.x / elems_per_row;

Expand Down Expand Up @@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;

let col_idx = (gid.x % elems_per_row);
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
let col_idx = gid.x % elems_per_row;
dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]);
}
Loading
Loading