Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug #7562
Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug #7562masahi merged 26 commits intoapache:mainfrom
Conversation
|
@masahi @tkonolige @mbrookhart @ymwangg PTAL. |
|
Nice, are you going to add frontend? |
|
Yes, do you prefer I add it in this PR or the next one ? I want to add frontends for multiple framework ops based on this relay op. |
|
Yes, I think it's better to add frontends (TF, PT) to make sure they are supported by this op. |
|
@masahi I have added 3 TF Ops to the frontend, all of which use this op. Let me know if that's enough. |
|
Can you also try PT EmbeddingBag? |
|
Hey @masahi , upon closely reading the Embedding Bag documentation, it seems that: (Referencing the
Now all of these ops exist except Let me know your thoughts on the best way to reuse existing code. After that implementation would be only a trivial few lines. |
|
Ok lets do embedding bag later, then. |
tkonolige
left a comment
There was a problem hiding this comment.
Looks pretty good. A couple documentation improvements would be nice though.
|
@tkonolige I have finished addressing your comments, please re-review |
|
Actually I would like to add another related op in this PR. I will ping you after I am done with that. |
|
@tkonolige @masahi . I am done with the PR Please review/ re-review. |
tkonolige
left a comment
There was a problem hiding this comment.
A couple minor comments
mbrookhart
left a comment
There was a problem hiding this comment.
Overall LGTM.
Could you add a direct test for scatter_add with dynamic inputs? That would help identifying problems in the future.
|
@tkonolige int64 is not allowed with tf sparse ops, I put it on the relay op tests and the tf math ops. |
| assert len(inputs) == 3, "There should be 3 input tensors" | ||
| data = _op.take(inputs[0], inputs[1], axis=0) | ||
| return _op.segment_sum(data, inputs[2]) | ||
|
|
There was a problem hiding this comment.
This is ok for now, but we definitely want a fused implementation here, just like TF/PT/C2 does. I don't expect this would work for a huge embedding table people want to use in practice.
There was a problem hiding this comment.
I agree. When you say a "fused implementation" , do you mean that all of it happens in a single ir ?
There was a problem hiding this comment.
Do you have any examples of what a "fused implementation" is ? Does this mean that in a fused implementation, the frontend will always just be a one liner ?
There was a problem hiding this comment.
In this case, I understand we must do the take and the addition from segment_sum simultaneously for performance. So a fused implementation in that case would be a new op ?
There was a problem hiding this comment.
By "fused" I meant we shouldn't materialize the result of take, which can be huge. In a fused implementation, we need to look up indices and accumulate the sum on the fly. This is why PT has EmbeddingBag op, see their doc https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html.
Yes, a complicated op like this will not likely be feasible if we rely only on Relay-level op fusion. We need a dedicated sparse_segment_sum TOPI/Relay op.
There was a problem hiding this comment.
I think he meant that scatter_nd exactly realizes fused take and segment_sum above. I haven't put deep thought into this but it made sense to me. But I remember parallelizing scatter_nd looked harder than scatter_add.
There was a problem hiding this comment.
Yes, I am having a bit of a mind block understanding how take and segment_sum is essentially scatter_nd, do anyone of you mind writing small pseudocode ?
There was a problem hiding this comment.
FWIW I did a few variants of torch.nn.EmbeddingBag, c2::sparse_length_sum, etc in TVM IR in https://github.com/ajtulloch/tvm/blob/4b98beb75ca1505ec81ddca358ad61282ab6a05b/topi/python/topi/x86/sparse.py#L162-L257, https://github.com/ajtulloch/tvm/blob/sparse-ops/topi/python/topi/sparse/sparse_lengths_sum.py#L45-L98, https://github.com/ajtulloch/sparse-ads-baselines/blob/a495ea076882615d454d27a1a5b191ec675d3acc/lxu_cache_cpu_funcs.py#L8-L149, etc if that's of interest.
There was a problem hiding this comment.
Thinking about this more, I believe the take is necessary if we are using scatter_nd. We could make a more generic version of scatter_nd and gather_nd that has indices in both the input and output buffers. That would cover this case.
There was a problem hiding this comment.
ok I'll merge this as it is then.
|
Thanks @codeislife99 @tkonolige @mbrookhart |
…add dynamic bug (apache#7562) * Add segment sum Op * Remove unnecessary * Documentation * Black * Add GPU * Uncomment * Add documentation * Add dynamic tests * Add TF Op * Add Sparse Segment Sum * Add test coverage * PR Comments * Int64 tests * Add SparseSegmentSqrtN * Add SparseSegmentSqrtNOp * Deduplicate code * Add SparseSegmentMean * Parametrize Tests * Remove * Modularize * Black * Modularize Code * Pylint * PR Comments * Add scatter add tests * Remove Test Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-251.us-east-2.compute.internal>
…add dynamic bug (apache#7562) * Add segment sum Op * Remove unnecessary * Documentation * Black * Add GPU * Uncomment * Add documentation * Add dynamic tests * Add TF Op * Add Sparse Segment Sum * Add test coverage * PR Comments * Int64 tests * Add SparseSegmentSqrtN * Add SparseSegmentSqrtNOp * Deduplicate code * Add SparseSegmentMean * Parametrize Tests * Remove * Modularize * Black * Modularize Code * Pylint * PR Comments * Add scatter add tests * Remove Test Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-251.us-east-2.compute.internal>
This PR adds the Segment Sum Op which will serve as a generic op for multiple framework specific ops
Tensorflow -- tf.math.segment_sum, tf.sparse.segment_sum
Caffe -- sparse length sum
PyTorch -- Embedding Bag
Since this PR uses scatter_add , it also makes some small changes which make it work for dynamic inputs.