[RELAY][PASS] FoldScaleAxis Forward#2020
Conversation
|
@jroesch @ZihengJiang @yzhliu @zhiics @FrozenGene @merrymercy @srkreddy1238 @masahi please review |
|
|
||
| Parameters | ||
| ---------- | ||
| data : relay.Expr |
| if (!rhs.defined()) return rhs; | ||
| AxesSet ret; | ||
| size_t i = 0, j = 0; | ||
| while (i < lhs.size() && j < rhs.size()) { |
There was a problem hiding this comment.
Out of curiosity, are both lhs and rhs always sorted?
There was a problem hiding this comment.
Yes, this is the requirement of axis set, Thanks for pointing this out, will add a comment block about it
|
|
||
| /*! | ||
| * \brief The transform function, transform an old call to | ||
| * new one given the new args. |
| std::unordered_map<const Node*, AxesSet> message_; | ||
| // Update the message stored at node. | ||
| void Update(const Expr& node, const AxesSet& axes) { | ||
| // We run interection of messages: |
| Expr new_e = ExprMutator::VisitExpr_(op); | ||
| if (!checked_type.same_as(new_e->checked_type_)) { | ||
| // new_call and new_var's code is only going to be valid for VarNode/CallNode. | ||
| // Compiler optimization will likely fold the these away for other nodes. |
|
|
||
| // Conv2D consumes the scale axis during transformation. | ||
| STuple Conv2DForwardTransform(const Call& ref_call, | ||
| const AxesSet& expected_axes, |
There was a problem hiding this comment.
expected_axes is unused. Should it be attached to rnode?
| axis : None or List[int] | ||
| Axes to remove. | ||
| If axes = [] or = None, remove all axis of dimensions 1. | ||
| If axes = None, remove all axis of dimensions 1. |
There was a problem hiding this comment.
change the comment also
There was a problem hiding this comment.
I do not get what do you mean
There was a problem hiding this comment.
axes should be changed to axis here
|
Thanks @zhiics @ZihengJiang for helpful reviews, please check again |
| } | ||
| ++j; | ||
| } else { | ||
| if (i >= base) { |
| } | ||
|
|
||
| void VisitExpr_(const TupleGetItemNode* op) { | ||
| // pass, do nothing |
There was a problem hiding this comment.
why dont you visit inside? maybe there is opt ability inside
|
|
||
| void VisitExpr_(const IfNode* op) { | ||
| ExprVisitor::VisitExpr_(op); | ||
| // do pass through condition. |
| // AddSub | ||
| Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) { | ||
| const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); | ||
| const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
There was a problem hiding this comment.
Do we need to add check here?
|
Thanks @MarisaKirisame @ZihengJiang , I have make followup changes, please check again |
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
This is a first serious attempt to implement an NN related optimization pass on relay. This PR contains the following changes:
Hopefully, this can serve as an example of how optimizations can be done in NNVMv2(relay) IR, and how the infrastructure of relay makes writing optimization in a more principled fashion.
Goal
Fold the scaling of axis(usually caused by BatchNorm) into weight of conv2d in the future. For example
Old:
Transformed:
Further constant folding can fold the multiplication and we remove the scaling in the network.
The Algorithm
While so far only the forward direction is implemented. The general idea is that we transform Expr to tuple of
(value, axes, scale), where the final result satisfies:Then we can propagate this signal along and fold the scale if necessary. However, it is possible that certain scale may never be consumed if there is no dense/conv2d that follow multiplication.
In order to make sure all the scale we sent out can be consumed eventually, we run a backward "preparation phase", which propagates the demand of the potential axes scaling back to its input.
The new pass is more general than the FoldScaleAxis in nnvm