diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 21bc7e7a5056..edfb31851872 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -678,6 +678,15 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span = Span()); +/*! + * \brief Fast_erf_float expression from Eigen + * + * \param arg The input expression. + * \param bits The number of bits in the type. + * \return The constructed expression. + */ +TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index f26105cb180b..49b50019f04d 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -455,54 +456,6 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", } } -/*! - * \brief Fast_erf_float expression from Eigen - * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 - * \param arg The input expression. - * \param bits The number of bits in the type. - */ -inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { - auto plus_4 = make_const(DataType::Float(bits), 4.f); - auto minus_4 = make_const(DataType::Float(bits), -4.f); - - // The monomial coefficients of the numerator polynomial (odd). - auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f); - auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f); - auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f); - auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f); - auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f); - auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f); - auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f); - - // The monomial coefficients of the denominator polynomial (even). - auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f); - auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f); - auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f); - auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f); - auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f); - - // clamp x - auto x = tvm::max(tvm::min(arg, plus_4), minus_4); - auto x2 = x * x; - - // Evaluate the numerator polynomial p. - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - - // Evaluate the denominator polynomial p. - auto q = x2 * beta_8 + beta_6; - q = x2 * q + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - - return p / q; -} - /*! * \brief Fast_erf_float expression from Eigen */ diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 8c7ff1abad51..398e24d2510e 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -118,6 +118,22 @@ TVM_REGISTER_OP("tir.nearbyint") TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", DispatchPureExtern); +PrimExpr DispatchFastErf(const PrimExpr& e) { + LOG(WARNING) << "fast_erf will be used instead of erf"; + const CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + PrimExpr arg = call->args[0]; + int bits = arg.dtype().bits(); + PrimExpr res; + if (arg.dtype().is_float() && (bits == 16 || bits == 32)) { + res = fast_erf_float_expr(arg, bits); + } else { + LOG(FATAL) << "Unsupported type in Metal fast_erf"; + } + return res; +} + } // namespace intrin namespace legalize { diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 6a517a9abd24..b7f5881b3a90 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -77,6 +77,9 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { } } +// Dispatch ERF to fast erf when it is not available. +PrimExpr DispatchFastErf(const PrimExpr& e); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index be715ad3a049..40733808d61b 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -262,7 +262,7 @@ class CodeGenC : public ExprFunctor, */ void RegisterHandleType(const VarNode* buf_var, DataType t); // override - void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); @@ -281,10 +281,10 @@ class CodeGenC : public ExprFunctor, const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); Integer constants_byte_alignment_ = 16; - - private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; + + private: /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; // deep comparison of PrimExpr diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b3ca3eb46149..928d961d50ee 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("_"); + name_supply_->FreshName("v_"); // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 75833fd93629..9c17458bf221 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -43,7 +43,8 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } } SSAEntry e; - e.vid = name_supply_->FreshName("_"); + // use v_ prefix so it works for most systems + e.vid = name_supply_->FreshName("v_"); e.scope_id = static_cast(scope_mark_.size() - 1); ssa_assign_map_[src] = e; this->PrintIndent(); diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc new file mode 100644 index 000000000000..e4ccef88b62f --- /dev/null +++ b/src/target/source/codegen_webgpu.cc @@ -0,0 +1,555 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_webgpu.cc + */ +#include "codegen_webgpu.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "../../arith/pattern_match.h" +#include "../../runtime/meta_data.h" +#include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" + +namespace tvm { +namespace codegen { + +std::string CodeGenWebGPU::Finish() { + return decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); +} + +void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { + CodeGenC::InitFuncState(f); + // analyze the data; + for (Var arg : f->params) { + if (arg.dtype().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } + std::fill(workgroup_size_, workgroup_size_ + 3, 1); +} + +CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} + +void CodeGenWebGPU::AddFunction(const PrimFunc& f) { + // clear previous generated state. + this->InitFuncState(f); + // skip the first underscore, so SSA variable starts from + name_supply_->FreshName("v_"); + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + + // add to alloc buffer type. + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; + + decl_stream << "//----------------------------------------\n" + << "// function: " << global_symbol.value() << "\n" + << "//----------------------------------------\n"; + + std::vector pod_args; + int num_buffer = 0; + // setup buffer argumemts + for (Var arg : f->params) { + DataType t = arg.dtype(); + if (t.is_handle()) { + auto* ptr = arg->type_annotation.as(); + ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + auto* prim = ptr->element_type.as(); + ICHECK(prim) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::Bool()) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + } + std::string vid = AllocVarID(arg.get()); + this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " + << "var " << vid << " : array<"; + this->PrintType(value_storage_type, this->decl_stream); + this->decl_stream << ">;\n"; + } else { + pod_args.push_back(arg); + } + } + + if (pod_args.size() != 0) { + // setup POD arguments + // TODO(tvm-team): store as a uniform, readonly buffer. + LOG(FATAL) << "Do not support pod arguments for now"; + } + // add to alloc buffer type. + // Function header. + this->stream << "fn main(\n" + << " @builtin(workgroup_id) blockIdx : vec3,\n" + << " @builtin(local_invocation_id) threadIdx : vec3\n" + << ") {\n"; + // the function scope. + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; + // anotate workgroup + this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " + << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; +} + +void CodeGenWebGPU::VisitStmt_(const AttrStmtNode* op) { + // record workgroup size + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag.length() != 0) { + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); + if (ts.rank == 1) { + ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; + ICHECK_LT(ts.dim_index, 3); + auto* sizeptr = op->value.as(); + ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " + << " get " << op->value; + workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); + } + } + } + // normal operation + CodeGenC::VisitStmt_(op); +} + +void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + std::ostringstream os; + PrintType(iv->var.dtype(), os); + os << "(" << iv->thread_tag << ")"; + std::string tidx = os.str(); + this->MarkConst(tidx); + var_idmap_[iv->var.get()] = tidx; +} + +void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + LOG(FATAL) << "Cannot print handle type in WebGPU"; + } + if (t.is_void()) { + os << "void"; + return; + } + if (t == DataType::Bool()) { + os << "bool"; + return; + } + + if (lanes != 1) { + ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; + os << "vec" << lanes << "<"; + } + + if (t.is_float()) { + ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; + os << "f" << t.bits(); + } else if (t.is_uint()) { + os << "u" << t.bits(); + } else if (t.is_int()) { + os << "i" << t.bits(); + } else { + LOG(FATAL) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; + } + if (lanes != 1) { + os << ">"; + } +} + +void CodeGenWebGPU::PrintStorageSync(const CallNode* op) { + const std::string& sync = op->args[0].as()->value; + if (sync == "warp") { + this->PrintIndent(); + this->stream << "workgroupBarrier();\n"; + } else if (sync == "shared") { + this->PrintIndent(); + this->stream << "workgroupBarrier();\n"; + } else if (sync == "global") { + LOG(FATAL) << "global barrier not supported"; + } +} + +void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& src, + DataType type) { + stream << "let " << target << " : "; + PrintType(type, stream); + stream << " = " << src << ";\n"; +} + +void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} + +void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (op->op.same_as(builtin::reinterpret())) { + // generate bitcast(ARG) + os << "bitcast<"; + this->PrintType(op->dtype, os); + os << ">("; + this->PrintExpr(op->args[0], os); + os << ")"; + } else if (op->op.same_as(builtin::if_then_else())) { + // conditional that skips eval if cond evals to false + std::string result = name_supply_->FreshName("condval"); + std::string cond = PrintExpr(op->args[0]); + this->PrintIndent(); + this->stream << "var " << result << " : "; + PrintType(op->dtype, this->stream); + this->stream << ";\n"; + this->PrintIndent(); + this->stream << "if (" << cond << ") {\n"; + { + int then_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << result << " = " << PrintExpr(op->args[1]) << ";\n} else {\n"; + this->EndScope(then_scope); + } + { + int else_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << result << " = " << PrintExpr(op->args[2]) << ";\n}\n"; + this->EndScope(else_scope); + } + os << result; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) + PrintType(op->dtype, os); + os << "(" << PrintExpr(op->value) << ")"; +} + +void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) + os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", " + << PrintExpr(op->condition) << ")"; +} + +void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) + if (op->dtype.bits() == 32) { + std::ostringstream temp; + if (op->dtype.is_int()) { + temp << op->value << "i"; + } else { + ICHECK(op->dtype.is_uint()); + temp << op->value << "u"; + } + this->MarkConst(temp.str()); + os << temp.str(); + } else { + this->PrintType(op->dtype, os); + os << "(" << op->value << ")"; + } +} + +void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) + std::ostringstream temp; + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) { + temp << 'f'; + } else if (op->dtype.bits() == 16) { + temp << 'h'; + } else { + LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits(); + } + MarkConst(temp.str()); + os << temp.str(); +} + +void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + // NOTE: direct impl of load/store for correctness + // Each printing stmt must stand on their own after all preprocessing steps + // to ensure correctness in the case of nested-expression + // do not try to lift common printings from each case + ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + int lanes = op->dtype.lanes(); + std::string buffer_vid = GetVarID(buffer_var.get()); + + if (value_dtype.lanes() == element_dtype.lanes()) { + // Direct buffer loading + // Special handle bool loading + if (value_dtype == DataType::Bool()) { + this->PrintType(value_dtype, os); + os << "("; + } else { + ICHECK(value_dtype == element_dtype); + } + ICHECK_EQ(index.dtype().lanes(), 1); + os << buffer_vid << "[" << this->PrintExpr(index) << "]"; + // Special handle bool loading + if (value_dtype == DataType::Bool()) { + os << ")"; + } + } else { + // Vector load from scalar buffer + ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + ICHECK(value_dtype.element_of() == element_dtype) + << "WebGPU vector loading requires base type to match"; + arith::PVar base; + if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + // vec3(buf[base + 0], buf[base + 1], buf[base + 2]); + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); + PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + os << "("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << buffer_vid << "[" << base_vid << " + " << i << "]"; + } + os << ")"; + } else { + // vec3(buf[index[0]], buf[index[1]], buf[index[2]]); + std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); + PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << buffer_vid << "[" << index_vid << "[" << i << "]]"; + } + os << ")"; + } + } +} + +void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { + // use ssa form. + if (print_ssa_form_) { + std::string value = PrintExpr(op->value); + ICHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + std::string value = PrintExpr(op->value); + this->stream << "let " << AllocVarID(op->var.get()) << " : "; + PrintType(op->var.dtype(), this->stream); + this->stream << " = " << value << ";\n"; + } + PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { + CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + + std::string buffer_vid = GetVarID(buffer_var.get()); + + if (value_dtype.lanes() == element_dtype.lanes()) { + // must execute print expr first + // so we won't have recursive append to stream + std::string index_vid = PrintExpr(index); + std::string value_vid = PrintExpr(op->value); + // now print the assignment line. + this->PrintIndent(); + stream << buffer_vid << "[" << index_vid << "] = "; + // special explicit conversion of bool + if (value_dtype == DataType::Bool()) { + PrintType(element_dtype, stream); + stream << "("; + } else { + ICHECK(value_dtype == element_dtype); + } + stream << value_vid; + // Special handle bool store + if (value_dtype == DataType::Bool()) { + stream << ")"; + } + stream << ";\n"; + } else { + // Vector store into scalar buffer + ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + ICHECK(value_dtype.element_of() == element_dtype) + << "WebGPU vector stire requires base type to match"; + std::string value_vid = PrintExpr(op->value); + arith::PVar base; + if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) { + // buf[base + 0] = value[0] + // buf[base + 1] = value[1] + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); + for (int i = 0; i < value_dtype.lanes(); ++i) { + this->PrintIndent(); + stream << buffer_vid << "[" << base_vid << " + " << i << "] = " << value_vid << "[" << i + << "];\n"; + } + } else { + // buf[index[0]] = value[0] + // buf[index[1]] = value[1] + std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); + for (int i = 0; i < value_dtype.lanes(); ++i) { + this->PrintIndent(); + stream << buffer_vid << "[" << index_vid << "[" << i << "]] = " << value_vid << "[" << i + << "];\n"; + } + } + } +} + +void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + + if (storage_scope.rank == runtime::StorageRank::kShared) { + this->decl_stream << "var " << vid << " : array<"; + PrintType(op->dtype, this->decl_stream); + this->decl_stream << ", " << constant_size << ">;\n"; + } else if (storage_scope.rank == runtime::StorageRank::kLocal) { + this->PrintIndent(); + this->stream << "var " << vid << " : array<"; + PrintType(op->dtype, this->stream); + this->stream << ", " << constant_size << ">;\n"; + } else { + LOG(FATAL) << "WebGPU: Do not support storage scope: " << storage_scope.to_string(); + } + this->PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const ForNode* op) { + std::string extent = PrintExpr(op->extent); + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + ICHECK(is_zero(op->min)); + stream << "for (var "; + stream << vid << " : "; + PrintType(op->loop_var.dtype(), stream); + stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; + int for_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(for_scope); + PrintIndent(); + stream << "}\n"; +} + +void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) { + // skip assert + PrintStmt(op->body); +} + +void CodeGenWebGPU::VisitStmt_(const AllocateConstNode* op) { + LOG(FATAL) << "WebGPU: do not support alloc const"; +} + +//------------------------------------------------- +// WebGPUSourceModule to enable export +//------------------------------------------------- +class WebGPUSourceModuleNode final : public runtime::ModuleNode { + public: + explicit WebGPUSourceModuleNode(std::unordered_map smap, + std::unordered_map fmap) + : smap_(smap), fmap_(fmap) {} + + const char* type_key() const final { return "webgpu"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; + return PackedFunc(nullptr); + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + LOG(FATAL) << "Not implemented"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmap_); + stream->Write(smap_); + } + + std::string GetSource(const std::string& format) final { + std::ostringstream os; + for (auto kv : smap_) { + os << kv.second; + } + return os.str(); + } + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; +}; + +//------------------------------------------------- +// Build logic. +//------------------------------------------------- +runtime::Module BuildWebGPU(IRModule mod, Target target) { + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + bool output_ssa = false; + + std::unordered_map smap; + for (auto kv : mod->functions) { + CodeGenWebGPU cg(target); + ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; + std::string f_name = global_symbol.value(); + cg.Init(output_ssa); + cg.AddFunction(f); + std::string code = cg.Finish(); + smap[f_name] = code; + } + auto n = make_object(smap, ExtractFuncInfo(mod)); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { + return BuildWebGPU(mod, target); +}); + +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h new file mode 100644 index 000000000000..57f226ba8ad6 --- /dev/null +++ b/src/target/source/codegen_webgpu.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_webgpu.h + * \brief Generate WebGPU shaders in WGSL. + * + * This module generates WGSL shading langauge. + * See https://www.w3.org/TR/WGSL/ for the language reference. + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ +#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ + +#include + +#include + +#include "codegen_c.h" + +namespace tvm { +namespace codegen { + +/*! + * \brief WebGPU code generator. + * + * Note WGSL have a different syntax from normal C. + * We only leevrage the C for expression generation and + * write most of the language generations. + */ +class CodeGenWebGPU final : public CodeGenC { + public: + explicit CodeGenWebGPU(Target target); + // overrides + std::string Finish() final; + void AddFunction(const PrimFunc& f); // NOLINT(*) + void InitFuncState(const PrimFunc& f) final; + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + + // assignment printing + void PrintSSAAssign(const std::string& target, const std::string& src, DataType type) final; + + // overload visitor + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) + + // stmt printing + void VisitStmt_(const LetStmtNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; + void VisitStmt_(const ForNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; + void VisitStmt_(const AssertStmtNode* op) final; + void VisitStmt_(const AllocateConstNode* op) final; + + private: + /*! + * \brief Records the workgroup size of the kernel. + */ + uint32_t workgroup_size_[3]; + /*! + * \brief Storage type of bool values. + */ + DataType boolean_storage_type_{DataType::Int(8)}; + Target target_; +}; +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 7d7a5fb29a7c..dd924b925596 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -22,7 +22,6 @@ * \brief Metal intrinsic rules. */ #include -#include #include "../intrin_rule.h" @@ -94,22 +93,6 @@ TVM_REGISTER_OP("tir.cos").set_attr("metal.FLowerIntrinsic", TVM_REGISTER_OP("tir.cosh") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -// There is no erf function in Metal. When erf is used, we use fast_erf instead -static PrimExpr DispatchFastErf(const PrimExpr& e) { - LOG(WARNING) << " Metal doesn't have built-in erf function. fast_erf will be used instead."; - const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - PrimExpr arg = call->args[0]; - int bits = arg.dtype().bits(); - bool isFloat = arg.dtype().is_float(); - PrimExpr res; - if (isFloat && (bits == 16 || bits == 32)) - res = topi::fast_erf_float_expr(arg, bits); - else - LOG(FATAL) << "Unsupported type in Metal fast_erf"; - return res; -} TVM_REGISTER_OP("tir.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); } // namespace intrin diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc new file mode 100644 index 000000000000..81803059fc49 --- /dev/null +++ b/src/target/source/intrin_rule_webgpu.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file intrin_rule_webgpu.cc + * \brief WebGPU intrinsic rules. + */ +#include +#include + +#include "../intrin_rule.h" + +namespace tvm { +namespace codegen { +namespace intrin { + +using tir::FLowerIntrinsic; + +// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions + +struct ReturnAbs { + std::string operator()(DataType t, std::string name) const { return "abs"; } +}; + +TVM_REGISTER_OP("tir.fabs") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.acos") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.acosh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.asin") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.asinh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.atan") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.atan2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.ceil") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.cos").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.cosh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.exp2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.floor") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.fma").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.log2") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.round") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.sin").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.sinh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.sqrt") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.tan").set_attr("webgpu.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.tanh") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +TVM_REGISTER_OP("tir.trunc") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); + +// extra dispatch +TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); + +} // namespace intrin +} // namespace codegen +} // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 94f1bf16a25e..dc1d8f865baa 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -97,7 +97,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) { +runtime::Module BuildSPIRV(IRModule mod, Target target) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -122,7 +122,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); - std::string entry = webgpu_restriction ? "main" : f_name; + std::string entry = f_name; VulkanShader shader = cg.BuildFunction(f, entry); @@ -144,12 +144,6 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) spirv_tools.ValidateShader(shader.data); } - if (webgpu_restriction) { - for (auto param : f->params) { - ICHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments"; - } - } - if (postproc != nullptr) { TVMByteArray arr; arr.data = reinterpret_cast(dmlc::BeginPtr(shader.data)); @@ -168,11 +162,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) } TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { - return BuildSPIRV(mod, target, false); -}); - -TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { - return BuildSPIRV(mod, target, true); + return BuildSPIRV(mod, target); }); } // namespace codegen diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 0c65f1718a5d..ac304b92b6d7 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -100,40 +100,6 @@ TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); - -// WebGPU rules. -TVM_REGISTER_OP("tir.floor") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.ceil") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.round") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.nearbyint") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.trunc") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.fabs") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.sqrt") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", - DispatchGLSLPureIntrin); - -TVM_REGISTER_OP("tir.tanh") - .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); } // namespace intrin namespace legalize { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index f1b5397b3757..d642484532f9 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -215,7 +215,7 @@ class InstrBuilder { * \brief add sequence of values to instruction * \param args The instruction sequence * \return reference to self. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template InstrBuilder& AddSeq(Args&&... args) { @@ -328,7 +328,7 @@ class IRBuilder { * \brief Add code to debug segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void Debug(spv::Op op, Args&&... args) { @@ -339,7 +339,7 @@ class IRBuilder { * \brief Set the name of a value or label * \param obj The object to be named * \param name The name of the object - * \tparams Obj The type of the object being named. Typically a Label or Value. + * \tparam Obj The type of the object being named. Typically a Label or Value. */ template void SetName(Obj&& obj, const std::string& name) { @@ -350,7 +350,7 @@ class IRBuilder { * \brief Add Execution mode to a function. * \param func The function value * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void ExecutionMode(Value func, Args&&... args) { @@ -360,7 +360,7 @@ class IRBuilder { * \brief Add code to decorate segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void Decorate(spv::Op op, Args&&... args) { @@ -370,7 +370,7 @@ class IRBuilder { * \brief Add code to global segment. * \param op The operator * \param args The instruction sequence - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template void DeclareGlobal(spv::Op op, Args&&... args) { @@ -382,7 +382,7 @@ class IRBuilder { * \param op The operator * \param args The instruction sequence * \return The result SSA value. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template Instr MakeInst(spv::Op op, Args&&... args) { @@ -395,7 +395,7 @@ class IRBuilder { * \param out_type The result type. * \param args The instruction sequence * \return The result SSA value. - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ template Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) { @@ -435,7 +435,7 @@ class IRBuilder { * \brief Build vector by concatenating components * * \param vec The vector component - * \tparams Args The positional arguments + * \tparam Args The positional arguments */ Value Concat(const std::vector& vec); /*! diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 078e32ca57c7..828ab010831f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -1026,4 +1026,46 @@ TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { return const_true(t.lanes(), span); }); +PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { + auto plus_4 = make_const(DataType::Float(bits), 4.f); + auto minus_4 = make_const(DataType::Float(bits), -4.f); + + // The monomial coefficients of the numerator polynomial (odd). + auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f); + auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f); + auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f); + auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f); + auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f); + auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f); + auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f); + + // The monomial coefficients of the denominator polynomial (even). + auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f); + auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f); + auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f); + auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f); + auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f); + + // clamp x + auto x = tvm::max(tvm::min(arg, plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; +} + } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 936c9938dd3a..17efcc8c70a7 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -38,7 +38,6 @@ #include #include "../../src/runtime/meta_data.h" -#include "../../src/runtime/vulkan/vulkan_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { @@ -150,9 +149,9 @@ WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore:: class WebGPUModuleNode final : public runtime::ModuleNode { public: - explicit WebGPUModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source) { + explicit WebGPUModuleNode(std::unordered_map smap, + std::unordered_map fmap) + : smap_(smap), fmap_(fmap) { auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); CHECK(fp != nullptr); create_shader_ = *fp; @@ -168,10 +167,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { std::ostringstream os; dmlc::JSONWriter writer(&os); info.Save(&writer); - TVMByteArray arr; - arr.data = reinterpret_cast(it->second.data.data()); - arr.size = it->second.data.size() * sizeof(it->second.data[0]); - return create_shader_(os.str(), arr); + return create_shader_(os.str(), it->second); } else { return PackedFunc(nullptr); } @@ -190,29 +186,27 @@ class WebGPUModuleNode final : public runtime::ModuleNode { private: // function information table. - std::unordered_map smap_; + std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The source std::string source_; // Callback to get the GPU function. - TypedPackedFunc create_shader_; + TypedPackedFunc create_shader_; }; Module WebGPUModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::unordered_map smap; + std::unordered_map smap; std::unordered_map fmap; - std::string fmt; - stream->Read(&fmt); stream->Read(&fmap); stream->Read(&smap); - return Module(make_object(smap, fmap, "")); + return Module(make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module. -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 8df382dbc837..b341a7d4b1a4 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1037,8 +1037,8 @@ export class Instance implements Disposable { this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { return webGPUContext.getDeviceAPI(name); }); - this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { - return webGPUContext.createShader(info, data); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, code: string) => { + return webGPUContext.createShader(info, code); }); this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { await webGPUContext.sync(); diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 5de47c200dcc..faf6fac990c8 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -79,9 +79,9 @@ export class WebGPUContext { * Create a PackedFunc that runs the given shader * * @param info The function information in json. - * @param data The shader data(in SPIRV) + * @param code The shader data(in WGSL) */ - createShader(info: string, data: Uint8Array): Function { + createShader(info: string, code: string): Function { const finfo = JSON.parse(info); const layoutEntries: Array = []; for (let i = 0; i < finfo.arg_types.length; ++i) { @@ -102,16 +102,13 @@ export class WebGPUContext { entries: layoutEntries }); - const textDecoder = new TextDecoder("utf-8") - const codeString = textDecoder.decode(data.buffer) - const pipeline = this.device.createComputePipeline({ layout: this.device.createPipelineLayout({ bindGroupLayouts: [ bindGroupLayout ] }), compute: { module: this.device.createShaderModule({ - code: codeString + code: code }), entryPoint: "main" } diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index ac1a241a9baa..6e34a8a2b36c 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -37,12 +37,10 @@ def test_rpc(): # generate the wasm library target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") runtime = Runtime("cpp", {"system-lib": True}) - if not tvm.runtime.enabled(target_host): - raise RuntimeError("Target %s is not enbaled" % target_host) n = 2048 A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i)) + 1.0), name="B") s = te.create_schedule(B.op) num_thread = 2 @@ -75,7 +73,7 @@ def check(remote): f1 = remote.system_lib() addone = f1.get_function("addone") addone(a, b) - np.testing.assert_equal(b.numpy(), a.numpy() + 1) + np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), atol=1e-5, rtol=1e-5) print("Test pass..") check(remote)