Skip to content

Commit 2c82462

Browse files
author
Neha Abbas
committed
updated optimization, fixed errors
1 parent b566811 commit 2c82462

File tree

3 files changed

+181
-85
lines changed

3 files changed

+181
-85
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct webgpu_context_struct {
248248

249249
webgpu_pipeline memset_pipeline;
250250
webgpu_pipeline mul_mat_pipeline[30][2];
251-
webgpu_pipeline set_rows_pipeline;
251+
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
252252
webgpu_pipeline get_rows_pipeline[30];
253253
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
@@ -767,9 +767,20 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
767767
};
768768

769769
size_t max_wg_size = ctx->max_wg_size_x;
770-
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
770+
// number of threads needed with vec4 = (total number of rows in matrix) * (number of elements in a row / 4)
771+
uint32_t threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
772+
773+
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][0];
774+
// if not evenly divisble by 4, use the non-vectorized version
775+
if (src->ne[0] % 4 != 0) {
776+
pipeline = ctx->set_rows_pipeline[0][1];
777+
// threads = number of rows
778+
threads = src->ne[1] * src->ne[2] * src->ne[3];
779+
}
780+
781+
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
771782

772-
return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
783+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
773784
}
774785

775786
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1620,7 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16201631
}
16211632

16221633
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1623-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
1634+
// create_pipeline(device, pipeline, shader_code, label, constants)
1635+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16, "set_rows_f16",
1636+
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1637+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16_vec, "set_rows_f16_vec",
16241638
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
16251639
}
16261640

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"SHADER_SUFFIX": "f16_vec",
6+
"REPLS": {
7+
"TYPE" : "vec4<f32>",
8+
"DST_TYPE": "vec4<f16>",
9+
"BLOCK_SIZE": 4
10+
},
11+
"DECLS": ["F16_VEC"]
12+
},
13+
{
14+
"SHADER_SUFFIX": "f16",
15+
"REPLS": {
16+
"TYPE" : "f32",
17+
"DST_TYPE": "f16",
18+
"BLOCK_SIZE": 1
19+
},
20+
"DECLS": ["F16"]
21+
}
22+
]
23+
24+
#end(VARIANTS)
25+
26+
#define(DECLS)
27+
28+
#decl(F16_VEC)
29+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
30+
let src_vec_index = (src_base + offset) / {{BLOCK_SIZE}};
31+
let dst_vec_index = (dst_base + offset) / {{BLOCK_SIZE}};
32+
dst[dst_vec_index] = vec4<f16>(src[src_vec_index]);
33+
}
34+
#enddecl(F16_VEC)
35+
36+
#decl(F16)
37+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
38+
dst[dst_base + offset] = f16(src[src_base + offset]);
39+
}
40+
#enddecl(F16)
41+
42+
#end(DECLS)
43+
44+
#define(SHADER)
45+
46+
enable f16;
47+
48+
DECLS
49+
50+
@group(0) @binding(0)
51+
var<storage, read_write> src: array<{{TYPE}}>;
52+
53+
@group(0) @binding(1)
54+
var<storage, read_write> idx: array<u32>;
55+
56+
@group(0) @binding(2)
57+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
58+
59+
@group(0) @binding(3)
60+
var<storage, read_write> error: atomic<u32>;
61+
62+
struct Params {
63+
offset_src: u32, // in elements
64+
offset_idx: u32, // in elements
65+
offset_dst: u32, // in elements
66+
67+
// Strides (in elements)
68+
stride_src1: u32,
69+
stride_src2: u32,
70+
stride_src3: u32,
71+
72+
stride_idx0: u32,
73+
stride_idx1: u32,
74+
stride_idx2: u32,
75+
76+
stride_dst1: u32,
77+
stride_dst2: u32,
78+
stride_dst3: u32,
79+
80+
// Shape of src
81+
ne0: u32,
82+
n_rows: u32, // n_rows = ne1 = rows per slice
83+
ne2: u32,
84+
ne3: u32,
85+
86+
// Shape of idx
87+
idx1: u32,
88+
idx2: u32,
89+
};
90+
91+
@group(0) @binding(4)
92+
var<uniform> params: Params;
93+
94+
override wg_size: u32;
95+
@compute @workgroup_size(wg_size)
96+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
97+
98+
// Determine the total number of threads based on mode
99+
var max_threads: u32;
100+
var i: u32;
101+
if {{BLOCK_SIZE}} > 1 {
102+
// Vectorized: one thread per vector of elements
103+
// # of total rows to go through * (# of threads per row)
104+
max_threads = (params.n_rows * params.ne2 * params.ne3) * (params.ne0 / {{BLOCK_SIZE}});
105+
106+
// calculations are based off i being row, but when vectorized, it corresponds to a vector in a row
107+
// getting the row from gid
108+
i = gid.x / (params.ne0 / {{BLOCK_SIZE}});
109+
} else {
110+
// Non-vectorized: one thread per row
111+
// # of total rows in the matrix
112+
max_threads = params.n_rows * params.ne2 * params.ne3;
113+
i = gid.x; // i corresponds to the row
114+
}
115+
116+
if (gid.x >= max_threads) {
117+
return;
118+
}
119+
120+
121+
let i_src3 = i / (params.ne2 * params.n_rows);
122+
123+
i = i % (params.ne2 * params.n_rows);
124+
let i_src2 = i / params.n_rows;
125+
let i_src1 = i % params.n_rows;
126+
127+
let i_idx2 = i_src3 % params.idx2;
128+
let i_idx1 = i_src2 % params.idx1;
129+
let i_idx0 = i_src1;
130+
131+
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
132+
133+
let idx_high_val = idx[idx_high];
134+
let idx_low_val = idx[idx_high + 1];
135+
136+
if (idx_low_val != 0) {
137+
// Upper bits of index are not zero, output will be incorrect
138+
atomicStore(&error, 1);
139+
return;
140+
}
141+
142+
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
143+
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
144+
145+
if {{BLOCK_SIZE}} > 1 {
146+
// Vectorized: one thread per vector of elements
147+
148+
// starts at what element of that row?
149+
let element_offset = (gid.x % (params.ne0 / {{BLOCK_SIZE}})) * {{BLOCK_SIZE}};
150+
copy_elements(i_src_row, i_dst_row, element_offset);
151+
152+
} else {
153+
// Non-vectorized: go through each element in row, copy one by one
154+
for (var i: u32 = 0; i < params.ne0; i++) {
155+
copy_elements(i_src_row, i_dst_row, i);
156+
}
157+
}
158+
159+
160+
}
161+
162+
#end(SHADER)
163+

ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)