From 1ee7cb184400843b329f219de9768eff37237fa5 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 20 Mar 2023 16:51:53 -0700 Subject: [PATCH] Fix data type and add minimal reproducible test. Co-authored-by: Sunghyun Park --- src/tir/transforms/lower_tvm_builtin.cc | 17 +++++++++-------- .../test_tir_transform_lower_tvm_builtin.py | 19 ++++++++++++++++++- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 49023a5ad01f..d8df2cc55a0b 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator { } } } - PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); + PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { + // set total_bytes to uint64 to avoid overflow total_bytes = total_bytes * op->extents[i]; } ICHECK(device_type_.defined()) << "Unknown device type in current IR"; @@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator { Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error), op->body}); - Stmt alloca = LetStmt( - op->buffer_var, - Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}), - body); + Stmt alloca = + LetStmt(op->buffer_var, + Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), + total_bytes, IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}), + body); PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_), diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 76d6bb82cce3..d224a688d298 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -16,8 +16,8 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T import numpy as np -from tvm import testing @tvm.register_func("tvm.test_matmul") @@ -172,6 +172,23 @@ def build_tir(): tvm.testing.assert_allclose(a.numpy(), expected_value) +def test_lower_overflow_int32(): + @T.prim_func + def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")): + T.func_attr({"global_symbol": "variance4", "tir.noalias": True}) + rxplaceholder_red = T.allocate([32], "float32", "global") + T_subtract = T.allocate([822083584], "float32", "global") + rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red) + rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data) + T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract) + for ax1, ax2 in T.grid(32, 25690112): + cse_var_1: T.int32 = ax1 * 25690112 + ax2 + T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1] + + func = variance4 + tvm.build(func, target="llvm") # should not crash + + if __name__ == "__main__": test_call_packed_return_non_i32() test_lower_packed_func()