diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 5bd76404a998..acc362758a7c 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -191,24 +191,53 @@ class SEqualReducer { /*! * \brief Reduce condition to comparison of two attribute values. + * * \param lhs The left operand. + * * \param rhs The right operand. + * + * \param paths The paths to the LHS and RHS operands. If + * unspecified, will attempt to identify the attribute's address + * within the most recent ObjectRef. In general, the paths only + * require explicit handling for computed parameters + * (e.g. `array.size()`) + * * \return the immediate check result. */ - bool operator()(const double& lhs, const double& rhs) const; - bool operator()(const int64_t& lhs, const int64_t& rhs) const; - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; - bool operator()(const int& lhs, const int& rhs) const; - bool operator()(const bool& lhs, const bool& rhs) const; - bool operator()(const std::string& lhs, const std::string& rhs) const; - bool operator()(const DataType& lhs, const DataType& rhs) const; + bool operator()(const double& lhs, const double& rhs, + Optional paths = NullOpt) const; + bool operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const int& lhs, const int& rhs, Optional paths = NullOpt) const; + bool operator()(const bool& lhs, const bool& rhs, Optional paths = NullOpt) const; + bool operator()(const std::string& lhs, const std::string& rhs, + Optional paths = NullOpt) const; + bool operator()(const DataType& lhs, const DataType& rhs, + Optional paths = NullOpt) const; template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs) const { + bool operator()(const ENum& lhs, const ENum& rhs, + Optional paths = NullOpt) const { using Underlying = typename std::underlying_type::type; static_assert(std::is_same::value, "Enum must have `int` as the underlying type"); - return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs); + return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs, paths); + } + + template , ObjectPath>>> + bool operator()(const T& lhs, const T& rhs, const Callable& callable) { + if (IsPathTracingEnabled()) { + ObjectPathPair current_paths = GetCurrentObjectPaths(); + ObjectPathPair new_paths = {callable(current_paths->lhs_path), + callable(current_paths->rhs_path)}; + return (*this)(lhs, rhs, new_paths); + } else { + return (*this)(lhs, rhs); + } } /*! @@ -310,7 +339,8 @@ class SEqualReducer { void RecordMismatchPaths(const ObjectPathPair& paths) const; private: - bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; + bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, + Optional paths = NullOpt) const; bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const ObjectPathPair* paths) const; @@ -321,7 +351,8 @@ class SEqualReducer { template static bool CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data); + const PathTracingData* tracing_data, + Optional paths = NullOpt); /*! \brief Internal class pointer. */ Handler* handler_ = nullptr; diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..4d5bebf70894 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (!equal(this->attrs, other->attrs)) return false; + if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) { + return false; + } + + if (equal.IsPathTracingEnabled()) { + if ((functions.size() != other->functions.size()) || + (type_definitions.size() != other->type_definitions.size())) { + return false; + } + } - if (functions.size() != other->functions.size()) return false; - // Update GlobalVar remap + // Define remaps for GlobalVar and GlobalTypeVar based on their + // string name. Early bail-out is only performed when path-tracing + // is disabled, as the later equality checks on the member variables + // will provide better error messages. for (const auto& gv : this->GetGlobalVars()) { - if (!other->ContainGlobalVar(gv->name_hint)) return false; - if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + if (other->ContainGlobalVar(gv->name_hint)) { + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; + } } - // Checking functions - for (const auto& kv : this->functions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("functions") - ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; - if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; - } else { - if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (other->ContainGlobalTypeVar(gtv->name_hint)) { + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; } } - if (type_definitions.size() != other->type_definitions.size()) return false; - // Update GlobalTypeVar remap - for (const auto& gtv : this->GetGlobalTypeVars()) { - if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; - if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + // Checking functions and type definitions + if (!equal(this->functions, other->functions, + [](const auto& path) { return path->Attr("functions"); })) { + return false; } - // Checking type_definitions - for (const auto& kv : this->type_definitions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair type_paths = { - obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("type_definitions") - ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false; - } else { - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; - } + if (!equal(this->type_definitions, other->type_definitions, + [](const auto& path) { return path->Attr("type_definitions"); })) { + return false; } + return true; } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 42726af9859a..66a347f6b8ba 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { template /* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data) { + const PathTracingData* tracing_data, + Optional paths) { if (BaseValueEqual()(lhs, rhs)) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); - return false; } + + if (tracing_data && !tracing_data->first_mismatch->defined()) { + if (paths) { + *tracing_data->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); + } + } + return false; } -bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { +bool SEqualReducer::operator()(const double& lhs, const double& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { +bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { +bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { +bool SEqualReducer::operator()(const int& lhs, const int& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { +bool SEqualReducer::operator()(const bool& lhs, const bool& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { +bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { +bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, - const void* rhs_address) const { + const void* rhs_address, Optional paths) const { if (lhs == rhs) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); - return false; } + + if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { + if (paths) { + *tracing_data_->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_); + } + } + + return false; } const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 4bb13ed77ad8..eca78d649b85 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -19,6 +19,7 @@ import pytest from tvm import te from tvm.runtime import ObjectPath +from tvm.script import tir as T, ir as I def consistent_equal(x, y, map_free_vars=False): @@ -394,13 +395,29 @@ def test_seq_length_mismatch(): assert rhs_path == expected_rhs_path +def test_ir_module_equal(): + def generate(n: int): + @I.ir_module + class module: + @T.prim_func + def func(A: T.Buffer(1, "int32")): + for i in range(n): + A[0] = A[0] + 1 + + return module + + # Equivalent IRModules should compare as equivalent, even though + # they have distinct GlobalVars, and GlobalVars usually compare by + # reference equality. + tvm.ir.assert_structural_equal(generate(16), generate(16)) + + # When there is a difference, the location should include the + # function name that caused the failure. + with pytest.raises(ValueError) as err: + tvm.ir.assert_structural_equal(generate(16), generate(32)) + + assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0] + + if __name__ == "__main__": - test_exprs() - test_prim_func() - test_attrs() - test_array() - test_env_func() - test_stmt() - test_buffer_storage_scope() - test_buffer_load_store() - test_while() + tvm.testing.main()