Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f69b4a0
[1.x][FEATURE] CUDA graphs support (#19142)
ptrendx Sep 19, 2020
9ca54ab
Fix compile and test_cuda_graphs
DickJC123 Jun 10, 2021
0b4ed47
Fix lint
DickJC123 Jun 10, 2021
cc69486
Mark more ops as not CUDA Graphs compatible
DickJC123 Jun 25, 2021
e79c111
Mark some linalg ops as not CUDA Graphs compatible
DickJC123 May 28, 2021
4a2dae4
Marked 2 ops CUDA Graphs incompatible due to cpu->gpu copy
DickJC123 May 28, 2021
64e8555
Mark cuDNN Dropout as fully CUDA Graphs compatible. Reenable tests.
DickJC123 May 29, 2021
78215fa
clang-tidy fixes
DickJC123 Feb 17, 2022
a558922
More clang-tidy fixes
DickJC123 Feb 17, 2022
eaa7fc7
Avoid CUDA_CALL(e): improper macro expansion
DickJC123 Feb 18, 2022
c44cfc6
Add compile guard to Dropout's FIsCUDAGraphsCompatible def
DickJC123 Feb 18, 2022
5a2f847
Temporarily add '-s' to pytest serial tests
DickJC123 Feb 18, 2022
3b58b49
Fix DropoutOp.dropout_passthrough_ handling for CUDA Graphs
DickJC123 Feb 27, 2022
0d62083
Adapt test_gluon_gpu.py::test_cuda_graphs for gluon2.0
DickJC123 Feb 27, 2022
9517011
Merge remote-tracking branch 'mxnet/master' into backport_cuda_graphs
DickJC123 Feb 27, 2022
3591f50
Create CUDA Graph 'dot' files if MXNET_CUDA_GRAPHS_DBG_FILE=<file_pre…
DickJC123 Feb 27, 2022
0e105ec
Fix clang-tidy
DickJC123 Feb 27, 2022
d8d65c9
Fix more clang-tidy
DickJC123 Feb 27, 2022
26182fb
Skip test_np_standard_binary_funcs test of 0-dim array broadcast
DickJC123 Feb 17, 2022
6cc8ab8
Improve test_rnn_layers_fp{16,32} invocation
DickJC123 Feb 21, 2022
d06b139
Run test_rnn_layers_fp32 only when cuDNN is present
DickJC123 Feb 21, 2022
c5198c2
Fix potential out-of-bounds write in count_sketch.cu
DickJC123 Feb 22, 2022
e013a85
Add temp output to debug centos crash
DickJC123 Mar 1, 2022
7651c97
Mark InstanceNorm and LeakyRELU as not CUDA Graphs compatible
DickJC123 Mar 4, 2022
e704022
Ops calling FStatefulCompute* are not CUDA Graphs compatible by default
DickJC123 Mar 4, 2022
da59cff
Fix clang-tidy
DickJC123 Mar 4, 2022
45bb7b8
Revert "Add temp output to debug centos crash"
DickJC123 Mar 10, 2022
b7ecce2
Quiet 'unused variable' compilation warning
DickJC123 Mar 10, 2022
c609cce
Trigger CI
DickJC123 Mar 10, 2022
eaf61a0
Check of FCreateOpState removed given new check for FStatefulCompute*
DickJC123 Mar 15, 2022
f451027
Revert "Temporarily add '-s' to pytest serial tests"
DickJC123 Mar 15, 2022
0c1b645
Merge remote-tracking branch 'mxnet/master' into backport_cuda_graphs
DickJC123 Mar 15, 2022
d9d323d
Merge remote-tracking branch 'mxnet/master' into backport_cuda_graphs
DickJC123 Mar 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD
- Values: Int ```(default=<value of MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN>)```
- The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass.
* MXNET_ENABLE_CUDA_GRAPHS
- Values: 0(false) or 1(true) ```(default=0)```
- If set to `1`, MXNet will utilize CUDA graphs when executing models on the GPU when possible.
- For CUDA graphs execution, one needs to use either symbolic model or Gluon model hybridized with options `static_alloc` and `static_shape` set to True.
* MXNET_CUDA_GRAPHS_VERBOSE
- Values: 0(false) or 1(true) ```(default=0)```
- If set to `1`, CUDA graphs executor will provide information about the graph being captured and executed.
* MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES
- Values: Int ```(default=0)```
- The maximum number of log messages generated by CUDA graphs executor.
* MXNET_CUDA_GRAPHS_DBG_FILE
- Values: String ```(default='', to indicate no debug dot files should be created)```
- The file prefix for '.dot' files for each graph created. Full path is <prefix>-devN-{trn,inf}.<graphId>.dot .
* MXNET_CUDA_GRAPHS_DBG_FILE_FLAGS
- Values: Int ```(default=<most verbose setting- includes all info>)```
- A bitmask to enable various types of info in the debug '.dot' files. See cudaGraphDebugDotFlags in the CUDA runtime API doc for details.

## Control the Data Communication

Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ using FNeedCalibrateInput = std::function<std::vector<int>(const NodeAttrs& attr
*/
using FNeedCalibrateOutput = std::function<std::vector<int>(const NodeAttrs& attrs)>;

#if MXNET_USE_CUDA

/*!
* \brief Register a function to determine if
* the operator implementation is compatible
* with CUDA graphs. This requires the execution
* to stay the same as long as the shape and type
* of input stays the same.
*/
using FIsCUDAGraphsCompatible = std::function<bool(const NodeAttrs& attrs, const bool is_train)>;

#endif

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
66 changes: 40 additions & 26 deletions src/imperative/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ namespace exec {
// FComputeExecutor and FStatefulComputeExecutor inherit from this class
class StorageFallbackOpExecutor : public OpExecutor {
public:
explicit StorageFallbackOpExecutor(std::vector<uint32_t> mutate_idx)
: mutate_idx_(std::move(mutate_idx)) {}
explicit StorageFallbackOpExecutor(const NodeAttrs& attrs,
DispatchMode dispatch_mode,
std::vector<uint32_t> mutate_idx)
: OpExecutor(attrs, dispatch_mode), mutate_idx_(std::move(mutate_idx)) {}

void Setup() override {
init_ = false;
Expand Down Expand Up @@ -146,11 +148,13 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
return state_;
}

explicit StatefulComputeExecutor(OpStatePtr state,
explicit StatefulComputeExecutor(const NodeAttrs& attrs,
DispatchMode dispatch_mode,
OpStatePtr state,
FStatefulCompute fcompute,
ExecType exec_type,
const std::vector<uint32_t>& mutate_idx)
: StorageFallbackOpExecutor(mutate_idx),
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
state_(std::move(state)),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
Expand All @@ -168,7 +172,7 @@ class StatefulComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
std::vector<NDArray>* pInArray = &in_array;
CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs_);
CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs);
fcompute_(state_, op_ctx, *pInArray, req, out_array);
}

Expand All @@ -186,17 +190,17 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_;
}

explicit StatefulComputeExExecutor(NodeAttrs attrs,
explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
DispatchMode dispatch_mode,
OpStatePtr state,
FStatefulComputeEx fcompute,
ExecType exec_type)
: attrs_(std::move(attrs)),
: OpExecutor(attrs, dispatch_mode),
state_(std::move(state)),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}

