From 796db602e367bf187261838653b3db456f47141c Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 19 Feb 2023 19:33:52 -0500 Subject: [PATCH] [WEB] WebGPU Codegen This PR provide an implementation of WebGPU codegen. Previously we relied on SPIRV codegen for WebGPU, which is deprecated in favor of the WGSL shading language. Pass limited testing on elementwise via chrome. Likely we will do future iterations. Also cleans up some legacy code organization in intrinsics. --- include/tvm/tir/op.h | 9 + include/tvm/topi/elemwise.h | 49 +- src/target/intrin_rule.cc | 16 + src/target/intrin_rule.h | 3 + src/target/source/codegen_c.h | 6 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_source_base.cc | 3 +- src/target/source/codegen_webgpu.cc | 555 +++++++++++++++++++++++ src/target/source/codegen_webgpu.h | 92 ++++ src/target/source/intrin_rule_metal.cc | 17 - src/target/source/intrin_rule_webgpu.cc | 118 +++++ src/target/spirv/build_vulkan.cc | 16 +- src/target/spirv/intrin_rule_spirv.cc | 34 -- src/target/spirv/ir_builder.h | 18 +- src/tir/op/op.cc | 42 ++ web/emcc/webgpu_runtime.cc | 24 +- web/src/runtime.ts | 4 +- web/src/webgpu.ts | 9 +- web/tests/python/webgpu_rpc_test.py | 6 +- 19 files changed, 870 insertions(+), 153 deletions(-) create mode 100644 src/target/source/codegen_webgpu.cc create mode 100644 src/target/source/codegen_webgpu.h create mode 100644 src/target/source/intrin_rule_webgpu.cc diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 21bc7e7a5056..edfb31851872 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -678,6 +678,15 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span = Span()); +/*! + * \brief Fast_erf_float expression from Eigen + * + * \param arg The input expression. + * \param bits The number of bits in the type. + * \return The constructed expression. + */ +TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index f26105cb180b..49b50019f04d 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -455,54 +456,6 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", } } -/*! - * \brief Fast_erf_float expression from Eigen - * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 - * \param arg The input expression. - * \param bits The number of bits in the type. - */ -inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { - auto plus_4 = make_const(DataType::Float(bits), 4.f); - auto minus_4 = make_const(DataType::Float(bits), -4.f); - - // The monomial coefficients of the numerator polynomial (odd). - auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f); - auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f); - auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f); - auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f); - auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f); - auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f); - auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f); - - // The monomial coefficients of the denominator polynomial (even). - auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f); - auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f); - auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f); - auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f); - auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f); - - // clamp x - auto x = tvm::max(tvm::min(arg, plus_4), minus_4); - auto x2 = x * x; - - // Evaluate the numerator polynomial p. - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - - // Evaluate the denominator polynomial p. - auto q = x2 * beta_8 + beta_6; - q = x2 * q + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - - return p / q; -} - /*! * \brief Fast_erf_float expression from Eigen */ diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 8c7ff1abad51..398e24d2510e 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -118,6 +118,22 @@ TVM_REGISTER_OP("tir.nearbyint") TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", DispatchPureExtern); +PrimExpr DispatchFastErf(const PrimExpr& e) { + LOG(WARNING) << "fast_erf will be used instead of erf"; + const CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + PrimExpr arg = call->args[0]; + int bits = arg.dtype().bits(); + PrimExpr res; + if (arg.dtype().is_float() && (bits == 16 || bits == 32)) { + res = fast_erf_float_expr(arg, bits); + } else { + LOG(FATAL) << "Unsupported type in Metal fast_erf"; + } + return res; +} + } // namespace intrin namespace legalize { diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 6a517a9abd24..b7f5881b3a90 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -77,6 +77,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { } } +// Dispatch ERF to fast erf when it is not available. +PrimExpr DispatchFastErf(const PrimExpr& e); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index be715ad3a049..40733808d61b 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -262,7 +262,7 @@ class CodeGenC : public ExprFunctor, */ void RegisterHandleType(const VarNode* buf_var, DataType t); // override - void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); @@ -281,10 +281,10 @@ class CodeGenC : public ExprFunctor, const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); Integer constants_byte_alignment_ = 16; - - private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; + + private: /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; // deep comparison of PrimExpr diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b3ca3eb46149..928d961d50ee 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("_"); + name_supply_->FreshName("v_"); // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 75833fd93629..9c17458bf221 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -43,7 +43,8 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } } SSAEntry e; - e.vid = name_supply_->FreshName("_"); + // use v_ prefix so it works for most systems + e.vid = name_supply_->FreshName("v_"); e.scope_id = static_cast(scope_mark_.size() - 1); ssa_assign_map_[src] = e; this->PrintIndent(); diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc new file mode 100644 index 000000000000..e4ccef88b62f --- /dev/null +++ b/src/target/source/codegen_webgpu.cc @@ -0,0 +1,555 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_webgpu.cc + */ +#include "codegen_webgpu.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "../../arith/pattern_match.h" +#include "../../runtime/meta_data.h" +#include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" + +namespace tvm { +namespace codegen { + +std::string CodeGenWebGPU::Finish() { + return decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); +} + +void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { + CodeGenC::InitFuncState(f); + // analyze the data; + for (Var arg : f->params) { + if (arg.dtype().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } + std::fill(workgroup_size_, workgroup_size_ + 3, 1); +} + +CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} + +void CodeGenWebGPU::AddFunction(const PrimFunc& f) { + // clear previous generated state. + this->InitFuncState(f); + // skip the first underscore, so SSA variable starts from + name_supply_->FreshName("v_"); + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + + // add to alloc buffer type. + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; + + decl_stream << "//----------------------------------------\n" + << "// function: " << global_symbol.value() << "\n" + << "//----------------------------------------\n"; + + std::vector pod_args; + int num_buffer = 0; + // setup buffer argumemts + for (Var arg : f->params) { + DataType t = arg.dtype(); + if (t.is_handle()) { + auto* ptr = arg->type_annotation.as(); + ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + auto* prim = ptr->element_type.as(); + ICHECK(prim) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::Bool()) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + } + std::string vid = AllocVarID(arg.get()); + this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " + << "var " << vid << " : array<"; + this->PrintType(value_storage_type, this->decl_stream); + this->decl_stream << ">;\n"; + } else { + pod_args.push_back(arg); + } + } + + if (pod_args.size() != 0) { + // setup POD arguments + // TODO(tvm-team): store as a uniform, readonly buffer. + LOG(FATAL) << "Do not support pod arguments for now"; + } + // add to alloc buffer type. + // Function header. + this->stream << "fn main(\n" + << " @builtin(workgroup_id) blockIdx : vec3,\n" + << " @builtin(local_invocation_id) threadIdx : vec3\n" + << ") {\n"; + // the function scope. + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; + // anotate workgroup + this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " + << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; +} + +void CodeGenWebGPU::VisitStmt_(const AttrStmtNode* op) { + // record workgroup size + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag.length() != 0) { + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); + if (ts.rank == 1) { + ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; + ICHECK_LT(ts.dim_index, 3); + auto* sizeptr = op->value.as(); + ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " + << " get " << op->value; + workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); + } + } + } + // normal operation + CodeGenC::VisitStmt_(op); +} + +void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + std::ostringstream os; + PrintType(iv->var.dtype(), os); + os << "(" << iv->thread_tag << ")"; + std::string tidx = os.str(); + this->MarkConst(tidx); + var_idmap_[iv->var.get()] = tidx; +} + +void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + LOG(FATAL) << "Cannot print handle type in WebGPU"; + } + if (t.is_void()) { + os << "void"; + return; + } + if (t == DataType::Bool()) { + os << "bool"; + return; + } + + if (lanes != 1) { + ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; + os << "vec" << lanes << "<"; + } + + if (t.is_float()) { + ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; + os << "f" << t.bits(); + } else if (t.is_uint()) { + os << "u" << t.bits(); + } else if (t.is_int()) { + os << "i" << t.bits(); + } else { + LOG(FATAL) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; + } + if (lanes != 1) { + os << ">"; + } +} + +void CodeGenWebGPU::PrintStorageSync(const CallNode* op) { + const std::string& sync = op->args[0].as()->value; + if (sync == "warp") { + this->PrintIndent(); + this->stream << "workgroupBarrier();\n"; + } else if (sync == "shared") { + this->PrintIndent(); + this->stream << "workgroupBarrier();\n"; + } else if (sync == "global") { + LOG(FATAL) << "global barrier not supported"; + } +} + +void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& src, + DataType type) { + stream << "let " << target << " : "; + PrintType(type, stream); + stream << " = " << src << ";\n"; +} + +void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} + +void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (op->op.same_as(builtin::reinterpret())) { + // generate bitcast(ARG) + os << "bitcast<"; + this->PrintType(op->dtype, os); + os << ">("; + this->PrintExpr(op->args[0], os); + os << ")"; + } else if (op->op.same_as(builtin::if_then_else())) { + // conditional that skips eval if cond evals to false + std::string result = name_supply_->FreshName("condval"); + std::string cond = PrintExpr(op->args[0]); + this->PrintIndent(); + this->stream << "var " << result << " : "; + PrintType(op->dtype, this->stream); + this->stream << ";\n"; + this->PrintIndent(); + this->stream << "if (" << cond << ") {\n"; + { + int then_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << result << " = " << PrintExpr(op->args[1]) << ";\n} else {\n"; + this->EndScope(then_scope); + } + { + int else_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << result << " = " << PrintExpr(op->args[2]) << ";\n}\n"; + this->EndScope(else_scope); + } + os << result; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) + PrintType(op->dtype, os); + os << "(" << PrintExpr(op->value) << ")"; +} + +void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) + os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", " + << PrintExpr(op->condition) << ")"; +} + +void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) + if (op->dtype.bits() == 32) { + std::ostringstream temp; + if (op->dtype.is_int()) { + temp << op->value << "i"; + } else { + ICHECK(op->dtype.is_uint()); + temp << op->value << "u"; + } + this->MarkConst(temp.str()); + os << temp.str(); + } else { + this->PrintType(op->dtype, os); + os << "(" << op->value << ")"; + } +} + +void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) + std::ostringstream temp; + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) { + temp << 'f'; + } else if (op->dtype.bits() == 16) { + temp << 'h'; + } else { + LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits(); + } + MarkConst(temp.str()); + os << temp.str(); +} + +void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + // NOTE: direct impl of load/store for correctness + // Each printing stmt must stand on their own after all preprocessing steps + // to ensure correctness in the case of nested-expression + // do not try to lift common printings from each case + ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + int lanes = op->dtype.lanes(); + std::string buffer_vid = GetVarID(buffer_var.get()); + + if (value_dtype.lanes() == element_dtype.lanes()) { + // Direct buffer loading + // Special handle bool loading + if (value_dtype == DataType::Bool()) { + this->PrintType(value_dtype, os); + os << "("; + } else { + ICHECK(value_dtype == element_dtype); + } + ICHECK_EQ(index.dtype().lanes(), 1); + os << buffer_vid << "[" << this->PrintExpr(index) << "]"; + // Special handle bool loading + if (value_dtype == DataType::Bool()) { + os << ")"; + } + } else { + // Vector load from scalar buffer + ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + ICHECK(value_dtype.element_of() == element_dtype) + << "WebGPU vector loading requires base type to match"; + arith::PVar base; + if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + // vec3(buf[base + 0], buf[base + 1], buf[base + 2]); + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); + PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + os << "("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << buffer_vid << "[" << base_vid << " + " << i << "]"; + } + os << ")"; + } else { + // vec3(buf[index[0]], buf[index[1]], buf[index[2]]); + std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); + PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << buffer_vid << "[" << index_vid << "[" << i << "]]"; + } + os << ")"; + } + } +} + +void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { + // use ssa form. + if (print_ssa_form_) { + std::string value = PrintExpr(op->value); + ICHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + std::string value = PrintExpr(op->value); + this->stream << "let " << AllocVarID(op->var.get()) << " : "; + PrintType(op->var.dtype(), this->stream); + this->stream << " = " << value << ";\n"; + } + PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { + CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + + std::string buffer_vid = GetVarID(buffer_var.get()); + + if (value_dtype.lanes() == element_dtype.lanes()) { + // must execute print expr first + // so we won't have recursive append to stream + std::string index_vid = PrintExpr(index); + std::string value_vid = PrintExpr(op->value); + // now print the assignment line. + this->PrintIndent(); + stream << buffer_vid << "[" << index_vid << "] = "; + // special explicit conversion of bool + if (value_dtype == DataType::Bool()) { + PrintType(element_dtype, stream); + stream << "("; + } else { + ICHECK(value_dtype == element_dtype); + } + stream << value_vid; + // Special handle bool store + if (value_dtype == DataType::Bool()) { + stream << ")"; + } + stream << ";\n"; + } else { + // Vector store into scalar buffer + ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + ICHECK(value_dtype.element_of() == element_dtype) + << "WebGPU vector stire requires base type to match"; + std::string value_vid = PrintExpr(op->value); + arith::PVar base; + if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) { + // buf[base + 0] = value[0] + // buf[base + 1] = value[1] + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); + for (int i = 0; i < value_dtype.lanes(); ++i) { + this->PrintIndent(); + stream << buffer_vid << "[" << base_vid << " + " << i << "] = " << value_vid << "[" << i + << "];\n"; + } + } else { + // buf[index[0]] = value[0] + // buf[index[1]] = value[1] + std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); + for (int i = 0; i < value_dtype.lanes(); ++i) { + this->PrintIndent(); + stream << buffer_vid << "[" << index_vid << "[" << i << "]] = " << value_vid << "[" << i + << "];\n"; + } + } + } +} + +void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + + if (storage_scope.rank == runtime::StorageRank::kShared) { + this->decl_stream << "var " << vid << " : array<"; + PrintType(op->dtype, this->decl_stream); + this->decl_stream << ", " << constant_size << ">;\n"; + } else if (storage_scope.rank == runtime::StorageRank::kLocal) { + this->PrintIndent(); + this->stream << "var " << vid << " : array<"; + PrintType(op->dtype, this->stream); + this->stream << ", " << constant_size << ">;\n"; + } else { + LOG(FATAL) << "WebGPU: Do not support storage scope: " << storage_scope.to_string(); + } + this->PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const ForNode* op) { + std::string extent = PrintExpr(op->extent); + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + ICHECK(is_zero(op->min)); + stream << "for (var "; + stream << vid << " : "; + PrintType(op->loop_var.dtype(), stream); + stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; + int for_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(for_scope); + PrintIndent(); + stream << "}\n"; +} + +void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) { + // skip assert + PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const AllocateConstNode* op) { + LOG(FATAL) << "WebGPU: do not support alloc const"; +} + +//------------------------------------------------- +// WebGPUSourceModule to enable export +//------------------------------------------------- +class WebGPUSourceModuleNode final : public runtime::ModuleNode { + public: + explicit WebGPUSourceModuleNode(std::unordered_map smap, + std::unordered_map fmap) + : smap_(smap), fmap_(fmap) {} + + const char* type_key() const final { return "webgpu"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; + return PackedFunc(nullptr); + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + LOG(FATAL) << "Not implemented"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmap_); + stream->Write(smap_); + } + + std::string GetSource(const std::string& format) final { + std::ostringstream os; + for (auto kv : smap_) { + os << kv.second; + } + return os.str(); + } + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; +}; + +//------------------------------------------------- +// Build logic. +//------------------------------------------------- +runtime::Module BuildWebGPU(IRModule mod, Target target) { + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + bool output_ssa = false; + + std::unordered_map smap; + for (auto kv : mod->functions) { + CodeGenWebGPU cg(target); + ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; + std::string f_name = global_symbol.value(); + cg.Init(output_ssa); + cg.AddFunction(f); + std::string code = cg.Finish(); + smap[f_name] = code; + } + auto n = make_object(smap, ExtractFuncInfo(mod)); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { + return BuildWebGPU(mod, target); +}); + +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h new file mode 100644 index 000000000000..57f226ba8ad6 --- /dev/null +++ b/src/target/source/codegen_webgpu.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_webgpu.h + * \brief Generate WebGPU shaders in WGSL. + * + * This module generates WGSL shading langauge. + * See https://www.w3.org/TR/WGSL/ for the language reference. + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ +#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ + +#include + +#include + +#include "codegen_c.h" + +namespace tvm { +namespace codegen { + +/*! + * \brief WebGPU code generator. + * + * Note WGSL have a different syntax from normal C. + * We only leevrage the C for expression generation and + * write most of the language generations. + */ +class CodeGenWebGPU final : public CodeGenC { + public: + explicit CodeGenWebGPU(Target target); + // overrides + std::string Finish() final; + void AddFunction(const PrimFunc& f); // NOLINT(*) + void InitFuncState(const PrimFunc& f) final; + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + + // assignment printing + void PrintSSAAssign(const std::string& target, const std::string& src, DataType type) final; + + // overload visitor + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) + + // stmt printing + void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; + void VisitStmt_(const ForNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; + void VisitStmt_(const AssertStmtNode* op) final; + void VisitStmt_(const AllocateConstNode* op) final; + + private: + /*! + * \brief Records the workgroup size of the kernel. + */ + uint32_t workgroup_size_[3]; + /*! + * \brief Storage type of bool values. + */ + DataType boolean_storage_type_{DataType::Int(8)}; + Target target_; +}; +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 7d7a5fb29a7c..dd924b925596 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -22,7 +22,6 @@ * \brief Metal intrinsic rules. */ #include -#include #include "../intrin_rule.h" @@ -94,22 +93,6 @@ TVM_REGISTER_OP("tir.cos").set_attr("metal.FLowerIntrinsic", TVM_REGISTER_OP("tir.cosh") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -// There is no erf function in Metal. When erf is used, we use fast_erf instead -static PrimExpr DispatchFastErf(const PrimExpr& e) { - LOG(WARNING) << " Metal doesn't have built-in erf function. fast_erf will be used instead."; - const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - PrimExpr arg = call->args[0]; - int bits = arg.dtype().bits(); - bool isFloat = arg.dtype().is_float(); - PrimExpr res; - if (isFloat && (bits == 16 || bits == 32)) - res = topi::fast_erf_float_expr(arg, bits); - else - LOG(FATAL) << "Unsupported type in Metal fast_erf"; - return res; -} TVM_REGISTER_OP("tir.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); } // namespace intrin diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc new file mode 100644 index 000000000000..81803059fc49 --- /dev/null +++ b/src/target/source/intrin_rule_webgpu.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file intrin_rule_webgpu.cc + * \brief WebGPU intrinsic rules. + */ +#include +#include + +#include "../intrin_rule.h" + +namespace tvm { +namespace codegen { +namespace intrin { + +using tir::FLowerIntrinsic; + +// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions + +struct ReturnAbs { + std::string operator()(DataType t, std::string name) const { return "abs"; } +}; + +TVM_REGISTER_OP("tir.fabs") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.acos") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.acosh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.asin") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.asinh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.atan") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.atan2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.ceil") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.cos").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.cosh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.floor") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.fma").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.log2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.round") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.sin").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.sinh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.sqrt") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.tan").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.tanh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.trunc") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +// extra dispatch +TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); + +} // namespace intrin +} // namespace codegen +} // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 94f1bf16a25e..dc1d8f865baa 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -97,7 +97,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) { +runtime::Module BuildSPIRV(IRModule mod, Target target) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -122,7 +122,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); - std::string entry = webgpu_restriction ? "main" : f_name; + std::string entry = f_name; VulkanShader shader = cg.BuildFunction(f, entry); @@ -144,12 +144,6 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) spirv_tools.ValidateShader(shader.data); } - if (webgpu_restriction) { - for (auto param : f->params) { - ICHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments"; - } - } - if (postproc != nullptr) { TVMByteArray arr; arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); @@ -168,11 +162,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) } TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { - return BuildSPIRV(mod, target, false); -}); - -TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { - return BuildSPIRV(mod, target, true); + return BuildSPIRV(mod, target); }); } // namespace codegen diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 0c65f1718a5d..ac304b92b6d7 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -100,40 +100,6 @@ TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); - -// WebGPU rules. -TVM_REGISTER_OP("tir.floor") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.ceil") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.round") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.nearbyint") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.trunc") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.fabs") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.sqrt") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.tanh") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); } // namespace intrin namespace legalize { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index f1b5397b3757..d642484532f9 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -215,7 +215,7 @@ class InstrBuilder { * \brief add sequence of values to instruction * \param args The instruction sequence * \return reference to self. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template InstrBuilder& AddSeq(Args&&... args) { @@ -328,7 +328,7 @@ class IRBuilder { * \brief Add code to debug segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void Debug(spv::Op op, Args&&... args) { @@ -339,7 +339,7 @@ class IRBuilder { * \brief Set the name of a value or label * \param obj The object to be named * \param name The name of the object - * \tparams Obj The type of the object being named. Typically a Label or Value. + * \tparam Obj The type of the object being named. Typically a Label or Value. */ template void SetName(Obj&& obj, const std::string& name) { @@ -350,7 +350,7 @@ class IRBuilder { * \brief Add Execution mode to a function. * \param func The function value * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void ExecutionMode(Value func, Args&&... args) { @@ -360,7 +360,7 @@ class IRBuilder { * \brief Add code to decorate segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void Decorate(spv::Op op, Args&&... args) { @@ -370,7 +370,7 @@ class IRBuilder { * \brief Add code to global segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void DeclareGlobal(spv::Op op, Args&&... args) { @@ -382,7 +382,7 @@ class IRBuilder { * \param op The operator * \param args The instruction sequence * \return The result SSA value. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template Instr MakeInst(spv::Op op, Args&&... args) { @@ -395,7 +395,7 @@ class IRBuilder { * \param out_type The result type. * \param args The instruction sequence * \return The result SSA value. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) { @@ -435,7 +435,7 @@ class IRBuilder { * \brief Build vector by concatenating components * * \param vec The vector component - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ Value Concat(const std::vector& vec); /*! diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 078e32ca57c7..828ab010831f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -1026,4 +1026,46 @@ TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { return const_true(t.lanes(), span); }); +PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { + auto plus_4 = make_const(DataType::Float(bits), 4.f); + auto minus_4 = make_const(DataType::Float(bits), -4.f); + + // The monomial coefficients of the numerator polynomial (odd). + auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f); + auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f); + auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f); + auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f); + auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f); + auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f); + auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f); + + // The monomial coefficients of the denominator polynomial (even). + auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f); + auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f); + auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f); + auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f); + auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f); + + // clamp x + auto x = tvm::max(tvm::min(arg, plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; +} + } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 936c9938dd3a..17efcc8c70a7 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -38,7 +38,6 @@ #include #include "../../src/runtime/meta_data.h" -#include "../../src/runtime/vulkan/vulkan_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { @@ -150,9 +149,9 @@ WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore:: class WebGPUModuleNode final : public runtime::ModuleNode { public: - explicit WebGPUModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) { + explicit WebGPUModuleNode(std::unordered_map smap, + std::unordered_map fmap) + : smap_(smap), fmap_(fmap) { auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); CHECK(fp != nullptr); create_shader_ = *fp; @@ -168,10 +167,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { std::ostringstream os; dmlc::JSONWriter writer(&os); info.Save(&writer); - TVMByteArray arr; - arr.data = reinterpret_cast(it->second.data.data()); - arr.size = it->second.data.size() * sizeof(it->second.data[0]); - return create_shader_(os.str(), arr); + return create_shader_(os.str(), it->second); } else { return PackedFunc(nullptr); } @@ -190,29 +186,27 @@ class WebGPUModuleNode final : public runtime::ModuleNode { private: // function information table. - std::unordered_map smap_; + std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The source std::string source_; // Callback to get the GPU function. - TypedPackedFunc create_shader_; + TypedPackedFunc create_shader_; }; Module WebGPUModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; - std::string fmt; - stream->Read(&fmt); stream->Read(&fmap); stream->Read(&smap); - return Module(make_object(smap, fmap, "")); + return Module(make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module. -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 8df382dbc837..b341a7d4b1a4 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1037,8 +1037,8 @@ export class Instance implements Disposable { this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { return webGPUContext.getDeviceAPI(name); }); - this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { - return webGPUContext.createShader(info, data); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, code: string) => { + return webGPUContext.createShader(info, code); }); this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { await webGPUContext.sync(); diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 5de47c200dcc..faf6fac990c8 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -79,9 +79,9 @@ export class WebGPUContext { * Create a PackedFunc that runs the given shader * * @param info The function information in json. - * @param data The shader data(in SPIRV) + * @param code The shader data(in WGSL) */ - createShader(info: string, data: Uint8Array): Function { + createShader(info: string, code: string): Function { const finfo = JSON.parse(info); const layoutEntries: Array = []; for (let i = 0; i < finfo.arg_types.length; ++i) { @@ -102,16 +102,13 @@ export class WebGPUContext { entries: layoutEntries }); - const textDecoder = new TextDecoder("utf-8") - const codeString = textDecoder.decode(data.buffer) - const pipeline = this.device.createComputePipeline({ layout: this.device.createPipelineLayout({ bindGroupLayouts: [ bindGroupLayout ] }), compute: { module: this.device.createShaderModule({ - code: codeString + code: code }), entryPoint: "main" } diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index ac1a241a9baa..6e34a8a2b36c 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -37,12 +37,10 @@ def test_rpc(): # generate the wasm library target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") runtime = Runtime("cpp", {"system-lib": True}) - if not tvm.runtime.enabled(target_host): - raise RuntimeError("Target %s is not enbaled" % target_host) n = 2048 A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B") s = te.create_schedule(B.op) num_thread = 2 @@ -75,7 +73,7 @@ def check(remote): f1 = remote.system_lib() addone = f1.get_function("addone") addone(a, b) - np.testing.assert_equal(b.numpy(), a.numpy() + 1) + np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), atol=1e-5, rtol=1e-5) print("Test pass..") check(remote)