add mkldnn softmax backward #17170
add mkldnn softmax backward #17170TaoLv merged 4 commits intoapache:masterfrom rongzha1:add_softmax_bwd
Conversation
| { MKLDNN_ARG_DIFF_SRC, *out_mem.second }, | ||
| }; | ||
|
|
||
| stream->RegisterPrimArgs(bwd_pd, args); |
There was a problem hiding this comment.
will change it to cache primitive pd.
There was a problem hiding this comment.
I mean here you need give a primitive not a primitive descriptor. Please check the definition of RegisterPrimArgs.
There was a problem hiding this comment.
yes, you're right.
Although it is a pd, it call implicit constructor
inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
to get a primitive
There was a problem hiding this comment.
I will make it more clearly by adding constructor before used.
| const std::vector<NDArray>& inputs, | ||
| const std::vector<OpReqType>& req, | ||
| const std::vector<NDArray>& outputs) { | ||
| // It seems MKLDNN softmax doesn't support training. |
There was a problem hiding this comment.
Will remove this out-dated comments
| std::vector<int> *in_attrs, | ||
| std::vector<int> *out_attrs) { | ||
| const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); | ||
| CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U); |
There was a problem hiding this comment.
this check will result in backward check fail.
Will recover this check and add another function for backward.
| std::shared_ptr<mkldnn::softmax_backward> bwd_; | ||
| }; | ||
|
|
||
| typedef ParamOpSign<SoftmaxParam> MKLDNNSoftmaxSignature; |
There was a problem hiding this comment.
exactly same, add extra one when rebase, will remove this line Thanks
| void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, | ||
| const OpContext &ctx, | ||
| const std::vector<NDArray> &in_data, | ||
| const std::vector<OpReqType>& req, |
There was a problem hiding this comment.
| const std::vector<OpReqType>& req, | |
| const std::vector<OpReqType> &req, |
| auto data_mem = in_data[1].GetMKLDNNData(); | ||
| auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data); | ||
|
|
||
| auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]); |
There was a problem hiding this comment.
- Please check if you want to support req=kAddTo;
- softmax backward primitive should support in-place calculation so no need to create additional buffer.
There was a problem hiding this comment.
you are right, create mem is used to support kAddTo
There was a problem hiding this comment.
Does the original softmax backward support kAddTo?
There was a problem hiding this comment.
yes, original softmax bwd support kAddTo : SoftmaxGradCompute-->SoftmaxGrad -->KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
|
@TaoLv were all concerns resolved now? |
* add mkldnn softmax backward * add primitive cache for softmax bwd * fix preci failed test * rm duplicate line
* add mkldnn softmax backward * add primitive cache for softmax bwd * fix preci failed test * rm duplicate line
Description
add mkldnn softmax backward implementation
unitest pass
Should fix #13365
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@PatricZhao @TaoLv