diff --git a/src/operator/nn/mkldnn/mkldnn_concat-inl.h b/src/operator/nn/mkldnn/mkldnn_concat-inl.h index d3866cc3d23d..aa2b50b42679 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_concat-inl.h @@ -20,13 +20,13 @@ /*! * \file mkldnn_concat-inl.h * \brief - * \author Wenting Jiang + * \author */ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include "../concat-inl.h" @@ -40,25 +40,20 @@ class MKLDNNConcatFwd { public: mkldnn::concat::primitive_desc fwd_pd; - MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) - : fwd_pd(concat_dim, data_md) { - data.resize(data_md.size()); + MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) + : fwd_pd(concat_dim, data_md, CpuEngine::Get()->get_engine()) { + fwd_ = std::make_shared(fwd_pd); } - void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); - const mkldnn::concat &GetFwd() const; private: - std::shared_ptr fwd; - std::vector> data; - std::vector data_mem; - std::shared_ptr out; + std::shared_ptr fwd_; }; static MKLDNNConcatFwd &GetConcatForward( int concat_dim, const std::vector &in_data, - const std::vector &data_md) { + const std::vector &data_md) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -79,5 +74,5 @@ static MKLDNNConcatFwd &GetConcatForward( } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 7b266efc2a14..c8843a575945 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -20,37 +20,16 @@ /*! * \file mkldnn_concat.cc * \brief - * \author Wenting Jiang + * \author */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "mkldnn_concat-inl.h" namespace mxnet { namespace op { -void MKLDNNConcatFwd::SetNewMem(const std::vector &in_data, - const mkldnn::memory &output) { - CHECK_EQ(in_data.size(), data.size()); - for (size_t i = 0; i < data.size(); i++) { - if (this->data[i] == nullptr) { - this->data[i] = std::shared_ptr( - new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); - this->data_mem.push_back(*this->data[i]); - } else { - this->data[i]->set_data_handle(in_data[i]->get_data_handle()); - } - } - if (this->out == nullptr) - this->out = std::shared_ptr( - new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (this->fwd == nullptr) fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out)); -} - -const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd; } +const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd_; } void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -58,24 +37,28 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); const ConcatParam& param = nnvm::get(attrs.parsed); - int num_in_data = param.num_args; - int concat_dim = param.dim; - std::vector data_md; + const int num_in_data = param.num_args; + const int concat_dim = param.dim; + std::vector data_md; std::vector data_mem; data_md.reserve(num_in_data); data_mem.reserve(num_in_data); for (int i = 0; i < num_in_data; i++) { const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData(); - mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc(); - data_md.push_back(tmp_pd); + mkldnn::memory::desc tmp_md = tmp_mem->get_desc(); + data_md.push_back(tmp_md); data_mem.push_back(tmp_mem); } MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut], - fwd.fwd_pd.dst_primitive_desc(), + fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + std::unordered_map net_args; + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + for (int i = 0; i < num_in_data; i++) { + net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]}); + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[concat_enum::kOut], out_mem); MKLDNNStream::Get()->Submit(); } @@ -86,11 +69,9 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& outputs) { TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); const ConcatParam& param = nnvm::get(attrs.parsed); - int num_in_data = param.num_args; - int axis_ = param.dim; - auto engine = CpuEngine::Get()->get_engine(); - auto gz_mem = inputs[0].GetMKLDNNData(); - mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc(); + const int num_in_data = param.num_args; + const int axis = param.dim; + const auto gradz_mem = inputs[0].GetMKLDNNData(); /* init the offset */ mkldnn::memory::dims offsets(outputs[0].shape().ndim()); for (auto &v : offsets) { @@ -99,19 +80,22 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, for (int i = 0; i < num_in_data; i++) { mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end()); - auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc(); - auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); - // create view from gy to gxs[i] - std::shared_ptr view_pd; - view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets)); - // create reorder primitive from gy to gxs[i] - mkldnn::reorder::primitive_desc reorder_pd( - view_pd.get()->dst_primitive_desc(), diff_src_mpd); - offsets[axis_] += diff_src_tz[axis_]; - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder( - reorder_pd, *gz_mem, *gradi_mem_.second)); - CommitOutput(outputs[i], gradi_mem_); + auto diff_src_md = outputs[i].GetMKLDNNData()->get_desc(); + auto gradi_mem = CreateMKLDNNMem(outputs[i], diff_src_md, req[i]); + + auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets); + auto from_mem = new mkldnn::memory(from_md, gradz_mem->get_engine(), + gradz_mem->get_data_handle()); + offsets[axis] += diff_src_tz[axis]; + + std::unordered_map net_args({ + {MKLDNN_ARG_FROM, *gradz_mem}, + {MKLDNN_ARG_TO, *gradi_mem.second} + }); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*from_mem, *gradi_mem.second), net_args); + CommitOutput(outputs[i], gradi_mem); } + MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc index a7c280e1e713..b16923014c9c 100644 --- a/src/operator/nn/mkldnn/mkldnn_copy.cc +++ b/src/operator/nn/mkldnn/mkldnn_copy.cc @@ -18,16 +18,15 @@ */ /*! - * \file mkldnn_softmax.cc + * \file mkldnn_copy.cc * \brief - * \author Da Zheng + * \author */ -#include "../softmax-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 namespace mxnet { namespace op { @@ -47,9 +46,9 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, // We should try and force the input memory has the same format // as the input output. If not, we'll have to reorder memory. auto out_mem = out_data.GetMKLDNNData(); - in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc()); + in_mem = data.GetMKLDNNData(out_mem ->get_desc()); if (in_mem == nullptr) - in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc()); + in_mem = data.GetMKLDNNDataReorder(out_mem->get_desc()); MKLDNNSum(*out_mem, *in_mem, *out_mem); } else { const_cast(out_data).CopyFrom(*in_mem); diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index ec97c9306076..bba76a3cc570 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -44,25 +44,17 @@ namespace mxnet { namespace op { #if MXNET_USE_MKLDNN == 1 -/* For sum */ -void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &inputs, const OpReqType &req, - const NDArray &out_data); - -/* For copy */ -void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); +void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); -/* For concat */ -void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data); -void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); +void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); #endif #if MXNET_USE_MKLDNN == 100 @@ -122,6 +114,26 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &c const std::vector &req, const std::vector &out_data); +/* For sum */ +void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &inputs, const OpReqType &req, + const NDArray &out_data); + +/* For copy */ +void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); + +/* For concat */ +void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out); diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index 69b6728fc0b5..2fa6e8e34164 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -54,7 +54,7 @@ void MKLDNNSum(const mkldnn::memory &arr1, in_mem2 = tmp_memory2; } mkldnn::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine()); - std::unordered_map args = { + mkldnn_args_map_t args = { { MKLDNN_ARG_MULTIPLE_SRC, *in_mem1 }, { MKLDNN_ARG_MULTIPLE_SRC + 1, *in_mem2 }, { MKLDNN_ARG_DST, out }, @@ -62,33 +62,25 @@ void MKLDNNSum(const mkldnn::memory &arr1, MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::sum(sum_pd), args); } -#endif - -#if MXNET_USE_MKLDNN == 1 class MKLDNNSumFwd { public: mkldnn::sum::primitive_desc fwd_pd; MKLDNNSumFwd(const std::vector &scales, - const std::vector &data_md) - : fwd_pd(scales, data_md) { - data_.resize(data_md.size()); + const std::vector &data_md) + : fwd_pd(scales, data_md, CpuEngine::Get()->get_engine()) { + fwd_ = std::make_shared(fwd_pd); } - void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); - const mkldnn::sum &GetFwd() const { return *fwd_; } private: std::shared_ptr fwd_; - std::vector> data_; - std::vector data_mem_; - std::shared_ptr out_; }; static MKLDNNSumFwd &GetSumForward( const std::vector &scales, const std::vector &in_data, - const std::vector &data_md) { + const std::vector &data_md) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -105,35 +97,12 @@ static MKLDNNSumFwd &GetSumForward( return it->second; } -void MKLDNNSumFwd::SetNewMem(const std::vector &in_data, - const mkldnn::memory &output) { - auto num_inputs = data_.size(); - CHECK_EQ(in_data.size(), num_inputs); - for (index_t i = 0; i < static_cast(num_inputs); ++i) { - if (this->data_[i] == nullptr) { - this->data_[i] = std::shared_ptr( - new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); - this->data_mem_.push_back(*this->data_[i]); - } else { - this->data_[i]->set_data_handle(in_data[i]->get_data_handle()); - } - } - if (this->out_ == nullptr) - this->out_ = std::shared_ptr( - new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (this->fwd_ == nullptr) - this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_)); -} - void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const OpReqType &req, const NDArray &out_data) { TmpMemMgr::Get()->Init(ctx.requested[0]); - auto num_inputs = inputs.size(); - std::vector data_md; + const int num_inputs = inputs.size(); + std::vector data_md; std::vector data_mem; std::vector scales(num_inputs, 1); std::vector in_bufs(num_inputs); @@ -141,7 +110,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, data_md.reserve(num_inputs); data_mem.reserve(num_inputs); - for (index_t i = 0; i < static_cast(num_inputs); ++i) { + for (int i = 0; i < num_inputs; ++i) { const mkldnn::memory *in_mem; if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) { in_bufs[i] = inputs[i].Reorder2Default(); @@ -150,18 +119,22 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, in_bufs[i] = inputs[i]; in_mem = inputs[i].GetMKLDNNData(); } - mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc(); - data_md.push_back(tmp_pd); + mkldnn::memory::desc tmp_md = in_mem->get_desc(); + data_md.push_back(tmp_md); data_mem.push_back(in_mem); } MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data, - fwd.fwd_pd.dst_primitive_desc(), + fwd.fwd_pd.dst_desc(), req, &in_bufs[0]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + mkldnn_args_map_t net_args; + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + for (int i = 0; i < num_inputs; ++i) { + net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]}); + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data, out_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index c5e30c68de7e..e1aad91b0bac 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -43,7 +43,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) { MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]); return; @@ -67,7 +67,7 @@ static inline bool ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); bool ret = ElemwiseBinaryOp::PreferDenseStorageType( attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask && !MKLDNNEnvSet()) { *dispatch_mode = DispatchMode::kFComputeFallback; } else if (dev_mask == mshadow::cpu::kDevMask @@ -82,7 +82,7 @@ static inline bool ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .set_attr("FInferStorageType", ElemwiseAddStorageType) .set_attr("FCompute", ElemwiseBinaryOp::Compute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) #endif .set_attr("FComputeEx", ElemwiseAddEx) @@ -120,7 +120,7 @@ static void _backward_ElemwiseAddEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (inputs[0].IsMKLDNNData()) { MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); MKLDNNCopy(attrs, ctx, inputs[0], req[1], outputs[1]); @@ -145,7 +145,7 @@ static inline bool ElemwiseAddBackwardStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 2); bool ret = ElemwiseStorageType<1, 2, true, true, true>(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask && !MKLDNNEnvSet()) { *dispatch_mode = DispatchMode::kFComputeFallback; } else if (dev_mask == mshadow::cpu::kDevMask) { @@ -164,7 +164,7 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 2ffe3eaa233d..496b2a65517d 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -568,5 +568,73 @@ def test_weight_async_reorder(): for output in mod.get_outputs(): output.wait_to_read() +@with_seed() +def test_concat(): + def ref_concat(a, b, axis): + return np.concatenate((a, b), axis=axis) + + a_sym = mx.sym.Variable("a") + b_sym = mx.sym.Variable("b") + dshape = rand_shape_nd(4) + a_shape = tuple(dshape) + b_shape = tuple(dshape) + + for axis in range(0, 4): + z = mx.sym.concat(a_sym, b_sym, dim=axis) + a = np.random.uniform(-1, 1, a_shape) + b = np.random.uniform(-1, 1, b_shape) + exe = z.simple_bind(ctx=mx.cpu(), a=a_shape, b=b_shape) + out = exe.forward(is_train=False, a=a, b=b) + ref_out = ref_concat(a, b, axis=axis) + out = out[0].asnumpy() + assert_almost_equal(out, ref_out) + + def check_concat_training(stype): + data_shape = rand_shape_nd(4) + for density in [1.0, 0.5, 0.0]: + a_sym = mx.sym.Variable('a') + b_sym = mx.sym.Variable('b') + sym = mx.sym.concat(a_sym, b_sym, dim=1) + a = rand_ndarray(shape=data_shape, stype=stype, density=density) + b = rand_ndarray(shape=data_shape, stype=stype, density=density) + in_location = [a, b] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_concat_training(stype) + +@with_seed() +def test_elemwise_add(): + def ref_add(a, b): + return np.add(a, b) + + a_sym = mx.sym.Variable("a") + b_sym = mx.sym.Variable("b") + dshape = rand_shape_nd(4) + a_shape = tuple(dshape) + b_shape = tuple(dshape) + z = mx.sym.elemwise_add(a_sym, b_sym) + a = np.random.uniform(-1, 1, a_shape) + b = np.random.uniform(-1, 1, b_shape) + exe = z.simple_bind(ctx=mx.cpu(), a=a_shape, b=b_shape) + out = exe.forward(is_train=False, a=a, b=b) + ref_out = ref_add(a, b) + out = out[0].asnumpy() + assert_almost_equal(out, ref_out, rtol=1e-6, atol=1e-6) + + def check_elemwise_add_training(stype): + data_shape = rand_shape_nd(4) + for density in [1.0, 0.5, 0.0]: + a_sym = mx.sym.Variable('a') + b_sym = mx.sym.Variable('b') + sym = mx.sym.elemwise_add(a_sym, b_sym) + a = rand_ndarray(shape=data_shape, stype=stype, density=density) + b = rand_ndarray(shape=data_shape, stype=stype, density=density) + in_location = [a, b] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_elemwise_add_training(stype) + if __name__ == '__main__': install.test_mkldnn_install()