Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 113 additions & 60 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
}
return CodeGenLLVM::Finish();
}
llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index,
int kind) {

CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf,
llvm::Value* index, int kind) {
if (kind < builtin::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
Expand All @@ -257,57 +258,87 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::
}
switch (kind) {
case builtin::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index));
}
case builtin::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(4);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(5);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeCode: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeBits: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(6);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceId: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceType: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kTVMValueContent: {
ICHECK_EQ(t.lanes(), 1);
ICHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index));
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index));
} else {
ICHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreateInBoundsGEP(buf, index);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index);
return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()));
}
}
default:
LOG(FATAL) << "unknown field code";
return nullptr;
return TypedPointer();
}
}

Expand Down Expand Up @@ -373,7 +404,10 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string
llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
ICHECK(gv != nullptr);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment()));
llvm::LoadInst* faddr =
builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment());
#else
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
#endif
Expand Down Expand Up @@ -485,34 +519,36 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
builder_->SetInsertPoint(compute_call_end);
}

llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* num_bytes) {
CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array<Var>& vfields,
uint64_t* num_bytes) {
if (vfields.size() == 0) {
*num_bytes = 0U;
return llvm::Constant::getNullValue(t_void_p_);
return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_));
}
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
ICHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::StructType* ctype = llvm::StructType::create(fields);
llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)}));
}
*num_bytes = data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType());
return cdata;
*num_bytes = data_layout_->getTypeAllocSize(ctype);
return TypedPointer(ctype, cvalue);
}

void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array<Var>& vfields,
void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array<Var>& vfields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)}));
llvm::Type* field_type = cdata.type->getStructElementType(i);
llvm::Value* field_addr =
builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)});
(*vmap)[vfields[i].get()] = builder_->CreateLoad(field_type, field_addr);
}
}

Expand All @@ -525,21 +561,22 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
// allocate and setup the closure, call the closure.
Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
TypedPointer cdata = PackClosureData(vfields, &nbytes);
#if TVM_LLVM_VERSION >= 90
auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
auto launch_callee = RuntimeTVMParallelLaunch();
#endif
BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
launch_callee,
{f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
Expand All @@ -548,8 +585,9 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
par_env.task_id = Var("task_id", DataType::Int(32));
par_env.num_task = Var("num_task", DataType::Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)}));
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
t_int32_,
builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
std::swap(function_, f);
std::swap(parallel_env_, par_env);
Expand Down Expand Up @@ -592,14 +630,14 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
// allocate and setup the closure, call the closure.
uint64_t nbytes;
Array<Var> vfields = tir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
TypedPointer cdata = PackClosureData(vfields, &nbytes);
BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall(
finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function.
BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
Expand Down Expand Up @@ -644,7 +682,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align));
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align);
#else
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
#endif
Expand All @@ -656,8 +696,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* out =
WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); });
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* ctx =
builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment()));
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
llvm::Align(gv_mod_ctx_->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
gv_mod_ctx_->getAlignment());
#else
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment());
#endif
Expand All @@ -671,7 +714,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align));
llvm::Value* loaded_handle =
builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align);
#else
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
#endif
Expand All @@ -698,37 +744,44 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
llvm::Value* stack_value = MakeValue(args[1]);
llvm::Value* stack_tcode = MakeValue(args[2]);
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
llvm::Value* call = builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode});
call_callee,
{handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr});
llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Value* load_ptr =
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
#if TVM_LLVM_VERSION >= 110
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
#else
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
#endif

pc.end_block = end_block;
Expand Down Expand Up @@ -871,24 +924,24 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::tvm_struct_get())) {
ICHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref =
this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
TypedPointer ref =
CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
if (kind == builtin::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
return builder_->CreatePointerCast(ref.addr, t_void_p_);
} else {
return builder_->CreateLoad(ref);
return builder_->CreateLoad(ref.type, ref.addr);
}
} else if (op->op.same_as(builtin::tvm_struct_set())) {
ICHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
ICHECK(kind != builtin::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType());
value = builder_->CreatePointerCast(value, ref.type);
}
builder_->CreateStore(value, ref);
builder_->CreateStore(value, ref.addr);
return ConstInt32(0);
} else if (op->op.same_as(builtin::tvm_stack_alloca())) {
ICHECK_EQ(op->args.size(), 2U);
Expand Down
6 changes: 3 additions & 3 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t* num_bytes);
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value* cdata, const Array<Var>& fields,
TypedPointer PackClosureData(const Array<Var>& fields, uint64_t* num_bytes);
TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(TypedPointer cdata, const Array<Var>& fields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
struct PackedCall {
Expand Down
Loading