Commit 3e08e70
authored
* [WebGPU] Support `__dp4a(int8x4, int8x4)` as a pure extern method
This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.
Here is an example to use `__dp4a` in WebGPU target:
```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```
Issue: #16627
* Add validation
* Add `dot4I8Packed` to WebGPU lower intrinsic
* Implement builtin `dp4a` with `dot4I8Packed`
* Small fix
* Add missing comment
1 parent 0df4103 commit 3e08e70
1 file changed
+8
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
410 | 410 | | |
411 | 411 | | |
412 | 412 | | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
413 | 421 | | |
414 | 422 | | |
415 | 423 | | |
| |||
0 commit comments