private:
NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
Expand All @@ -210,25 +214,24 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
PreFCompute(is_gpu);
fcompute_(attrs_, op_ctx, in_data_, req, out_data_);
fcompute_(attrs, op_ctx, in_data_, req, out_data_);
PostFCompute(is_gpu);
}

ExecType exec_type() const override {
return exec_type_;
}

explicit FComputeExecutor(NodeAttrs attrs,
explicit FComputeExecutor(const NodeAttrs& attrs,
DispatchMode dispatch_mode,
FCompute fcompute,
ExecType exec_type,
const std::vector<uint32_t>& mutate_idx)
: StorageFallbackOpExecutor(mutate_idx),
attrs_(std::move(attrs)),
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}

private:
NodeAttrs attrs_;
FCompute fcompute_;
ExecType exec_type_;
};
Expand All @@ -240,8 +243,8 @@ class FComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
std::vector<NDArray>* pInArray = &in_array;
CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs_);
fcompute_(attrs_, op_ctx, *pInArray, req, out_array);
CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs);
fcompute_(attrs, op_ctx, *pInArray, req, out_array);
}

void Setup() override {}
Expand All @@ -250,11 +253,13 @@ class FComputeExExecutor : public OpExecutor {
return exec_type_;
}

explicit FComputeExExecutor(NodeAttrs attrs, FComputeEx fcompute, ExecType exec_type)
: attrs_(std::move(attrs)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {}
explicit FComputeExExecutor(const NodeAttrs& attrs,
DispatchMode dispatch_mode,
FComputeEx fcompute,
ExecType exec_type)
: OpExecutor(attrs, dispatch_mode), fcompute_(std::move(fcompute)), exec_type_(exec_type) {}

private:
NodeAttrs attrs_;
FComputeEx fcompute_;
ExecType exec_type_;
};
Expand Down Expand Up @@ -309,14 +314,15 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state,
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
inode.source->attrs, state, fcompute_ex, exec_type);
inode.source->attrs, dispatch_modes[i], state, fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute =
common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute, exec_type, mutate_index);
ret[i] = std::make_shared<StatefulComputeExecutor>(
inode.source->attrs, dispatch_modes[i], state, fcompute, exec_type, mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
Expand All @@ -327,25 +333,33 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state,
common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, exec_type);
ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs,
dispatch_modes[i],
ret[fwd_id].get()->state(),
fcompute_ex,
exec_type);
} else {
FStatefulCompute fcompute =
common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(
ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dispatch_modes[i],
ret[fwd_id].get()->state(),
fcompute,
exec_type,
mutate_index);
}
} else {
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<FComputeExExecutor>(inode.source->attrs, fcomp_ex, exec_type);
ret[i] = std::make_shared<FComputeExExecutor>(
inode.source->attrs, dispatch_modes[i], fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
inode.source->attrs, fcompute, exec_type, mutate_index);
inode.source->attrs, dispatch_modes[i], fcompute, exec_type, mutate_index);
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
}
Expand Down
Loading