Skip to content

Commit

Permalink
[llvm] Support real function with single scalar return value (#4452)
Browse files Browse the repository at this point in the history
* [llvm] Support real function with single scalar return value

* add comment
  • Loading branch information
lin-hitonami authored Mar 8, 2022
1 parent f714872 commit 2c04cb2
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 97 deletions.
119 changes: 65 additions & 54 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,54 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) {
create_naive_range_for(for_stmt);
}

llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {
llvm::Type *dest_ty = nullptr;
TI_ASSERT(!type->is<PointerType>());
if (auto cit = type->cast<CustomIntType>()) {
if (cit->get_is_signed())
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
else
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
} else {
dest_ty = tlctx->get_data_type(type);
}
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) {
// if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
// which doesn't mean anything.
// So we truncate to 32 bits first and then fptrunc to half if applicable
auto truncated =
builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32));
auto casted = builder->CreateBitCast(truncated,
llvm::Type::getFloatTy(*llvm_context));
return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context));
} else {
auto truncated = builder->CreateTrunc(
val, llvm::Type::getIntNTy(*llvm_context, dest_bits));

return builder->CreateBitCast(truncated, dest_ty);
}
}

llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
auto intermediate_bits = 0;
if (auto cit = type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits();
}
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
llvm::Type *intermediate_type = nullptr;
if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) {
val = builder->CreateFPExt(val, tlctx->get_data_type<float>());
intermediate_type = tlctx->get_data_type<int32>();
} else {
intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
}
return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type),
dest_ty);
}

void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
auto raw_arg = call(builder.get(), "RuntimeContext_get_args", get_context(),
tlctx->get_constant(stmt->arg_id));
Expand All @@ -1029,32 +1077,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
TI_ASSERT(!stmt->ret_type->is<PointerType>());
if (auto cit = stmt->ret_type->cast<CustomIntType>()) {
if (cit->get_is_signed())
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
else
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
} else {
dest_ty = tlctx->get_data_type(stmt->ret_type);
}
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) {
// if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
// which doesn't mean anything.
// So we truncate to 32 bits first and then fptrunc to half if applicable
auto truncated = builder->CreateTrunc(
raw_arg, llvm::Type::getIntNTy(*llvm_context, 32));
auto casted = builder->CreateBitCast(
truncated, llvm::Type::getFloatTy(*llvm_context));
llvm_val[stmt] =
builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context));
} else {
auto truncated = builder->CreateTrunc(
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));

llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty);
}
llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type);
}
}

Expand All @@ -1067,27 +1090,10 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) {
TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value);
int idx{0};
for (auto &value : stmt->values) {
auto intermediate_bits = 0;
if (auto cit = value->ret_type->cast<CustomIntType>()) {
intermediate_bits = data_type_bits(cit->get_compute_type());
} else {
intermediate_bits =
tlctx->get_data_type(value->ret_type)->getPrimitiveSizeInBits();
}
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
llvm::Type *intermediate_type = nullptr;
if (llvm_val[value]->getType() == llvm::Type::getHalfTy(*llvm_context)) {
llvm_val[value] = builder->CreateFPExt(llvm_val[value],
tlctx->get_data_type<float>());
intermediate_type = tlctx->get_data_type<int32>();
} else {
intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
}
auto extended = builder->CreateZExt(
builder->CreateBitCast(llvm_val[value], intermediate_type), dest_ty);
create_call("LLVMRuntime_store_result",
{get_runtime(), extended, tlctx->get_constant<int32>(idx++)});
create_call(
"RuntimeContext_store_result",
{get_context(), bitcast_to_u64(llvm_val[value], value->ret_type),
tlctx->get_constant<int32>(idx++)});
}
}
}
Expand Down Expand Up @@ -2387,17 +2393,22 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
auto *new_ctx = builder->CreateAlloca(get_runtime_type("RuntimeContext"));
call("RuntimeContext_set_runtime", new_ctx, get_runtime());
for (int i = 0; i < stmt->args.size(); i++) {
auto *original = llvm_val[stmt->args[i]];
int src_bits = original->getType()->getPrimitiveSizeInBits();
auto *cast = builder->CreateBitCast(
original, llvm::Type::getIntNTy(*llvm_context, src_bits));
auto *val =
builder->CreateZExt(cast, llvm::Type::getInt64Ty(*llvm_context));
bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type);
call("RuntimeContext_set_args", new_ctx,
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
}

