Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
Expand Down
84 changes: 84 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(bool is_train,
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
}

static mkldnn::softmax_backward::primitive_desc GetSoftmaxBwdPd(
const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) {
mkldnn::memory::desc diff_md = diff_mem.get_desc();
mkldnn::memory::desc data_md = data_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto desc = mkldnn::softmax_backward::desc(diff_md, data_md, axis);
return mkldnn::softmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd);
}


bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
Expand Down Expand Up @@ -131,6 +143,78 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
stream->Submit();
}

class MKLDNNSoftmaxBwd {
public:
mkldnn::softmax_backward::primitive_desc pd;

MKLDNNSoftmaxBwd(const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) :
pd(GetSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) {
bwd_ = std::make_shared<mkldnn::softmax_backward>(pd);
}

const mkldnn::softmax_backward &GetBwd() const {
return *bwd_;
}

private:
std::shared_ptr<mkldnn::softmax_backward> bwd_;
};

static MKLDNNSoftmaxBwd &GetSoftmaxBwd(const SoftmaxParam &param,
const int real_axis,
const std::vector<NDArray> &data,
const std::vector<NDArray> &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature, MKLDNNSoftmaxBwd, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature, MKLDNNSoftmaxBwd, OpHash> bwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(data);
key.AddSign(output);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto diff_mem = data[0].GetMKLDNNData();
auto data_mem = data[1].GetMKLDNNData();
auto fwd_pd = GetSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}

void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
if (req[0] == kNullOp) return;
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
auto diff_mem = in_data[0].GetMKLDNNData();
auto data_mem = in_data[1].GetMKLDNNData();
auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data);

auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please check if you want to support req=kAddTo;
  2. softmax backward primitive should support in-place calculation so no need to create additional buffer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, create mem is used to support kAddTo

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the original softmax backward support kAddTo?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, original softmax bwd support kAddTo : SoftmaxGradCompute-->SoftmaxGrad -->KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);

MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_args_map_t args = {
{ MKLDNN_ARG_DST, *data_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_mem },
{ MKLDNN_ARG_DIFF_SRC, *out_mem.second }
};

stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(out_data[0], out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
41 changes: 39 additions & 2 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// It seems MKLDNN softmax doesn't support training.
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -54,6 +53,23 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
inputs, req, outputs);
}

static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxGradCompute<cpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>, attrs, ctx,
inputs, req, outputs);
}

inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -72,6 +88,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}

inline static bool SoftmaxGradStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (param.use_length.value() || softmax_has_dtype_override(attrs)) {
auto& out_stype = out_attrs->at(0);
return storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif


Expand Down Expand Up @@ -147,8 +180,12 @@ NNVM_REGISTER_OP(_backward_softmax)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxGradComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxGradStorageType)
#endif
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);

} // namespace op
} // namespace mxnet