diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 7547af81eb1a..2c6b03daeccf 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -52,6 +52,7 @@ elseif(USE_DNNL STREQUAL "C_SRC") find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl.cc + src/runtime/contrib/dnnl/dnnl_utils.cc src/runtime/contrib/cblas/dnnl_blas.cc) list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC}) message(STATUS "Build with DNNL C source module: " ${EXTERN_LIBRARY_DNNL}) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 74cd19b3aaba..9a9ed5f83d97 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -50,6 +50,29 @@ namespace contrib { using namespace backend; +/*! + * \brief Replace var expr which bind with args of call node + * + * \param args vector of expression (contains vars or constant nodes) + * \param cn call node which describe mapping of internal body vars with args + * \return updated vector of expressions + */ +static tvm::Array BindToCallNodeArgs(const std::vector& args, const CallNode* cn) { + tvm::Array res; + for (const auto& arg : args) { + if (arg->IsInstance()) { + res.push_back(arg); + } else { + auto body_params = cn->op.as()->params; + auto found = std::find(body_params.begin(), body_params.end(), arg); + ICHECK(found != body_params.end()); + auto idx = std::distance(body_params.begin(), found); + res.push_back(cn->args[idx]); + } + } + return res; +} + #ifndef USE_JSON_RUNTIME // C source runtime inline size_t GetShape1DSize(const Type& type) { const auto shape = GetShape(type); @@ -203,7 +226,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C // Give the ndarray a unique name to ease the initialization of it at // runtime. - std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); + std::string const_symbol = "dnnl_" + ext_func_id_; + std::string const_var_name = CreateConstVar(const_symbol, const_idx_); const_vars_.push_back(const_var_name); const_idx_++; @@ -274,7 +298,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller), Conv2d(conv_call)); } else if (pattern_name == "dnnl.conv2d_relu") { - const auto* conv_call = GetRootCall(callee->body.as(), 1, {"nn.conv2d", "nn.relu"}); + const auto* conv_call = GetRootCall(callee->body.as(), 1, + (const std::vector){"nn.conv2d", "nn.relu"}); return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller), Conv2d(conv_call)); } @@ -434,29 +459,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { #else // DNNL JSON runtime -/*! - * \brief Replace var expr which bind with args of call node - * - * \param args vector of expression (contains vars or constant nodes) - * \param cn call node which describe mapping of internal body vars with args - * \return updated vector of expressions - */ -static tvm::Array BindToCallNodeArgs(const std::vector& args, const CallNode* cn) { - tvm::Array res; - for (const auto& arg : args) { - if (arg->IsInstance()) { - res.push_back(arg); - } else { - auto body_params = cn->op.as()->params; - auto found = std::find(body_params.begin(), body_params.end(), arg); - ICHECK(found != body_params.end()); - auto idx = std::distance(body_params.begin(), found); - res.push_back(cn->args[idx]); - } - } - return res; -} - /*! \brief Serializer to DNNL JSON runtime module */ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode;