diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 78e50a5eb5da..d5c8d077c669 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -32,7 +32,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); } else if (dtype == DataType::Bool()) { - return LiteralDoc::Boolean(imm->value, imm_p->Attr("value")); + return TIR(d, DType2Str(dtype)) + ->Call({LiteralDoc::Boolean(imm->value, imm_p->Attr("value"))}); } else { return TIR(d, DType2Str(dtype))->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))}); } diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 2806c7b2fc52..7826b5960bb3 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -314,7 +314,7 @@ def test_isnan(): y = te.var("y", "float16") assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))' z = te.var("z", "int32") - assert str(tvm.tir.isnan(z)) == "False" + assert str(tvm.tir.isnan(z)) == "T.bool(False)" k = te.var("k", "int8x2") assert str(tvm.tir.isnan(k).dtype) == "uint1x2" diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 42eb2b6be4e9..63df2de23129 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -1154,7 +1154,7 @@ def test_reverse_compute_inline_producer_predicate_disallowed(): sch.reverse_compute_inline(sch.get_block("compute_4")) assert ( - "that cannot be implied by the synthesized predicate True of the new inlined block" + "that cannot be implied by the synthesized predicate T.bool(True) of the new inlined block" in str(e) ) diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 171d49b6191b..25272d912da2 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -606,7 +606,7 @@ def test_select(): obj = tir.Select(True, 0, 2) _assert_print( obj, - """T.Select(True, 0, 2) + """T.Select(T.bool(True), 0, 2) """, ) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 6f07b6a75aeb..2ede87c34069 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3444,7 +3444,9 @@ def func() -> None: def bool_cast(): @T.prim_func def func() -> None: + a = T.bool() T.evaluate(T.bool(T.int32(0))) + T.evaluate(a == T.bool(False)) return func