[RELAY][OP]Strided slice#1891
Conversation
4f8aae0 to
b0688f1
Compare
|
@MarisaKirisame @yzhliu @yuruofeifei @srkreddy1238 @tqchen please review. |
| TVM_ATTR_FIELD(begin) | ||
| .describe("Indices for begin of slice"); | ||
| TVM_ATTR_FIELD(end) | ||
| .describe("Indices for end of the slice"); |
There was a problem hiding this comment.
describe inclusive, exclusive?
| } | ||
|
|
||
| reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); | ||
| return true; |
There was a problem hiding this comment.
undef min/max somewhere
45efc2f to
32a2e69
Compare
|
please rebase against master after #1934 to make use of the newly introduced API and add test-case to make sure text format works. Thanks! |
15b6a24 to
4e2edf7
Compare
deb1df4 to
41b6e19
Compare
|
@tqchen this can be merged? Anything else need to be done? can you please review once again and let me know. Thanks. |
| std::vector<IndexExpr> oshape(dshape.size()); | ||
|
|
||
| for (size_t i = 0; i < num_axis; ++i) { | ||
| auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0; |
There was a problem hiding this comment.
using Assert here is not reliable, as if stride_vec is symbolic, then assert does not reflect anything.
|
|
||
| /*! \brief Attributes for StridedSlice operator */ | ||
| struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { | ||
| Array<IndexExpr> begin; |
There was a problem hiding this comment.
Let us change Array<IndexExpr> -> Array<Integer> for now, as integer is really what we can do reliably so far
| struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { | ||
| Array<IndexExpr> begin; | ||
| Array<IndexExpr> end; | ||
| Array<IndexExpr> stride; |
There was a problem hiding this comment.
stride->strides as per https://www.tensorflow.org/api_docs/python/tf/strided_slice
| auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; | ||
| auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i]; | ||
|
|
||
| begin = min(max(begin, begin_range), end_range); |
There was a problem hiding this comment.
begin and end is derived from dshape which is symbolic, so cannot use std::min/max
All other comments are fixed.
| verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) | ||
| verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) | ||
| verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) | ||
| def verify_strided_slice(data, begin, end, stride, output): |
c9e48d3 to
3c3f641
Compare
3c3f641 to
ba9864d
Compare
|
Given that there is still some gap and we need this OP in quickly, I am opening a followup #2094 which is based on this PR. |
|
Thanks @siju-samuel @MarisaKirisame |
#1799
Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from others in the community.