diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 093d49ca2dd4..e497407a5877 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -802,6 +802,9 @@ class ForNode : public StmtNode { ForKind kind; /*! \brief The body of the for loop. */ Stmt body; + /*! \brief The additional termination condition of the for loop. */ + Optional test; + /*! * \brief Only valid when kind == ForKind::kThreadBinding * The context thread that this loop variable bounds to. @@ -823,6 +826,7 @@ class ForNode : public StmtNode { v->Visit("extent", &extent); v->Visit("kind", &kind); v->Visit("body", &body); + v->Visit("test", &test); v->Visit("thread_binding", &thread_binding); v->Visit("annotations", &annotations); v->Visit("span", &span); @@ -831,7 +835,8 @@ class ForNode : public StmtNode { bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && - equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations); + equal(test, other->test) && equal(thread_binding, other->thread_binding) && + equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -840,6 +845,7 @@ class ForNode : public StmtNode { hash_reduce(extent); hash_reduce(kind); hash_reduce(body); + hash_reduce(test); hash_reduce(thread_binding); hash_reduce(annotations); } @@ -855,7 +861,7 @@ class ForNode : public StmtNode { class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding = NullOpt, + Optional test = NullOpt, Optional thread_binding = NullOpt, Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 437e8f6610f4..e6a2c4af0bc4 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): + def for_range(self, begin, end, name="i", test=None, dtype="int32", kind="serial"): """Create a for iteration scope. Parameters @@ -221,6 +221,9 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): The name of iteration variable, if no input names, using typical index names i, j, k, then i_nidx + test : Expr, optional + The additional termination condition. + dtype : str, optional The data type of iteration variable. @@ -248,6 +251,10 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): loop_var = _expr.Var(name, dtype=dtype) extent = end if begin == 0 else (end - begin) + if test is not None: + msg = "A general termination condition is only supported for a serial loop." + assert kind == "serial", msg + def _exit_cb(): if kind == "serial": kind_id = _stmt.ForKind.SERIAL @@ -259,7 +266,7 @@ def _exit_cb(): kind_id = _stmt.ForKind.UNROLLED else: raise ValueError("Unknown kind") - self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) + self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), test)) return WithScope(loop_var, _exit_cb) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 9e1ef56cca58..ea53911c45e2 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -138,6 +138,7 @@ def __init__( extent, kind, body, + test=None, thread_binding=None, annotations=None, span=None, @@ -149,6 +150,7 @@ def __init__( extent, kind, body, + test, thread_binding, annotations, span, diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2d6e1e464ef8..e154f2e50c12 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -541,16 +541,18 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, nkeep) as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size - # boxes - # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - nms_inner_loop(ib, j) - with ib.else_scope(): + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we have already reached max_output_size + # boxes + with ib.for_range(0, nkeep, test=(num_valid_boxes_local[0] < max_output_size)) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): + nms_inner_loop(ib, j) + + with ib.else_scope(): + with ib.for_range(0, nkeep) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): nms_inner_loop(ib, j) with ib.if_scope(tx + 0 == 0): diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 4b0871ae2ce6..d54257f775ee 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -486,7 +486,12 @@ inline const char* ForKind2String(ForKind t) { Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " - << Print(op->min + op->extent) << ")"; + << Print(op->min + op->extent); + if (op->test) { + doc << ", (" << Print(op->test.value()) << "))"; + } else { + doc << ")"; + } if (op->kind != ForKind::kSerial) { doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index e2a8553199f0..83f90b208bf2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -980,7 +980,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { if (parallel_env_.penv == nullptr) { - CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, + CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding, op->annotations), 0); } else { @@ -996,13 +996,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), - op->loop_var, op->body); + op->loop_var, op->body, op->test); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), - llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body, + op->test); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 1dd76f6b9d51..784650a0e68d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -661,7 +661,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { } void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body) { + const Var& loop_var, const Stmt& body, Optional test) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); @@ -673,8 +673,14 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va loop_value->addIncoming(begin, pre_block); ICHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, - md_very_likely_branch_); + + llvm::Value* less_than = CreateLT(loop_var.dtype(), loop_value, end); + llvm::Value* cond = less_than; + if (test) { + cond = builder_->CreateAnd(less_than, MakeValue(test.value())); + } + builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); + builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); @@ -1325,7 +1331,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { ICHECK(op->kind == ForKind::kSerial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body, + op->test); } void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 71583708da2c..3a94ae817d6e 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -291,7 +291,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body); + const Var& loop_var, const Stmt& body, Optional test); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); // The IRBuilder. diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..b95988a8d6e4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -891,7 +891,12 @@ void CodeGenC::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = 0; " << vid << " < " << extent; + if (op->test) { + std::string test = PrintExpr(op->test.value()); + stream << " && (" << test << ")"; + } + stream << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..a0ff74c095df 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -277,7 +277,7 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, + return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, op->test, op->thread_binding, op->annotations); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_mapextent, body); } else { return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type), - op->body, op->thread_binding, op->annotations); + op->body, op->test, op->thread_binding, op->annotations); } } return StmtMutator::VisitStmt_(op); @@ -414,7 +414,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type); } const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return For(target->var, range->min, range->extent, kind, body, op->thread_binding, + return For(target->var, range->min, range->extent, kind, body, op->test, op->thread_binding, op->annotations); } }; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 74d1a19d2cfe..6e074f624cf7 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -968,8 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, - op->annotations); + stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test, + op->thread_binding, op->annotations); } } return stmt; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 92dc38797544..ac7753d5986d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -129,7 +129,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding, Map annotations, Span span) { + Optional test, Optional thread_binding, + Map annotations, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -143,6 +144,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, node->extent = std::move(extent); node->kind = kind; node->body = std::move(body); + node->test = std::move(test); node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); node->span = std::move(span); @@ -150,9 +152,9 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, } TVM_REGISTER_GLOBAL("tir.For").set_body_typed( - [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, + [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional test, Optional thread_binding, Optional> annotations, Span span) { - return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + return For(loop_var, min, extent, static_cast(kind), body, test, thread_binding, annotations.value_or(Map()), span); }); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e0ccb49fc454..9143dd580864 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -43,6 +43,9 @@ void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); this->VisitStmt(op->body); + if (op->test) { + this->VisitExpr(op->test.value()); + } } void StmtVisitor::VisitStmt_(const AllocateNode* op) { @@ -168,6 +171,11 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); + Optional test = NullOpt; + if (op->test) { + test = this->VisitExpr(op->test.value()); + } + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { @@ -175,6 +183,9 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { n->min = std::move(min); n->extent = std::move(extent); n->body = std::move(body); + if (test) { + n->test = std::move(test); + } return Stmt(n); } } diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 7bae0ce8ca75..1212a3afc267 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -142,6 +142,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { HoistCandidateSelector() { InitRecorder(); } void VisitStmt_(const ForNode* op) final { + if (op->test) { + // Do not hoist if this is a while loop + return; + } // If already recording complete, // then stop tracing if (RecordingComplete()) { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..fb6f288f5f11 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -149,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, + return For(new_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding, op->annotations); } else { defined_.insert(v.get()); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dc34626205a1..d0969d6f8aa5 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -221,7 +221,7 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, - op->thread_binding, op->annotations); + op->test, op->thread_binding, op->annotations); } Stmt VisitStmt_(const AttrStmtNode* op) final { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0b1429ca7efa..5603cb4f2061 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -444,7 +444,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), + return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), op->test, op->thread_binding, op->annotations); } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..f6821c71f9ae 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -125,7 +125,7 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->kind != ForKind::kUnrolled) { - return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, + return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, op->test, op->thread_binding, op->annotations); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 66f4ae329f69..c8b30e750471 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -365,7 +365,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, + return For(op->loop_var, op->min, extent, op->kind, body, op->test, op->thread_binding, op->annotations); } } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index b84ee09b9fd9..dd01cec0d0de 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,9 +173,99 @@ def check_target(target): check_target("cuda") +def test_binary_search(): + def binary_search(ib, n, i, Aptr, Bptr, Cptr): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n + v = Bptr[i] + num_loop = int(np.log2(n)) + 1 + + with ib.for_range(0, num_loop, test=(lo[0] < hi[0])) as _: + mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + Cptr[i] = lo[0] + + def searchsorted_ir_cpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + binary_search(ib, n, i, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + def searchsorted_ir_gpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = 32 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + binary_search(ib, n, tid, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + def check_target(target, ir): + if not tvm.testing.device_enabled(target): + return + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: ir(ins[0], ins[1], outs[0], n), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): + func = tvm.build(s, [A, B, C], target) + + ctx = tvm.context(target, 0) + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", searchsorted_ir_cpu) + check_target("cuda", searchsorted_ir_gpu) + check_target("nvptx", searchsorted_ir_gpu) + + if __name__ == "__main__": test_prefetch() test_if() test_for() test_cpu() test_gpu() + test_binary_search() diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9770857fb0b9..f22d70b8b3a1 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -236,6 +236,7 @@ def _merge_block(slist, body): op.extent, op.kind, body, + op.test, op.thread_binding, op.annotations, ) @@ -321,7 +322,14 @@ def _do_fold(stmt): op = stmt.body assert isinstance(op, tvm.tir.For) return tvm.tir.For( - op.loop_var, op.min, 2, op.kind, op.body, op.thread_binding, op.annotations + op.loop_var, + op.min, + 2, + op.kind, + op.body, + op.test, + op.thread_binding, + op.annotations, ) return None