diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index dad481cfbf3f..8e12b48aac76 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -170,6 +170,22 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD - Values: Int ```(default=)``` - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass. +* MXNET_ENABLE_CUDA_GRAPHS + - Values: 0(false) or 1(true) ```(default=0)``` + - If set to `1`, MXNet will utilize CUDA graphs when executing models on the GPU when possible. + - For CUDA graphs execution, one needs to use either symbolic model or Gluon model hybridized with options `static_alloc` and `static_shape` set to True. +* MXNET_CUDA_GRAPHS_VERBOSE + - Values: 0(false) or 1(true) ```(default=0)``` + - If set to `1`, CUDA graphs executor will provide information about the graph being captured and executed. +* MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES + - Values: Int ```(default=0)``` + - The maximum number of log messages generated by CUDA graphs executor. +* MXNET_CUDA_GRAPHS_DBG_FILE + - Values: String ```(default='', to indicate no debug dot files should be created)``` + - The file prefix for '.dot' files for each graph created. Full path is -devN-{trn,inf}..dot . +* MXNET_CUDA_GRAPHS_DBG_FILE_FLAGS + - Values: Int ```(default=)``` + - A bitmask to enable various types of info in the debug '.dot' files. See cudaGraphDebugDotFlags in the CUDA runtime API doc for details. ## Control the Data Communication diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 2fec1768ea86..c936d3e84afa 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -357,6 +357,19 @@ using FNeedCalibrateInput = std::function(const NodeAttrs& attr */ using FNeedCalibrateOutput = std::function(const NodeAttrs& attrs)>; +#if MXNET_USE_CUDA + +/*! + * \brief Register a function to determine if + * the operator implementation is compatible + * with CUDA graphs. This requires the execution + * to stay the same as long as the shape and type + * of input stays the same. + */ +using FIsCUDAGraphsCompatible = std::function; + +#endif + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/src/imperative/attach_op_execs_pass.cc b/src/imperative/attach_op_execs_pass.cc index 4a8c51d107c7..732391fdd747 100644 --- a/src/imperative/attach_op_execs_pass.cc +++ b/src/imperative/attach_op_execs_pass.cc @@ -47,8 +47,10 @@ namespace exec { // FComputeExecutor and FStatefulComputeExecutor inherit from this class class StorageFallbackOpExecutor : public OpExecutor { public: - explicit StorageFallbackOpExecutor(std::vector mutate_idx) - : mutate_idx_(std::move(mutate_idx)) {} + explicit StorageFallbackOpExecutor(const NodeAttrs& attrs, + DispatchMode dispatch_mode, + std::vector mutate_idx) + : OpExecutor(attrs, dispatch_mode), mutate_idx_(std::move(mutate_idx)) {} void Setup() override { init_ = false; @@ -146,11 +148,13 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { return state_; } - explicit StatefulComputeExecutor(OpStatePtr state, + explicit StatefulComputeExecutor(const NodeAttrs& attrs, + DispatchMode dispatch_mode, + OpStatePtr state, FStatefulCompute fcompute, ExecType exec_type, const std::vector& mutate_idx) - : StorageFallbackOpExecutor(mutate_idx), + : StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx), state_(std::move(state)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} @@ -168,7 +172,7 @@ class StatefulComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; INVALIDATE_OUTPUTS(out_array, req); std::vector* pInArray = &in_array; - CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs_); + CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs); fcompute_(state_, op_ctx, *pInArray, req, out_array); } @@ -186,17 +190,17 @@ class StatefulComputeExExecutor : public OpExecutor { return state_; } - explicit StatefulComputeExExecutor(NodeAttrs attrs, + explicit StatefulComputeExExecutor(const NodeAttrs& attrs, + DispatchMode dispatch_mode, OpStatePtr state, FStatefulComputeEx fcompute, ExecType exec_type) - : attrs_(std::move(attrs)), + : OpExecutor(attrs, dispatch_mode), state_(std::move(state)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} private: - NodeAttrs attrs_; OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -210,7 +214,7 @@ class FComputeExecutor : public StorageFallbackOpExecutor { op_ctx.run_ctx = rctx; INVALIDATE_OUTPUTS(out_array, req); PreFCompute(is_gpu); - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + fcompute_(attrs, op_ctx, in_data_, req, out_data_); PostFCompute(is_gpu); } @@ -218,17 +222,16 @@ class FComputeExecutor : public StorageFallbackOpExecutor { return exec_type_; } - explicit FComputeExecutor(NodeAttrs attrs, + explicit FComputeExecutor(const NodeAttrs& attrs, + DispatchMode dispatch_mode, FCompute fcompute, ExecType exec_type, const std::vector& mutate_idx) - : StorageFallbackOpExecutor(mutate_idx), - attrs_(std::move(attrs)), + : StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} private: - NodeAttrs attrs_; FCompute fcompute_; ExecType exec_type_; }; @@ -240,8 +243,8 @@ class FComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; INVALIDATE_OUTPUTS(out_array, req); std::vector* pInArray = &in_array; - CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs_); - fcompute_(attrs_, op_ctx, *pInArray, req, out_array); + CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs); + fcompute_(attrs, op_ctx, *pInArray, req, out_array); } void Setup() override {} @@ -250,11 +253,13 @@ class FComputeExExecutor : public OpExecutor { return exec_type_; } - explicit FComputeExExecutor(NodeAttrs attrs, FComputeEx fcompute, ExecType exec_type) - : attrs_(std::move(attrs)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} + explicit FComputeExExecutor(const NodeAttrs& attrs, + DispatchMode dispatch_mode, + FComputeEx fcompute, + ExecType exec_type) + : OpExecutor(attrs, dispatch_mode), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} private: - NodeAttrs attrs_; FComputeEx fcompute_; ExecType exec_type_; }; @@ -309,14 +314,15 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { ret[i] = std::make_shared( - inode.source->attrs, state, fcompute_ex, exec_type); + inode.source->attrs, dispatch_modes[i], state, fcompute_ex, exec_type); } else { FStatefulCompute fcompute = common::GetFCompute(op, "FStatefulCompute", vctx[i]); CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - ret[i] = std::make_shared(state, fcompute, exec_type, mutate_index); + ret[i] = std::make_shared( + inode.source->attrs, dispatch_modes[i], state, fcompute, exec_type, mutate_index); } } else if (is_layer_backward.get(op, false)) { CHECK_GE(inode.control_deps.size(), 1); @@ -327,25 +333,33 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, common::GetFCompute(op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, exec_type); + ret[i] = std::make_shared(inode.source->attrs, + dispatch_modes[i], + ret[fwd_id].get()->state(), + fcompute_ex, + exec_type); } else { FStatefulCompute fcompute = common::GetFCompute(op, "FStatefulCompute", vctx[i]); CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - ret[i] = std::make_shared( - ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index); + ret[i] = std::make_shared(inode.source->attrs, + dispatch_modes[i], + ret[fwd_id].get()->state(), + fcompute, + exec_type, + mutate_index); } } else { FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(inode.source->attrs, fcomp_ex, exec_type); + ret[i] = std::make_shared( + inode.source->attrs, dispatch_modes[i], fcomp_ex, exec_type); } else if (fcompute != nullptr) { ret[i] = std::make_shared( - inode.source->attrs, fcompute, exec_type, mutate_index); + inode.source->attrs, dispatch_modes[i], fcompute, exec_type, mutate_index); } else { LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; } diff --git a/src/imperative/cuda_graphs.h b/src/imperative/cuda_graphs.h new file mode 100644 index 000000000000..c9e16d84e8b3 --- /dev/null +++ b/src/imperative/cuda_graphs.h @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file cuda_graphs.h + * \brief Wrappers for use of CUDA Graphs API + */ +#ifndef MXNET_IMPERATIVE_CUDA_GRAPHS_H_ +#define MXNET_IMPERATIVE_CUDA_GRAPHS_H_ + +#include +#include +#include +#include +#include +#include + +#include "./exec_pass.h" +#include "../common/cuda/utils.h" + +#if MXNET_USE_CUDA +#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10020) +#else +#define CUDA_GRAPHS_AVAILABLE (0) +#endif + +#if CUDA_GRAPHS_AVAILABLE + +namespace mxnet { +namespace cuda_graphs { + +inline std::string CudaDim3ToString(const dim3& dims) { + std::stringstream ss; + if (dims.z != 1) + ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")"; + else if (dims.y != 1) + ss << "(" << dims.x << "," << dims.y << ")"; + else + ss << "(" << dims.x << ")"; + return ss.str(); +} + +// Return the list of CUDA Graph nodes from a graph +inline std::vector GetCudaGraphNodes(cudaGraph_t cuda_graph) { + size_t numNodes; + CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast(nullptr), &numNodes)); + if (numNodes == 0) + return std::vector(); + std::vector graphNodes(numNodes); + CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes)); + return graphNodes; +} + +// Create a description of a CUDA Graph node +inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) { + std::stringstream ss; + + // The following introspection calls are made through the driver API in order to bypass + // problems that would arise if multiple statically-linked copies of the runtime exist. + + CUgraphNode cu_node = node; + CUgraphNodeType t; + CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t)); + switch (t) { + case CU_GRAPH_NODE_TYPE_KERNEL: { + CUDA_KERNEL_NODE_PARAMS kparams; + auto err = cuGraphKernelNodeGetParams(cu_node, &kparams); + if (err == CUDA_SUCCESS) { + ss << "GPUKernel@" << kparams.func; + dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ); + dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ); + ss << "<<>>"; + ss << "(..."; + if (kparams.sharedMemBytes != 0) + ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes; + ss << ")"; + } else { + ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err; + } + } break; + case CU_GRAPH_NODE_TYPE_MEMCPY: { + cudaMemcpy3DParms mparams = {}; + CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams)); + // If memcpy is seen, return without setting up runnable executor + switch (mparams.kind) { + case cudaMemcpyHostToHost: + ss << "Host->Host "; + break; + case cudaMemcpyHostToDevice: + ss << "Host->Device "; + break; + case cudaMemcpyDeviceToHost: + ss << "Device->Host "; + break; + case cudaMemcpyDeviceToDevice: + ss << "Device->Device "; + break; + default: + break; + } + ss << "Memcpy"; + } break; + case CU_GRAPH_NODE_TYPE_MEMSET: { + cudaMemsetParams mparams = {}; + CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams)); + if (mparams.height == 1 && mparams.elementSize == 1) { + ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value + << ", count=" << mparams.width << ")"; + } else { + if (mparams.elementSize == 1) + ss << "cudaMemset2D"; + else + ss << "MemSet"; + ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch + << ", value=" << mparams.value << ", width=" << mparams.width + << ", height=" << mparams.height << ")"; + } + } break; + case CU_GRAPH_NODE_TYPE_HOST: + ss << "Host (executable) node"; + break; + case CU_GRAPH_NODE_TYPE_GRAPH: + ss << "Node which executes an embedded graph"; + break; + case CU_GRAPH_NODE_TYPE_EMPTY: + ss << "Empty (no-op) node"; + break; + default: + ss << "Unknown/Invalid node type " << t; + } + return ss.str(); +} + +// CUDA Graphs are managed in RAII fashion by smart pointers below. +// Function objects (preferred for readability) provide the deleter function. +class CudaGraphDeleter { + public: + void operator()(cudaGraph_t graph) { + if (graph != nullptr) + CUDA_CALL(cudaGraphDestroy(graph)); + } +}; + +// CUDA Graphs Executors are managed in RAII fashion by smart pointers below. +// Function objects (preferred for readability) provide the deleter function. +class CudaGraphExecDeleter { + public: + void operator()(cudaGraphExec_t graph_exec) { + if (graph_exec != nullptr) + CUDA_CALL(cudaGraphExecDestroy(graph_exec)); + } +}; + +// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'), +// characterized by a starting index in the OpExecutor list and a number of ops. +class CudaGraphsSubSegExec { + public: + CudaGraphsSubSegExec(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu, + bool verbose, + int from_op_idx, + int num_ops, + bool ops_are_cuda_graph_compatible = true) + : from_op_idx_(from_op_idx), + num_ops_(num_ops), + graph_(nullptr), + graph_exec_(nullptr), + graph_exec_id_(0) { + if (ops_are_cuda_graph_compatible) { + MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops); + MakeGraphExec(exec_list, rctx); + } + } + + void Update(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu, + bool verbose) { + // Current executor should be Runnable with the same parameters + CHECK(IsRunnable()); + MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_); + + cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError; + cudaGraphNode_t error_node; + cudaError_t err = + cudaGraphExecUpdate(graph_exec_.get(), graph_.get(), &error_node, &update_result); + switch (err) { + case cudaErrorGraphExecUpdateFailure: + MakeGraphExec(exec_list, rctx); + break; + case cudaSuccess: + CHECK_EQ(update_result, cudaGraphExecUpdateSuccess); + break; + default: + // Respond normally to unusual cudaGraphExecUpdate() ret vals + CUDA_CALL(err); + } + } + + void RunSubSeg(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu) { + if (IsRunnable()) { + auto s = rctx.get_stream(); + const cudaStream_t cu_s = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s)); + } else { + // No CUDA Graph could be made for this portion of the OpSegment. Run conventionally. + for (int i = 0; i != num_ops_; ++i) + exec_list[from_op_idx_ + i]->Run(rctx, is_gpu); + } + } + + bool IsRunnable() { + return graph_exec_ != nullptr; + } + + int NumGraphNodes() { + size_t numNodes; + CUDA_CALL(cudaGraphGetNodes(graph_.get(), static_cast(nullptr), &numNodes)); + return numNodes; + } + + private: + void MakeGraph(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu, + bool verbose, + int from_op_idx, + int num_ops) { + auto s = rctx.get_stream(); + const cudaStream_t cu_s = mshadow::Stream::GetStream(s); + // Create CUDA Graph + // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers + // to sync their streams without disturbing this capture. + CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal)); + // Run those oprs in the sub segment while capturing- no actual GPU work is launched. + for (int i = 0; i != num_ops; ++i) + exec_list[from_op_idx + i]->Run(rctx, is_gpu); + cudaGraph_t cuda_graph = nullptr; + CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph)); + graph_.reset(cuda_graph, CudaGraphDeleter()); + + if (verbose) { + std::vector graph_nodes = GetCudaGraphNodes(cuda_graph); + size_t num_nodes = graph_nodes.size(); + LOG(INFO) << " Graph has " << num_nodes << " nodes:"; + for (size_t i = 0; i != num_nodes; ++i) { + LOG(INFO) << " node " << i << " = " << CudaGraphNodeToString(graph_nodes[i]); + } + } + } + + void MakeGraphExec(const std::vector>& exec_list, + const RunContext& rctx) { + // Note that this routine is not invoked when a graph executor is merely updated. + cudaGraphExec_t cuda_graph_exec; + cudaGraphNode_t error_node; + char log_buffer[1000]; + + CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(), &error_node, log_buffer, 1000)); + graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter()); + + // At this point we have a CUDA Graph executor + static int num_graph_creations = 0; + graph_exec_id_ = num_graph_creations++; + + static size_t max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0); + if (graph_exec_id_ < max_log_entries) { + LOG(INFO) << "Created CUDA graph " << graph_exec_id_; + if (num_graph_creations == max_log_entries) + LOG(INFO) << "Further CUDA graph creation log messages are suppressed."; + } + // Create a .dot file for graph visualization if requested + static std::string dotfile_base = dmlc::GetEnv("MXNET_CUDA_GRAPHS_DBG_FILE", std::string()); + if (dotfile_base.size() > 0) { +#if CUDA_VERSION >= 11030 + static int dotfile_flags = dmlc::GetEnv("MXNET_CUDA_GRAPHS_DBG_FILE_FLAGS", + static_cast(cudaGraphDebugDotFlagsVerbose)); + std::ostringstream filename; + const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train; + int dev_id = rctx.ctx.dev_id; + filename << dotfile_base << "-" + << "dev" << dev_id << "-" << (is_train ? "trn" : "inf") << "-" << graph_exec_id_ + << ".dot"; + CUDA_CALL(cudaGraphDebugDotPrint(graph_.get(), filename.str().c_str(), dotfile_flags)); +#else + [[maybe_unused]] static bool dot_file_unsupported = []() { // NOLINT + LOG(INFO) << "MXNET_CUDA_GRAPHS_DBG_FILE setting ignored- requires CUDA version >= 11.3"; + return true; + }(); +#endif // CUDA_VERSION >= 11030 + } + } + + int from_op_idx_; + int num_ops_; + using cudaGraphStruct_t = typename std::remove_pointer::type; + using cudaGraphExecStruct_t = typename std::remove_pointer::type; + std::shared_ptr graph_; + std::shared_ptr graph_exec_; + size_t graph_exec_id_; +}; + +// The CudaGraph executor and associated Tempspace ptrs for which it is valid. +struct CudaGraphInfo { + std::vector cuda_graph_subseg_execs; + bool has_been_run_conventionally = false; + std::vector tempspace_dptrs; +}; +// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and +// the state of the is_train flag of the OpContext. If the tempspace_dptrs change, we +// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph. +struct CudaGraphCacheKey { + cudaStream_t cu_s; + bool is_train; + // overload '<' so CudaGraphCacheKey can be used as a std::map key + bool operator<(const CudaGraphCacheKey& other) const { + return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train); + } +}; +using CudaGraphCache = std::map; + +class CudaGraphsExec { + public: + CudaGraphsExec(const std::vector>& exec_list, + bool is_gpu, + const char* opr_names) + : verbose_(false), is_enabled_(false) { + opr_names_ = opr_names ? std::string(opr_names) : std::string(); + if (is_gpu) { + is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false); + verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false); + SetTempSpaces(exec_list); + } + } + + void RunAll(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu) { + // If this a CPU op or CUDA Graphs use isn't possible, run normally and return + if (!is_gpu || !is_enabled_) { + // Run all opr in the sub-graph + exec::OpExecutor::RunAll(exec_list, rctx, is_gpu); + return; + } + + // Also if we're in a warm-up period where tempspace pointers are likely + // to change, run normally and return + auto s = rctx.get_stream(); + const cudaStream_t cu_s = mshadow::Stream::GetStream(s); + // All the ops in the bulked segment will have the same setting of is_train as the first op + const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train; + const CudaGraphCacheKey key = {cu_s, is_train}; + // Look-up the CUDA Graph info for this combo of stream and is_train setting + // This may create a default-initialized new entry. + auto& cuda_graph_info = cache_[key]; + if (!cuda_graph_info.has_been_run_conventionally) { + // Run all opr in the sub-graph + exec::OpExecutor::RunAll(exec_list, rctx, is_gpu); + cuda_graph_info.has_been_run_conventionally = true; + return; + } + + // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors' + // (there might be more than one executor if some ops in the segment are not capturable) + auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s); + + // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'. + if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 && + cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) { + // Update all runnable executors. Non-runnable executors launch their ops conventionally. + for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) { + if (subseg_exec.IsRunnable()) + subseg_exec.Update(exec_list, rctx, is_gpu, verbose_); + } + } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) { + // No executors exist yet, so create them. + if (verbose_) + LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_; + // Make one or more CUDA Graphs, avoiding ops that are not compatible. + for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) { + int num_good_ops = 0; + for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) { + if (OpOK(exec_list[last_op_idx])) + num_good_ops++; + else + break; + } + if (num_good_ops > 0) { + CreateSubExecOverRegion(exec_list, + rctx, + is_gpu, + first_op_idx, + first_op_idx + num_good_ops, + &cuda_graph_info.cuda_graph_subseg_execs); + first_op_idx += num_good_ops; + } + if (first_op_idx != exec_list.size()) { + // We had to have hit an op that was not OK. + if (verbose_) { + LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]" + << " of op segment " << opr_names_; + } + CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false); + cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg); + first_op_idx++; + } + } + // During graph capture, the ops may be asking for the tempworkspace. This should + // not alter the base pointers, since this op seg has been executed before on this + // stream (i.e. on this gpu worker). Safest to double-check this though. + auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s); + if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs) + LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use."; + cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs; + } + // Now execute the CUDA Graph that we either just created or looked-up in the cache. + if (verbose_) { + int runnable_execs = 0; + int bypassed_ops = 0; + for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) { + if (subseg_exec.IsRunnable()) { + LOG(INFO) << "Launching captured graph with " << subseg_exec.NumGraphNodes() << " nodes."; + runnable_execs++; + } else { + bypassed_ops++; + } + } + if (bypassed_ops > 0) + LOG(INFO) << " (bypassing " << bypassed_ops << " un-capturable ops)"; + } + for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) + subseg_exec.RunSubSeg(exec_list, rctx, is_gpu); + } + + private: + // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx). If such a graph + // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a + // number of smaller graphs that avoid those ops with the memcpys. + void CreateSubExecOverRegion(const std::vector>& exec_list, + const RunContext& rctx, + bool is_gpu, + size_t from_op_idx, + size_t upto_op_idx, + std::vector* cuda_graph_subseg_execs) { + // Optimistically try to create a CUDA Graph of the entire op segment region + + int num_ops = upto_op_idx - from_op_idx; + CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops); + if (full_opseg.IsRunnable()) { + cuda_graph_subseg_execs->push_back(full_opseg); + } else { + if (verbose_) + LOG(INFO) << " Graph was not runnable- creating op sub-segments..."; + // Enter fall-back approach to making many sub-execs + for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx;) { + int num_good_ops = 0; + for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) { + CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1); + if (single_opseg.IsRunnable()) + num_good_ops++; + // Is it time to create a subseg exec from accumulated good ops? + if (num_good_ops > 0 && (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) { + if (verbose_) + LOG(INFO) << "Capturing CUDA graph of op sub segment[" << first_op_idx << ":" + << (first_op_idx + num_good_ops - 1) << "]" + << " of op segment " << opr_names_; + CudaGraphsSubSegExec good_opseg( + exec_list, rctx, is_gpu, verbose_, first_op_idx, num_good_ops); + CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation"; + cuda_graph_subseg_execs->push_back(good_opseg); + first_op_idx += num_good_ops; + } + // If the last single op was not runnable, use the exec to handle that op conventionally + if (!single_opseg.IsRunnable()) { + if (verbose_) { + LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]" + << " of op segment " << opr_names_; + // Generate throw-away exec in order to produce a diagnostic listing of graph nodes + CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1); + } + cuda_graph_subseg_execs->push_back(single_opseg); + first_op_idx++; + break; + } + } + } + } + } + + // Is the Op OK to make part of a CUDA Graph? + bool OpOK(const std::shared_ptr& exec) { + static auto& fgraphcompatible = Op::GetAttr("FIsCUDAGraphsCompatible"); + static auto& fcompute_ex = Op::GetAttr("FComputeEx"); + static auto& fstatefulcompute = Op::GetAttr("FStatefulCompute"); + static auto& fstatefulcompute_ex = Op::GetAttr("FStatefulComputeEx"); + const auto& attrs = exec->attrs; + if (attrs.op != nullptr) { + const auto f = fgraphcompatible.get(attrs.op, nullptr); + if (f != nullptr) { + return f(attrs, exec->op_ctx.is_train); + } + if (fstatefulcompute.get(attrs.op, nullptr) != nullptr || + fstatefulcompute_ex.get(attrs.op, nullptr) != nullptr) { + if (verbose_) { + LOG(INFO) << "Omitting stateful operator " << attrs.op->name << " from CUDA graph."; + } + return false; + } + if ((fcompute_ex.get(attrs.op, nullptr) != nullptr && + exec->dispatch_mode == DispatchMode::kFComputeEx) || + exec->dispatch_mode == DispatchMode::kFComputeFallback) { + if (verbose_) { + LOG(INFO) << "Omitting operator " << attrs.op->name + << " from CUDA graph due to dispatch mode " + << static_cast(exec->dispatch_mode); + } + return false; + } + } + for (auto& resource : exec->op_ctx.requested) { + if (!(resource.req.type == ResourceRequest::kTempSpace)) { + if (verbose_) { + LOG(INFO) << "Omitting operator " << attrs.op->name + << " from CUDA graph due to using the resource type " + << static_cast(resource.req.type); + } + return false; + } + } + return true; + } + + // Determine Tempspaces used by ops. Other resource uses disable CUDA Graphs. + void SetTempSpaces(const std::vector>& exec_list) { + // Gather info about the ops use of TempSpace. + if (is_enabled_) { + std::set tempspaces_set; + for (auto& exec : exec_list) { + for (auto& resource : exec->op_ctx.requested) { + if (resource.req.type == ResourceRequest::kTempSpace) { + tempspaces_set.insert(&resource); + } + } + } + tempspaces_.assign(tempspaces_set.begin(), tempspaces_set.end()); + } + } + + // Return the addresses of the gpu TempSpace areas + std::vector GetGPUTempspacePtrs(mshadow::Stream* s) { + std::vector ret; + for (const auto& resource : tempspaces_) { + // Ask for minimal allocation to get base pointer without increasing the size + auto* base_ptr = resource->get_space_typed(mshadow::Shape1(1), s).dptr_; + ret.push_back(static_cast(base_ptr)); + } + return ret; + } + + CudaGraphCache cache_; + std::vector tempspaces_; + std::string opr_names_; + bool verbose_; + bool is_enabled_; +}; + +} // namespace cuda_graphs +} // namespace mxnet + +#endif // CUDA_GRAPHS_AVAILABLE + +#endif // MXNET_IMPERATIVE_CUDA_GRAPHS_H_ diff --git a/src/imperative/exec_pass.h b/src/imperative/exec_pass.h index 7667d97632fc..02fa967a19b7 100644 --- a/src/imperative/exec_pass.h +++ b/src/imperative/exec_pass.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -84,6 +85,13 @@ class OpExecutor { std::vector req; /*! \brief runtime op context, contains allocated resources */ OpContext op_ctx; + /*! \brief attributes of the node */ + NodeAttrs attrs; + /*! \brief dispatch mode of the executor */ + DispatchMode dispatch_mode; + + explicit OpExecutor(NodeAttrs attrs, DispatchMode dispatch_mode) + : attrs(std::move(attrs)), dispatch_mode(dispatch_mode) {} /*! \brief virtual destructor */ virtual ~OpExecutor() {} /*! @@ -98,6 +106,17 @@ class OpExecutor { * \param rctx The runtime context passed in by environment. */ virtual void Run(RunContext rctx, bool is_gpu) = 0; + /*! + * \brief run the operators of a vector of execs, given runtime context on device. + * This function call does not synchronize the stream. + * \param rctx The runtime context passed in by environment. + */ + static void RunAll(const std::vector>& execs, + RunContext rctx, + bool is_gpu) { + for (auto& exec : execs) + exec->Run(rctx, is_gpu); + } /*! \return the execution type */ virtual ExecType exec_type() const = 0; /*! \return return engine variable for operator states */ diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index ce1a60fb2b20..7f90528f4793 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -27,6 +27,8 @@ #include #include +#include "./exec_pass.h" +#include "./cuda_graphs.h" #include "../c_api/c_api_common.h" #include "../common/exec_utils.h" #include "../common/utils.h" @@ -1248,6 +1250,21 @@ inline Engine::OprHandle CreateEngineOp( bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask; bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync; +#if CUDA_GRAPHS_AVAILABLE + // Provide initialized `cuda_graphs_exec`, which when captured + // by exec_fun, acts like a static variable inside the mutable closure. + cuda_graphs::CudaGraphsExec cuda_graphs_exec(execs, is_gpu, opr_names); + auto exec_fun = [cuda_graphs_exec, execs, is_async, is_gpu]( + RunContext ctx, + Engine::CallbackOnStart on_start, + Engine::CallbackOnComplete on_complete) mutable { + on_start(); + if (is_async) { + execs[0]->op_ctx.async_on_complete = on_complete; + } + // Run all opr in the sub-graph with CUDA graphs executor if possible + cuda_graphs_exec.RunAll(execs, ctx, is_gpu); +#else auto exec_fun = [execs, is_async, is_gpu](RunContext ctx, Engine::CallbackOnStart on_start, Engine::CallbackOnComplete on_complete) { @@ -1255,8 +1272,8 @@ inline Engine::OprHandle CreateEngineOp( if (is_async) { execs[0]->op_ctx.async_on_complete = on_complete; } - for (const auto& exec : execs) - exec->Run(ctx, is_gpu); + exec::OpExecutor::RunAll(execs, ctx, is_gpu); +#endif // call on complete only if it is async op if (!is_async) { if (is_gpu) { diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index 802378839bc2..0b247900ce4c 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -45,15 +45,23 @@ void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* })} NNVM_REGISTER_OP(_adamw_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_mp_adamw_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_multi_adamw_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", adamw::multiMPUpdate); NNVM_REGISTER_OP(_multi_mp_adamw_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", adamw::multiMPUpdate); } // namespace adamw diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index 482cbf6b8150..3702fed2a06a 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -82,7 +82,10 @@ void IndexArrayForwardGPU(const nnvm::NodeAttrs& attrs, } } -NNVM_REGISTER_OP(_contrib_index_array).set_attr("FCompute", IndexArrayForwardGPU); +NNVM_REGISTER_OP(_contrib_index_array) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", IndexArrayForwardGPU); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/multi_lamb.cu b/src/operator/contrib/multi_lamb.cu index 118ec6348ed7..c6bedfc861f8 100644 --- a/src/operator/contrib/multi_lamb.cu +++ b/src/operator/contrib/multi_lamb.cu @@ -268,9 +268,13 @@ void CallKernel2(Stream* s, } NNVM_REGISTER_OP(_multi_lamb_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) .set_attr("FCompute", MultiLAMBUpdate); NNVM_REGISTER_OP(_multi_mp_lamb_update) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) .set_attr("FCompute", MultiLAMBUpdate); } // namespace op diff --git a/src/operator/instance_norm.cu b/src/operator/instance_norm.cu index ca45dbbff386..ce11fbf3200d 100644 --- a/src/operator/instance_norm.cu +++ b/src/operator/instance_norm.cu @@ -28,9 +28,14 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(InstanceNorm).set_attr("FCompute", InstanceNormForward); +NNVM_REGISTER_OP(InstanceNorm) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", InstanceNormForward); NNVM_REGISTER_OP(_backward_instance_norm) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) .set_attr("FCompute", InstanceNormBackward); } // namespace op diff --git a/src/operator/leaky_relu.cu b/src/operator/leaky_relu.cu index d461949ed225..82ec59bfe907 100644 --- a/src/operator/leaky_relu.cu +++ b/src/operator/leaky_relu.cu @@ -28,9 +28,14 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(LeakyReLU).set_attr("FCompute", LeakyReLUCompute); +NNVM_REGISTER_OP(LeakyReLU) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", LeakyReLUCompute); NNVM_REGISTER_OP(_backward_LeakyReLU) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) .set_attr("FCompute", LeakyReLUGradCompute); } // namespace op diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 18f94cffd25b..0baa8e40c397 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -437,7 +437,6 @@ class DropoutOp { using namespace mshadow::expr; Stream* s = ctx.get_stream(); if (!this->dropout_passthrough_) { - this->dropout_passthrough_ = true; const TBlob& gdata = in_grad[dropout::kData]; const TBlob& grad = out_grad[dropout::kOut]; const TBlob& mask = out_data[dropout::kMask]; diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index d6c97f5f09fc..414b82edcc65 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -28,7 +28,26 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(Dropout).set_attr("FStatefulCompute", DropoutCompute); +NNVM_REGISTER_OP(Dropout) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool is_train) { + // Dropout is a passthrough during inference for all impls + if (!is_train) + return true; +#if MXNET_USE_CUDNN_DROPOUT + // cuDNN impl is compatible during training as well + const DropoutParam& param = + nnvm::get(attrs.parsed); + real_t pkeep = 1.0f - param.p; + bool cudnn_off = + param.cudnn_off && param.cudnn_off.value(); + bool cudnn_available = pkeep > 0 && !cudnn_off; + return cudnn_available; +#else + return false; +#endif // MXNET_USE_CUDNN_DROPOUT + }) + .set_attr("FStatefulCompute", DropoutCompute); NNVM_REGISTER_OP(_backward_Dropout) .set_attr("FStatefulCompute", DropoutGradCompute); diff --git a/src/operator/numpy/linalg/np_eig.cu b/src/operator/numpy/linalg/np_eig.cu index 1f89106bab47..a217b6d4e0e7 100644 --- a/src/operator/numpy/linalg/np_eig.cu +++ b/src/operator/numpy/linalg/np_eig.cu @@ -28,11 +28,17 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_eig).set_attr("FCompute", EigOpForward); +NNVM_REGISTER_OP(_npi_eig) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", EigOpForward); #if MXNET_USE_CUSOLVER == 1 -NNVM_REGISTER_OP(_npi_eigh).set_attr("FCompute", EighOpForward); +NNVM_REGISTER_OP(_npi_eigh) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", EighOpForward); #endif diff --git a/src/operator/numpy/linalg/np_eigvals.cu b/src/operator/numpy/linalg/np_eigvals.cu index dc03805c54d0..be00d8c991d9 100644 --- a/src/operator/numpy/linalg/np_eigvals.cu +++ b/src/operator/numpy/linalg/np_eigvals.cu @@ -28,11 +28,17 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_eigvals).set_attr("FCompute", EigvalsOpForward); +NNVM_REGISTER_OP(_npi_eigvals) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", EigvalsOpForward); #if MXNET_USE_CUSOLVER == 1 -NNVM_REGISTER_OP(_npi_eigvalsh).set_attr("FCompute", EigvalshOpForward); +NNVM_REGISTER_OP(_npi_eigvalsh) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", EigvalshOpForward); #endif diff --git a/src/operator/numpy/linalg/np_norm_backward.cu b/src/operator/numpy/linalg/np_norm_backward.cu index 24d8783dba33..23a021d00ce5 100644 --- a/src/operator/numpy/linalg/np_norm_backward.cu +++ b/src/operator/numpy/linalg/np_norm_backward.cu @@ -26,6 +26,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_backward_npi_norm) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { + const NumpyNormParam& param = + nnvm::get(attrs.parsed); + return param.axis.value().ndim() == 2; + }) .set_attr("FCompute", NumpyNormComputeBackward); } // namespace op diff --git a/src/operator/numpy/linalg/np_norm_forward.cu b/src/operator/numpy/linalg/np_norm_forward.cu index 89267632d898..7399727324d0 100644 --- a/src/operator/numpy/linalg/np_norm_forward.cu +++ b/src/operator/numpy/linalg/np_norm_forward.cu @@ -25,7 +25,14 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_norm).set_attr("FCompute", NumpyNormComputeForward); +NNVM_REGISTER_OP(_npi_norm) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { + const NumpyNormParam& param = + nnvm::get(attrs.parsed); + return param.axis.value().ndim() == 2; + }) + .set_attr("FCompute", NumpyNormComputeForward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_boolean_mask_assign.cu b/src/operator/numpy/np_boolean_mask_assign.cu index 10f8612a3ffb..216e8ff2b839 100644 --- a/src/operator/numpy/np_boolean_mask_assign.cu +++ b/src/operator/numpy/np_boolean_mask_assign.cu @@ -273,9 +273,13 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", NumpyBooleanAssignForwardGPU); NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", NumpyBooleanAssignForwardGPU); } // namespace op diff --git a/src/operator/numpy/np_constraint_check.cu b/src/operator/numpy/np_constraint_check.cu index 04a0a36f4043..26a5f0178c0b 100644 --- a/src/operator/numpy/np_constraint_check.cu +++ b/src/operator/numpy/np_constraint_check.cu @@ -38,6 +38,8 @@ void GetReduceOutput(mshadow::Stream* s, const TBlob& output_blob, boo } NNVM_REGISTER_OP(_npx_constraint_check) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", ConstraintCheckForward); } // namespace op diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index f2078146c78e..27858988432d 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -52,9 +52,15 @@ NNVM_REGISTER_OP(_npi_column_stack) NNVM_REGISTER_OP(_backward_np_column_stack) .set_attr("FCompute", NumpyColumnStackBackward); -NNVM_REGISTER_OP(_npi_tril_indices).set_attr("FCompute", TrilindicesOpForward); +NNVM_REGISTER_OP(_npi_tril_indices) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", TrilindicesOpForward); -NNVM_REGISTER_OP(_npi_roll).set_attr("FCompute", NumpyRollCompute); +NNVM_REGISTER_OP(_npi_roll) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", NumpyRollCompute); template <> void NumpyFlipForwardImpl(const OpContext& ctx, @@ -92,9 +98,15 @@ void NumpyFlipForwardImpl(const OpContext& ctx, }); } -NNVM_REGISTER_OP(_npi_flip).set_attr("FCompute", NumpyFlipForward); +NNVM_REGISTER_OP(_npi_flip) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", NumpyFlipForward); -NNVM_REGISTER_OP(_backward_npi_flip).set_attr("FCompute", NumpyFlipForward); +NNVM_REGISTER_OP(_backward_npi_flip) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) + .set_attr("FCompute", NumpyFlipForward); NNVM_REGISTER_OP(_npi_moveaxis).set_attr("FCompute", NumpyMoveaxisCompute); @@ -103,7 +115,22 @@ NNVM_REGISTER_OP(_npi_rollaxis).set_attr("FCompute", NumpyRollaxi NNVM_REGISTER_OP(_npi_rollaxis_backward) .set_attr("FCompute", NumpyRollaxisBackward); -NNVM_REGISTER_OP(_npi_rot90).set_attr("FCompute", NumpyRot90Compute); +NNVM_REGISTER_OP(_npi_rot90) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { + const auto& param = + nnvm::get(attrs.parsed); + // Should track code in NumpyRot90Compute() + int real_k(param.k); + real_k = real_k % 4; + if (real_k < 0) { + real_k += 4; + } + // Avoid NumpyRot90ComputeFlipIml(), + // which uses mshadow::Copy() + return real_k != 2; + }) + .set_attr("FCompute", NumpyRot90Compute); NNVM_REGISTER_OP(_npi_hsplit).set_attr("FCompute", HSplitOpForward); diff --git a/src/operator/numpy/np_nonzero_op.cu b/src/operator/numpy/np_nonzero_op.cu index 1499030dbe9b..597331e458ff 100644 --- a/src/operator/numpy/np_nonzero_op.cu +++ b/src/operator/numpy/np_nonzero_op.cu @@ -115,6 +115,8 @@ NNVM_REGISTER_OP(_npx_nonzero) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs& attrs, const bool) { return false; }) .set_attr("FComputeEx", NonzeroForwardGPU); } // namespace op diff --git a/src/operator/numpy/np_pad_op.cu b/src/operator/numpy/np_pad_op.cu index 01a7035ab42d..1b9f4f4d5a86 100644 --- a/src/operator/numpy/np_pad_op.cu +++ b/src/operator/numpy/np_pad_op.cu @@ -28,9 +28,17 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_pad).set_attr("FCompute", NumpyPadOpForward); +NNVM_REGISTER_OP(_npi_pad) + // Incompatible due to Copy(xpu_tensor, cpu_tensor) in NumpyPadOpForward + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyPadOpForward); -NNVM_REGISTER_OP(_backward_npi_pad).set_attr("FCompute", NumpyPadOpBackward); +NNVM_REGISTER_OP(_backward_npi_pad) + // Incompatible due to Copy(xpu_tensor, cpu_tensor) in NumpyPadOpBackward + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyPadOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_percentile_op.cu b/src/operator/numpy/np_percentile_op.cu index 13d076dd9b53..2dcc8294bb55 100644 --- a/src/operator/numpy/np_percentile_op.cu +++ b/src/operator/numpy/np_percentile_op.cu @@ -52,7 +52,10 @@ bool CheckInvalidInput(mshadow::Stream* s, return is_valid == 0; } -NNVM_REGISTER_OP(_npi_percentile).set_attr("FCompute", NumpyPercentileForward); +NNVM_REGISTER_OP(_npi_percentile) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyPercentileForward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_bernoulli_op.cu b/src/operator/numpy/random/np_bernoulli_op.cu index 8cdceb5bb4c8..eee89c1ea8d4 100644 --- a/src/operator/numpy/random/np_bernoulli_op.cu +++ b/src/operator/numpy/random/np_bernoulli_op.cu @@ -27,7 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_bernoulli).set_attr("FCompute", NumpyBernoulliForward); +NNVM_REGISTER_OP(_npi_bernoulli) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyBernoulliForward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_exponential_op.cu b/src/operator/numpy/random/np_exponential_op.cu index 60809fbb91c5..8ad738639eae 100644 --- a/src/operator/numpy/random/np_exponential_op.cu +++ b/src/operator/numpy/random/np_exponential_op.cu @@ -28,6 +28,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_exponential) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", NumpyExponentialForward); NNVM_REGISTER_OP(_backward_broadcast_exponential) diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu index 7e3cabc3a83f..0191fd597ec6 100644 --- a/src/operator/numpy/random/np_gamma_op.cu +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -28,7 +28,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_gamma).set_attr("FCompute", NumpyGammaForward); +NNVM_REGISTER_OP(_npi_gamma) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyGammaForward); NNVM_REGISTER_OP(_backward_gamma_sample).set_attr("FCompute", NumpyGammaGrad); diff --git a/src/operator/numpy/random/np_multinomial_op.cu b/src/operator/numpy/random/np_multinomial_op.cu index 083b410a2d8a..575ad08b8184 100644 --- a/src/operator/numpy/random/np_multinomial_op.cu +++ b/src/operator/numpy/random/np_multinomial_op.cu @@ -41,6 +41,8 @@ void CheckPvalGPU(const OpContext& ctx, DType* input, int prob_length) { } NNVM_REGISTER_OP(_npi_multinomial) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", NumpyMultinomialForward); } // namespace op diff --git a/src/operator/numpy/random/np_normal_op.cu b/src/operator/numpy/random/np_normal_op.cu index db8746165c6e..525a0e14a4e4 100644 --- a/src/operator/numpy/random/np_normal_op.cu +++ b/src/operator/numpy/random/np_normal_op.cu @@ -27,12 +27,18 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_normal).set_attr("FCompute", NumpyNormalForward); +NNVM_REGISTER_OP(_npi_normal) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyNormalForward); NNVM_REGISTER_OP(_backward_broadcast_normal) .set_attr("FCompute", NormalReparamBackward); -NNVM_REGISTER_OP(_npi_normal_n).set_attr("FCompute", NumpyNormalForward); +NNVM_REGISTER_OP(_npi_normal_n) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyNormalForward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_pareto_op.cu b/src/operator/numpy/random/np_pareto_op.cu index 7618d2871099..82fcd1f4d066 100644 --- a/src/operator/numpy/random/np_pareto_op.cu +++ b/src/operator/numpy/random/np_pareto_op.cu @@ -27,7 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_pareto).set_attr("FCompute", NumpyParetoForward); +NNVM_REGISTER_OP(_npi_pareto) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyParetoForward); NNVM_REGISTER_OP(_backward_broadcast_pareto) .set_attr("FCompute", ParetoReparamBackward); diff --git a/src/operator/numpy/random/np_power_op.cu b/src/operator/numpy/random/np_power_op.cu index 290442037eee..f7a6686769d0 100644 --- a/src/operator/numpy/random/np_power_op.cu +++ b/src/operator/numpy/random/np_power_op.cu @@ -27,7 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_powerd).set_attr("FCompute", NumpyPowerForward); +NNVM_REGISTER_OP(_npi_powerd) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyPowerForward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_rayleigh_op.cu b/src/operator/numpy/random/np_rayleigh_op.cu index 586f17481e30..f67a2fe36ad7 100644 --- a/src/operator/numpy/random/np_rayleigh_op.cu +++ b/src/operator/numpy/random/np_rayleigh_op.cu @@ -27,7 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_rayleigh).set_attr("FCompute", NumpyRayleighForward); +NNVM_REGISTER_OP(_npi_rayleigh) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyRayleighForward); NNVM_REGISTER_OP(_backward_broadcast_rayleigh) .set_attr("FCompute", RayleighReparamBackward); diff --git a/src/operator/numpy/random/np_weibull_op.cu b/src/operator/numpy/random/np_weibull_op.cu index 658be16e6333..4495bab39206 100644 --- a/src/operator/numpy/random/np_weibull_op.cu +++ b/src/operator/numpy/random/np_weibull_op.cu @@ -27,7 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npi_weibull).set_attr("FCompute", NumpyWeibullForward); +NNVM_REGISTER_OP(_npi_weibull) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", NumpyWeibullForward); NNVM_REGISTER_OP(_backward_broadcast_weibull) .set_attr("FCompute", WeibullReparamBackward); diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 7fdc047630cb..5099301a1e4f 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -115,7 +115,10 @@ void ShapeComputeGPU(const nnvm::NodeAttrs& attrs, mshadow::Stream::GetStream(s)); } -NNVM_REGISTER_OP(shape_array).set_attr("FCompute", ShapeComputeGPU); +NNVM_REGISTER_OP(shape_array) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", ShapeComputeGPU); void SizeComputeGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 90504301cc22..992054f860ef 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -957,7 +957,10 @@ NNVM_REGISTER_OP(batch_take).set_attr("FCompute", BatchTakeOpForw NNVM_REGISTER_OP(one_hot).set_attr("FCompute", OneHotOpForward); -NNVM_REGISTER_OP(gather_nd).set_attr("FCompute", GatherNDForwardGPU); +NNVM_REGISTER_OP(gather_nd) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", GatherNDForwardGPU); NNVM_REGISTER_OP(scatter_nd).set_attr("FCompute", ScatterNDForward); diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 1f16e2d58251..a32143a31d60 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -88,6 +88,8 @@ NNVM_REGISTER_OP(_backward_linalg_maketrian) .set_attr("FCompute", LaOpBackward); NNVM_REGISTER_OP(_linalg_potri) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_potri) @@ -99,32 +101,56 @@ NNVM_REGISTER_OP(_linalg_inverse) NNVM_REGISTER_OP(_backward_linalg_inverse) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(_linalg_det).set_attr("FCompute", LaOpDetForward); +NNVM_REGISTER_OP(_linalg_det) + // Incompatibility comes from allocs made in linalg_batch_getrf(), called by det::op() + // see https://github.com/apache/incubator-mxnet/issues/19353 + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", LaOpDetForward); NNVM_REGISTER_OP(_backward_linalg_det) + // Incompatibility comes from allocs made in linalg_batch_getri(), + // called by linalg_batch_det_backward_helper, called by det_backward::op() + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpDetBackward); NNVM_REGISTER_OP(_linalg_slogdet) + // Incompatibility comes from allocs made in linalg_batch_getrf(), + // called by slogdet::op(). + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpDetForward); NNVM_REGISTER_OP(_backward_linalg_slogdet) + // Incompatibility comes from allocs made in linalg_batch_getri(), + // called by linalg_batch_det_backward_helper, called by slogdet_backward::op() + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpDetBackward); #if MXNET_USE_CUSOLVER == 1 NNVM_REGISTER_OP(_linalg_potrf) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_potrf) .set_attr("FCompute", LaOpBackward); NNVM_REGISTER_OP(_linalg_gelqf) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) .set_attr("FCompute", LaOpForward); NNVM_REGISTER_OP(_backward_linalg_gelqf) .set_attr("FCompute", LaOpBackward); -NNVM_REGISTER_OP(_linalg_syevd).set_attr("FCompute", LaOpForwSyevd); +NNVM_REGISTER_OP(_linalg_syevd) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", LaOpForwSyevd); NNVM_REGISTER_OP(_backward_linalg_syevd) .set_attr("FCompute", LaOpBackwSyevd); diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index b5bd1c96d25b..00007bd2e602 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -412,9 +412,15 @@ NNVM_REGISTER_OP(tile).set_attr("FCompute", TileOpForward); NNVM_REGISTER_OP(_backward_tile).set_attr("FCompute", TileOpBackward); -NNVM_REGISTER_OP(reverse).set_attr("FCompute", ReverseOpForward); +NNVM_REGISTER_OP(reverse) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", ReverseOpForward); -NNVM_REGISTER_OP(_backward_reverse).set_attr("FCompute", ReverseOpForward); +NNVM_REGISTER_OP(_backward_reverse) + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", ReverseOpForward); NNVM_REGISTER_OP(stack).set_attr("FCompute", StackOpForward); @@ -429,9 +435,17 @@ NNVM_REGISTER_OP(depth_to_space).set_attr("FCompute", DepthToSpac NNVM_REGISTER_OP(space_to_depth).set_attr("FCompute", SpaceToDepthOpForward); -NNVM_REGISTER_OP(_split_v2).set_attr("FCompute", SplitOpForwardGPU); - -NNVM_REGISTER_OP(_split_v2_backward).set_attr("FCompute", SplitOpBackward); +NNVM_REGISTER_OP(_split_v2) + // Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpForwardImpl + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", SplitOpForwardGPU); + +NNVM_REGISTER_OP(_split_v2_backward) + // Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpBackwardImpl + .set_attr("FIsCUDAGraphsCompatible", + [](const NodeAttrs&, const bool) { return false; }) + .set_attr("FCompute", SplitOpBackward); } // namespace op } // namespace mxnet diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 492055f4aef2..20a7f26e8686 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -18,6 +18,7 @@ import sys import os import time +import random import mxnet as mx import multiprocessing as mp from mxnet.test_utils import check_consistency, set_default_device, assert_almost_equal, rand_ndarray, environment @@ -28,7 +29,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import assert_raises_cudnn_not_satisfied, run_in_spawned_process +from common import assert_raises_cudnn_not_satisfied, run_in_spawned_process, random_seed from test_gluon import * from test_loss import * from test_numpy_loss import * @@ -595,3 +596,111 @@ def test_cudnn_dropout_reproducibility(): assert_almost_equal(a.grad, b.grad) +@mx.util.use_np +def test_cuda_graphs(): + class GraphTester(gluon.HybridBlock): + def __init__(self, function_to_test, **kwargs): + super(GraphTester, self).__init__(**kwargs) + self.f = function_to_test() + + def forward(self, *args): + # We need to isolate the operation to be fully inside the graph + # in order for graphs usage to be possible + copied_args = [mx.np.copy(a) for a in args] + outputs = self.f(*copied_args) + if isinstance(outputs, (list, tuple)): + return [mx.np.copy(o) for o in outputs] + else: + return mx.np.copy(outputs) + + class TestDesc: + def __init__(self, name, f, num_inputs=1, input_dim=4): + self.name = name + self.f = f + self.num_inputs = num_inputs + self.input_dim = input_dim + + def generate_inputs(self): + shape = tuple(_np.random.randint(4, 11, size=self.input_dim)) + ret = [mx.np.random.uniform(size=shape) for _ in range(self.num_inputs)] + for r in ret: + r.attach_grad() + return ret + + tested_ops = [ + TestDesc('add', lambda: (lambda x, y: x + y), num_inputs = 2), + TestDesc('add_scalar', lambda: (lambda x: x + 0.5)), + TestDesc('Conv', lambda: mx.gluon.nn.Conv2D(channels=32, kernel_size=(1,1))), + TestDesc('ConvTranspose', lambda: mx.gluon.nn.Conv2DTranspose(channels=32, kernel_size=(1,1))), + TestDesc('Dense', lambda: mx.gluon.nn.Dense(units=128)), + TestDesc('Activation', lambda: mx.gluon.nn.Activation('tanh')), + TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)), + TestDesc('Flatten', lambda: mx.gluon.nn.Flatten()), + TestDesc('MaxPool', lambda: mx.gluon.nn.MaxPool2D()), + TestDesc('AvgPool', lambda: mx.gluon.nn.AvgPool2D()), + TestDesc('GlobalMaxPool', lambda: mx.gluon.nn.GlobalMaxPool2D()), + TestDesc('GlobalAvgPool', lambda: mx.gluon.nn.GlobalAvgPool2D()), + TestDesc('ReflectionPad2D', lambda: mx.gluon.nn.ReflectionPad2D()), + TestDesc('BatchNorm', lambda: mx.gluon.nn.BatchNorm()), + TestDesc('InstanceNorm', lambda: mx.gluon.nn.InstanceNorm()), + TestDesc('LayerNorm', lambda: mx.gluon.nn.LayerNorm()), + TestDesc('LeakyReLU', lambda: mx.gluon.nn.LeakyReLU(0.1)), + TestDesc('PReLU', lambda: mx.gluon.nn.PReLU()), + TestDesc('ELU', lambda: mx.gluon.nn.ELU()), + TestDesc('SELU', lambda: mx.gluon.nn.SELU()), + TestDesc('Swish', lambda: mx.gluon.nn.Swish()), + ] + + N = 10 + + with environment({'MXNET_ENABLE_CUDA_GRAPHS': '1', + 'MXNET_USE_FUSION': '0'}): + device = mx.gpu(0) + for test_desc in tested_ops: + print("Testing ", test_desc.name) + inputs = test_desc.generate_inputs() + inputsg = [i.copy() for i in inputs] + for i in inputsg: + i.attach_grad() + seed = random.randint(0, 10000) + net = GraphTester(test_desc.f) + netg = GraphTester(test_desc.f) + + # initialize parameters + net.initialize(device=device) + netg.initialize(device=device) + + net(*inputs) + + for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()): + p2.set_data(p1.data()) + + netg.hybridize(static_alloc=True, static_shape=True) + + print("Testing inference mode") + with random_seed(seed): + for _ in range(N): + assert_almost_equal(net(*inputs), netg(*inputsg)) + + mx.npx.waitall() + print("Testing training mode") + for _ in range(N): + with random_seed(seed): + with mx.autograd.record(): + out = net(*inputs) + out.backward() + + with random_seed(seed): + with mx.autograd.record(): + outg = netg(*inputsg) + outg.backward() + + assert_almost_equal(out, outg) + for i, ig in zip(inputs, inputsg): + assert_almost_equal(i.grad, ig.grad) + + for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()): + assert_almost_equal(p1.data(), p2.data()) + if p1.grad_req != 'null': + assert_almost_equal(p1.grad(), p2.grad()) + mx.npx.waitall()