[RELAY][PASS] General OpFusion.#2090
Conversation
|
cc @masahi @junrushao1994 @ZihengJiang @jroesch @zhiics @yzhliu please review |
| class Creator; | ||
| }; | ||
|
|
||
| // Creator a post donimator tree of the dataflow |
| * \param edge_pattern The edge pattern across all the parents. | ||
| * \return The least common acenstor of thw two. | ||
| */ | ||
| static Node* LCA(Node* lhs, |
There was a problem hiding this comment.
Could we use a more descriptive name?
| /*! | ||
| * \brief Check all the node between src and sink satisfies fcond. | ||
| * | ||
| * The heck does not include, source and sink. |
|
|
||
|
|
||
| Expr FuseOps(const Expr& expr) { | ||
| Expr FuseOps(const Expr& expr, int fuse_opt_level) { |
There was a problem hiding this comment.
I believe the pass should still take a module for the time being, and in the global var case recursively apply fusion. Our experiments with adding data types and using iterative structures requires that we apply these optimizations to all functions which are reachable from the main one.
There was a problem hiding this comment.
I agree, however, I think we could wrap a FuseOps for envs via applies all, or get a list of functions via reachable analysis and then apply this Op.
Since this pass optionally requires the env, I feel having a reachable analysis to pull out the list of functions then do foreach apply is a better approach
| !new_var->type_annotation.defined()); | ||
|
|
||
| if (!need_update_type && !need_update_var && !need_update_call) return new_e; | ||
| bool need_update_fn = ( |
There was a problem hiding this comment.
We can punt this conversation to the future, but @MarisaKirisame and I feel like this style of resolver is going to continue to cause problems as we continue. We can chat further in person.
| // The parameters of the function. | ||
| Array<Var> params; | ||
| // The arguments to call the parameters. | ||
| Array<Expr> arguments; |
There was a problem hiding this comment.
call the parameters ?
Shouldn't it be call the function?
|
Could you elaborate the fusing condition? Update: sorry, I got it in the code. ignore me |
|
So we are able to fuse a residual block into one op in relay, have I got it right? |
| func = ir_pass.fuse_ops(func) | ||
|
|
||
| if cfg.pass_enabled("OpFusion"): | ||
| func = ir_pass.fuse_ops(func, opt_level=2) |
There was a problem hiding this comment.
should the opt_level here be the default opt_level instead of 2, though default is 2? Or should it be 1 since level 1 enables fusion?
| * \return The allocated object. | ||
| * \note The type T must be simple type, or only contain | ||
| * memory allocated from the same arena. | ||
| * Otherwise the destructor need to be called explicitly. |
| * \param arena The arena used for node allocation. | ||
| * \param graph The graph to be analyze. | ||
| * \return The dominator tree of the graph. | ||
| * \note This algorithm makes use of the fact that graph is DAG, |
There was a problem hiding this comment.
Just double check, a node will only have at most one parent in this DAG, right?
There was a problem hiding this comment.
A node can have multiple outputs in the DAG. A node will only have one parent in the dominator tree(it is immediate post-dominator)
There was a problem hiding this comment.
@tqchen Thanks. I misunderstood. I thought the node in the LCA algorithm was a graph node.
| Array<Expr> arguments; | ||
| // Get a new parameter or allocate an old one | ||
| Var GetOrAllocParam(const Expr& expr, const Type& type) { | ||
| // run linear scan as most fused group contain only a few inputs. |
|
Thanks @jroesch @masahi @junrushao1994 @zhiics this is merged |
|
@junrushao1994 No, all paths between the first conv op at the top and the elemwise add at the bottom need to contain only broadcast or elemwise ops, to be fused into a single op. Resblock has another conv ops on one of the paths, so it cannot be fused into a single op. |
| auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >(); | ||
| link->value.node = parent; | ||
| link->value.pattern = pattern; | ||
| current->outputs.Push(link); |
There was a problem hiding this comment.
Just to clarify my confusion, is the name parent a good one here? From the code above, it looks like an edge is pointing to a parent node.
Naturally I think of a parent node as a node being pointed from. To me consumer seems a better name here.
I understand that this parent will indeed become the parent node in the post dom tree representation. But in this forward data flow graph, this parent node seems more like a child node.
| * The combined edge pattern across all the parents. | ||
| * \return The least common acenstor of thw two. | ||
| */ | ||
| static Node* LeastCommonAcenstor( |
There was a problem hiding this comment.
Acenstor -> Ancestor !!
The comment above needs to be fixed too.
There was a problem hiding this comment.
Thanks for pointing this out, sorry this is merged after your comment, this is fixed by #2098
| // First we convert all chains of fusable ops into | ||
| // abstracted functions which we mark as primtive | ||
| // then we convert these primtive functions into | ||
| // new operators. |
|
This is the restriction of current schedule, which avoids to fuse the input, we can change the rules to do so |
This is a new version of OpFusion algorithm that is cleaner and general than the one used in NNVMv1.
Method
The main challenge of genenral fusor is to handle possible diamond shape branches, in the following graph, conv2d can be fused to elemwise add.
However, at the point of conv2d we do not necessarily know that all its future path will merge at the elemwise add. The new fusor algorithm applies post-dominator analysis. The immediate post-dominator of a node defined by the closest node where all the future path goes into. In the above case, the elemwise add is the immediate post-dominator of conv2d. The general algorithm is as follows:
Note that, because we run analysis on a DAG, we use a single pass post-dominator
tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
The fusion algorithm traverses from each node and checks if it can be fused to its immediate post dominator. It has to check the following things:
satiesfies the fuse condition.
will still run correctly.