Skip to content

Commit 3e08e70

Browse files
authored
[WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed (#16976)
* [WebGPU] Support `__dp4a(int8x4, int8x4)` as a pure extern method This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure extern method of WebGPU target. In the generated WGSL shader, `int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)` will be translated into the WGSL built-in function `dot4I8Packed(u32, u32)`. Here is an example to use `__dp4a` in WebGPU target: ``` n = te.var("n") A = te.placeholder((n,), "int8x4", name="A") B = te.placeholder((n,), "int8x4", name="B") C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C") s = te.create_schedule(C.op) bx, tx = s[C].split(C.op.axis[0], factor=64) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest") ``` Issue: #16627 * Add validation * Add `dot4I8Packed` to WebGPU lower intrinsic * Implement builtin `dp4a` with `dot4I8Packed` * Small fix * Add missing comment
1 parent 0df4103 commit 3e08e70

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/target/source/codegen_webgpu.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,14 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN
410410
this->EndScope(else_scope);
411411
}
412412
os << result;
413+
} else if (op->op.same_as(builtin::dp4a())) {
414+
// generate `dot4I8Packed(vec1, vec2) + acc` for the builtin `dp4a`
415+
os << "dot4I8Packed(";
416+
this->PrintExpr(op->args[0], os);
417+
os << ", ";
418+
this->PrintExpr(op->args[1], os);
419+
os << ") + ";
420+
this->PrintExpr(op->args[2], os);
413421
} else {
414422
CodeGenC::VisitExpr_(op, os);
415423
}

0 commit comments

Comments
 (0)