[ConvertLayout] Squeeze and reduce ops#7835
Conversation
1. Add FInferCorrectLayout for squeeze 2. Handle keep_dims = False for reduce ops
| getent group "${CI_BUILD_GID}" || addgroup --force-badname --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" | ||
| getent passwd "${CI_BUILD_UID}" || adduser --force-badname --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ |
There was a problem hiding this comment.
Do not change unrelevant files. Please file another PR for this change if necessary.
There was a problem hiding this comment.
Thanks for your comment, I've solved all those issues, please take another look.
| return std::make_tuple(params->keepdims ? stripe(layout) : Layout(new_layout_string), | ||
| new_r_axes); |
There was a problem hiding this comment.
It seems to me that infer could be further improved, since the desired layouts of two branches (keepdims or not) are totally different. Could we make it like the following, so that you don't need to traverse the layout twice when keepdims is true?
std::string new_layout_string = "";
for (auto iter_var : layout->axes) {
// ...
if (layout_axis.IsPrimal()) {
if (params->keepdims || !old_r_dims.count(layout_dim)) {
new_layout_string += layout_dim;
}
axis_index++;
}
}
return std::make_tuple(new_layout_string, new_r_axes);If the above solution works, we could further consider merging infer and stripe to reduce duplicated logic:
auto infer = [&](const Layout& layout, const bool keepdims) { /* ... */ };
// ...
// Origin: std::tie(inferred_out, new_r_axes) = infer(new_in_layouts[0]);
// Origin: inferred_in = stripe(new_in_layouts[0]);
std::tie(inferred_out, new_r_axes) = infer(new_in_layouts[0], params->keepdims);
std::tie(inferred_in, std::ignore) = infer(new_in_layouts[0], false);| } | ||
| } | ||
|
|
||
| if (axis.size() == 0) { |
There was a problem hiding this comment.
Add comment saying nothing for squeeze, or simply put this into the above else branch to make it clearer.
There was a problem hiding this comment.
I've removed this check since it can be handled in following logic, and a case is added to cover nothing to squeeze case
| return Array<Array<Layout>>{{inferred_input}, {inferred_output}}; | ||
| } | ||
|
|
||
| auto kept = [&](size_t i, Array<Integer> axis) { |
There was a problem hiding this comment.
- Add comments, or name this function better.
- I'd suggest moving this function down to get closer to its callee to reduce the confusion.
comaniac
left a comment
There was a problem hiding this comment.
LGTM.
cc @anijain2305 to take a final look.
anijain2305
left a comment
There was a problem hiding this comment.
Thanks for the contribution. LGTM!
|
Thanks @lixiaoquan @anijain2305 |
@anijain2305 Could you please review? thanks.