llvm_val[stmt] = create_call(llvm_func, {new_ctx});
llvm::Value *result_buffer = nullptr;
if (stmt->ret_type->is<PrimitiveType>() &&
!stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
result_buffer = builder->CreateAlloca(tlctx->get_data_type<uint64>());
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer);
create_call(llvm_func, {new_ctx});
auto *ret_val_u64 = builder->CreateLoad(result_buffer);
llvm_val[stmt] = bitcast_from_u64(ret_val_u64, stmt->ret_type);
} else {
create_call(llvm_func, {new_ctx});
}
}

TLANG_NAMESPACE_END
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(FuncCallStmt *stmt) override;

llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type);
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~CodeGenLLVM() override = default;
};

Expand Down
5 changes: 5 additions & 0 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ struct RuntimeContext {
int32 cpu_thread_id;
// |is_device_allocation| is true iff args[i] is a DeviceAllocation*.
bool is_device_allocation[taichi_max_num_args_total]{false};
// We move the pointer of result buffer from LLVMRuntime to RuntimeContext
// because each real function need a place to store its result, but
// LLVMRuntime is shared among functions. So we moved the pointer to
// RuntimeContext which each function have one.
uint64 *result_buffer;

static constexpr size_t extra_args_size = sizeof(extra_args);

Expand Down
1 change: 1 addition & 0 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ RuntimeContext &Kernel::LaunchContextBuilder::get_context() {
ctx_->runtime = llvm_program_impl->get_llvm_runtime();
}
#endif
ctx_->result_buffer = kernel_->program->result_buffer;
return *ctx_;
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val);

STRUCT_FIELD_ARRAY(RuntimeContext, args);
STRUCT_FIELD(RuntimeContext, runtime);
STRUCT_FIELD(RuntimeContext, result_buffer)

int32 RuntimeContext_get_extra_args(RuntimeContext *ctx, int32 i, int32 j) {
return ctx->extra_args[i][j];
Expand Down Expand Up @@ -696,8 +697,8 @@ struct NodeManager {

extern "C" {

void LLVMRuntime_store_result(LLVMRuntime *runtime, u64 ret, u32 idx) {
runtime->set_result(taichi_result_buffer_ret_value_id + idx, ret);
void RuntimeContext_store_result(RuntimeContext *ctx, u64 ret, u32 idx) {
ctx->result_buffer[taichi_result_buffer_ret_value_id + idx] = ret;
}

void LLVMRuntime_profiler_start(LLVMRuntime *runtime, Ptr kernel_name) {
Expand Down
80 changes: 39 additions & 41 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,45 @@ def run():
assert x[None] == 42


# @test_utils.test(arch=[ti.cpu, ti.gpu])
# def test_function_with_return():
# x = ti.field(ti.i32, shape=())
#
# @ti.experimental.real_func
# def foo(val: ti.i32) -> ti.i32:
# x[None] += val
# return val
#
# @ti.kernel
# def run():
# a = foo(40)
# foo(2)
# assert a == 40
#
# x[None] = 0
# run()
# assert x[None] == 42
#
#
# @test_utils.test(arch=[ti.cpu, ti.gpu])
# def test_call_expressions():
# x = ti.field(ti.i32, shape=())
#
# @ti.experimental.real_func
# def foo(val: ti.i32) -> ti.i32:
# if x[None] > 10:
# x[None] += 1
# x[None] += val
# return 0
#
# @ti.kernel
# def run():
# assert foo(15) == 0
# assert foo(10) == 0
#
# x[None] = 0
# run()
# assert x[None] == 26
#
#
@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
def test_function_with_return():
x = ti.field(ti.i32, shape=())

@ti.experimental.real_func
def foo(val: ti.i32) -> ti.i32:
x[None] += val
return val

@ti.kernel
def run():
a = foo(40)
foo(2)
assert a == 40

x[None] = 0
run()
assert x[None] == 42


@test_utils.test(arch=[ti.cpu, ti.gpu])
def test_call_expressions():
x = ti.field(ti.i32, shape=())

@ti.experimental.real_func
def foo(val: ti.i32) -> ti.i32:
if x[None] > 10:
x[None] += 1
x[None] += val
return 0

@ti.kernel
def run():
assert foo(15) == 0
assert foo(10) == 0

x[None] = 0
run()
assert x[None] == 26


@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)
Expand Down

0 comments on commit 2c04cb2

Please sign in to comment.