From aec6febefb7618913e66969df82fa1773b10581c Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Thu, 2 Mar 2023 15:25:15 +0800 Subject: [PATCH] [Relay] Enhance EliminateCommonSubexpr to support Tuple argument If an argument of a call is a Tuple, we should check its fields. Different tuples with the same fields should be treated as same inputs --- .../transforms/eliminate_common_subexpr.cc | 28 ++++++++++++++-- .../test_pass_eliminate_common_subexpr.py | 33 +++++++++++++++++-- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index e9603575111d..9de1b86b17e1 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -65,8 +65,7 @@ class CommonSubexprEliminator : public MixedModeMutator { continue; } for (size_t i = 0; i < new_call->args.size(); i++) { - if (!new_call->args[i].same_as(candidate->args[i]) && - !IsEqualScalar(new_call->args[i], candidate->args[i])) { + if (!IsEquivalent(new_call->args[i], candidate->args[i])) { is_equivalent = false; break; } @@ -105,6 +104,31 @@ class CommonSubexprEliminator : public MixedModeMutator { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; runtime::TypedPackedFunc fskip_; + + private: + bool IsEquivalent(const Expr& arg, const Expr& candidate_arg) { + if (arg->IsInstance() && candidate_arg->IsInstance()) { + const TupleNode* arg_node = arg.as(); + const TupleNode* candidate_arg_node = candidate_arg.as(); + + if (arg_node->fields.size() != candidate_arg_node->fields.size()) { + return false; + } + + for (size_t i = 0; i < arg_node->fields.size(); i++) { + if (!arg_node->fields[i].same_as(candidate_arg_node->fields[i]) && + !IsEqualScalar(arg_node->fields[i], candidate_arg_node->fields[i])) { + return false; + } + } + } else { + if (!arg.same_as(candidate_arg) && !IsEqualScalar(arg, candidate_arg)) { + return false; + } + } + + return true; + } }; Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index ac519a98c7d3..a8ca5058ad7f 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" +import numpy as np import tvm from tvm import te @@ -116,6 +117,34 @@ def expected(): assert tvm.ir.structural_equal(z, expected()) +def test_tuple_arg(): + def before(): + x = relay.var("x", shape=(1, 16)) + y1 = relay.nn.relu(x) + y2 = relay.nn.relu(x) + y1 = relay.add(y1, relay.const(1.0, "float32")) + y2 = relay.add(y2, relay.const(1.0, "float32")) + c0 = relay.const(np.ones((1, 16)), "float32") + y1 = relay.concatenate([y1, c0], axis=0) + y2 = relay.concatenate([y2, c0], axis=0) + y = relay.add(y1, y2) + f = relay.Function([x], y) + return f + + def expected(): + x = relay.var("x", shape=(1, 16)) + y = relay.nn.relu(x) + y = relay.add(y, relay.const(1.0, "float32")) + c0 = relay.const(np.ones((1, 16)), "float32") + y = relay.concatenate([y, c0], axis=0) + y = relay.add(y, y) + f = relay.Function([x], y) + return run_opt_pass(f, transform.InferType()) + + z = before() + z = run_opt_pass(z, transform.EliminateCommonSubexpr()) + assert tvm.ir.structural_equal(z, expected()) + + if __name__ == "__main__": - test_simple() - test_callback() + tvm.testing.main()