Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Legend:
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
Expand Down
28 changes: 14 additions & 14 deletions docs/ops/WebGPU.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5023,20 +5023,20 @@
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
Expand Down
71 changes: 65 additions & 6 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash {
}
};

/** Repeat **/

struct ggml_webgpu_repeat_pipeline_key {
int type;

bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
};

struct ggml_webgpu_repeat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
return seed;
}
};

/** Binary **/

struct ggml_webgpu_binary_pipeline_key {
Expand Down Expand Up @@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib {
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
Expand Down Expand Up @@ -1147,7 +1165,7 @@ class ggml_webgpu_shader_lib {
}

std::vector<std::string> defines;
std::string variant = "concat";
std::string variant = "concat";

switch (key.type) {
case GGML_TYPE_F32:
Expand All @@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib {

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
return concat_pipelines[key];
}

webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_repeat_pipeline_key key = {
.type = context.dst->type,
};

auto it = repeat_pipelines.find(key);
if (it != repeat_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "repeat";

switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
case GGML_TYPE_I16:
defines.push_back("TYPE_I16");
variant += "_i16";
break;
default:
GGML_ABORT("Unsupported type for repeat shader");
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_repeat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
repeat_pipelines[key] = pipeline;
return repeat_pipelines[key];
}

webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
Expand Down
55 changes: 51 additions & 4 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}

static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);

std::vector<uint32_t> params = { ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src0->ne[0]),
(uint32_t) (src0->ne[1]),
(uint32_t) (src0->ne[2]),
(uint32_t) (src0->ne[3]),
(uint32_t) (dst->ne[0]),
(uint32_t) (dst->ne[1]),
(uint32_t) (dst->ne[2]) };

std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};

ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};

webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}

static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);

Expand Down Expand Up @@ -2158,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_CONCAT:
return ggml_webgpu_concat(ctx, src0, src1, node);
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
Expand Down Expand Up @@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
Expand Down Expand Up @@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_CONCAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
break;
case GGML_OP_REPEAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
Expand Down
67 changes: 67 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
enable f16;

struct Params {
ne: u32,

offset_src0: u32,
offset_dst: u32,

stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,

a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
a_ne3: u32,

ne0: u32,
ne1: u32,
ne2: u32,
};

#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_I32
#define DataType i32
#endif
#ifdef TYPE_I16
// same size (16-bit) is sufficient for repeat
#define DataType f16
#endif

@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;

@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(2)
var<uniform> params: Params;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;

let a_i0 = i0 % params.a_ne0;
let a_i1 = i1 % params.a_ne1;
let a_i2 = i2 % params.a_ne2;
let a_i3 = i3 % params.a_ne3;

let a_index = a_i0 * params.stride_src0_0 +
a_i1 * params.stride_src0_1 +
a_i2 * params.stride_src0_2 +
a_i3 * params.stride_src0_3;

dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
}
}
Loading