diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index e268adfb16b..67f1dc0928f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -50,13 +50,13 @@ var params: Params; @compute @workgroup_size(WG_SIZE) fn main( - @builtin(global_invocation_index) gindex: u32, + @builtin(global_invocation_id) gid: vec3, ) { - if (gindex >= params.ne) { + if (gid.x >= params.ne) { return; } - var i = gindex; + var i = gid.x; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); let i2 = i / (params.src_ne1 * params.src_ne0); @@ -64,7 +64,7 @@ fn main( let i1 = i / params.src_ne0; let i0 = i % params.src_ne0; - var j = gindex; + var j = gid.x; let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); let j2 = j / (params.dst_ne1 * params.dst_ne0); @@ -80,4 +80,3 @@ fn main( dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } -