From 2c04cb2fa14c716f70cf18894edaef76cf74c1bd Mon Sep 17 00:00:00 2001 From: Lin Jiang <90667349+lin-hitonami@users.noreply.github.com> Date: Tue, 8 Mar 2022 14:33:08 +0800 Subject: [PATCH] [llvm] Support real function with single scalar return value (#4452) * [llvm] Support real function with single scalar return value * add comment --- taichi/codegen/codegen_llvm.cpp | 119 +++++++++++++++++--------------- taichi/codegen/codegen_llvm.h | 3 + taichi/program/context.h | 5 ++ taichi/program/kernel.cpp | 1 + taichi/runtime/llvm/runtime.cpp | 5 +- tests/python/test_function.py | 80 +++++++++++---------- 6 files changed, 116 insertions(+), 97 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 40ae77525..241cb1516 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1019,6 +1019,54 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) { create_naive_range_for(for_stmt); } +llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) { + llvm::Type *dest_ty = nullptr; + TI_ASSERT(!type->is()); + if (auto cit = type->cast()) { + if (cit->get_is_signed()) + dest_ty = tlctx->get_data_type(PrimitiveType::i32); + else + dest_ty = tlctx->get_data_type(PrimitiveType::u32); + } else { + dest_ty = tlctx->get_data_type(type); + } + auto dest_bits = dest_ty->getPrimitiveSizeInBits(); + if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) { + // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa + // which doesn't mean anything. + // So we truncate to 32 bits first and then fptrunc to half if applicable + auto truncated = + builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32)); + auto casted = builder->CreateBitCast(truncated, + llvm::Type::getFloatTy(*llvm_context)); + return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context)); + } else { + auto truncated = builder->CreateTrunc( + val, llvm::Type::getIntNTy(*llvm_context, dest_bits)); + + return builder->CreateBitCast(truncated, dest_ty); + } +} + +llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) { + auto intermediate_bits = 0; + if (auto cit = type->cast()) { + intermediate_bits = data_type_bits(cit->get_compute_type()); + } else { + intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits(); + } + llvm::Type *dest_ty = tlctx->get_data_type(); + llvm::Type *intermediate_type = nullptr; + if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) { + val = builder->CreateFPExt(val, tlctx->get_data_type()); + intermediate_type = tlctx->get_data_type(); + } else { + intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits); + } + return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type), + dest_ty); +} + void CodeGenLLVM::visit(ArgLoadStmt *stmt) { auto raw_arg = call(builder.get(), "RuntimeContext_get_args", get_context(), tlctx->get_constant(stmt->arg_id)); @@ -1029,32 +1077,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { - TI_ASSERT(!stmt->ret_type->is()); - if (auto cit = stmt->ret_type->cast()) { - if (cit->get_is_signed()) - dest_ty = tlctx->get_data_type(PrimitiveType::i32); - else - dest_ty = tlctx->get_data_type(PrimitiveType::u32); - } else { - dest_ty = tlctx->get_data_type(stmt->ret_type); - } - auto dest_bits = dest_ty->getPrimitiveSizeInBits(); - if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) { - // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa - // which doesn't mean anything. - // So we truncate to 32 bits first and then fptrunc to half if applicable - auto truncated = builder->CreateTrunc( - raw_arg, llvm::Type::getIntNTy(*llvm_context, 32)); - auto casted = builder->CreateBitCast( - truncated, llvm::Type::getFloatTy(*llvm_context)); - llvm_val[stmt] = - builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context)); - } else { - auto truncated = builder->CreateTrunc( - raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits)); - - llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty); - } + llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type); } } @@ -1067,27 +1090,10 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) { TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value); int idx{0}; for (auto &value : stmt->values) { - auto intermediate_bits = 0; - if (auto cit = value->ret_type->cast()) { - intermediate_bits = data_type_bits(cit->get_compute_type()); - } else { - intermediate_bits = - tlctx->get_data_type(value->ret_type)->getPrimitiveSizeInBits(); - } - llvm::Type *dest_ty = tlctx->get_data_type(); - llvm::Type *intermediate_type = nullptr; - if (llvm_val[value]->getType() == llvm::Type::getHalfTy(*llvm_context)) { - llvm_val[value] = builder->CreateFPExt(llvm_val[value], - tlctx->get_data_type()); - intermediate_type = tlctx->get_data_type(); - } else { - intermediate_type = - llvm::Type::getIntNTy(*llvm_context, intermediate_bits); - } - auto extended = builder->CreateZExt( - builder->CreateBitCast(llvm_val[value], intermediate_type), dest_ty); - create_call("LLVMRuntime_store_result", - {get_runtime(), extended, tlctx->get_constant(idx++)}); + create_call( + "RuntimeContext_store_result", + {get_context(), bitcast_to_u64(llvm_val[value], value->ret_type), + tlctx->get_constant(idx++)}); } } } @@ -2387,17 +2393,22 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) { auto *new_ctx = builder->CreateAlloca(get_runtime_type("RuntimeContext")); call("RuntimeContext_set_runtime", new_ctx, get_runtime()); for (int i = 0; i < stmt->args.size(); i++) { - auto *original = llvm_val[stmt->args[i]]; - int src_bits = original->getType()->getPrimitiveSizeInBits(); - auto *cast = builder->CreateBitCast( - original, llvm::Type::getIntNTy(*llvm_context, src_bits)); auto *val = - builder->CreateZExt(cast, llvm::Type::getInt64Ty(*llvm_context)); + bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type); call("RuntimeContext_set_args", new_ctx, llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val); } - - llvm_val[stmt] = create_call(llvm_func, {new_ctx}); + llvm::Value *result_buffer = nullptr; + if (stmt->ret_type->is() && + !stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + result_buffer = builder->CreateAlloca(tlctx->get_data_type()); + call("RuntimeContext_set_result_buffer", new_ctx, result_buffer); + create_call(llvm_func, {new_ctx}); + auto *ret_val_u64 = builder->CreateLoad(result_buffer); + llvm_val[stmt] = bitcast_from_u64(ret_val_u64, stmt->ret_type); + } else { + create_call(llvm_func, {new_ctx}); + } } TLANG_NAMESPACE_END diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 339b965a6..e32ee3cdc 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -382,6 +382,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(FuncCallStmt *stmt) override; + llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type); + llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); + ~CodeGenLLVM() override = default; }; diff --git a/taichi/program/context.h b/taichi/program/context.h index 4202a68e9..0d6d7eef4 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -24,6 +24,11 @@ struct RuntimeContext { int32 cpu_thread_id; // |is_device_allocation| is true iff args[i] is a DeviceAllocation*. bool is_device_allocation[taichi_max_num_args_total]{false}; + // We move the pointer of result buffer from LLVMRuntime to RuntimeContext + // because each real function need a place to store its result, but + // LLVMRuntime is shared among functions. So we moved the pointer to + // RuntimeContext which each function have one. + uint64 *result_buffer; static constexpr size_t extra_args_size = sizeof(extra_args); diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 265b4d9d7..582ee70c9 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -285,6 +285,7 @@ RuntimeContext &Kernel::LaunchContextBuilder::get_context() { ctx_->runtime = llvm_program_impl->get_llvm_runtime(); } #endif + ctx_->result_buffer = kernel_->program->result_buffer; return *ctx_; } diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 431b69ff0..0be7f410e 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -348,6 +348,7 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val); STRUCT_FIELD_ARRAY(RuntimeContext, args); STRUCT_FIELD(RuntimeContext, runtime); +STRUCT_FIELD(RuntimeContext, result_buffer) int32 RuntimeContext_get_extra_args(RuntimeContext *ctx, int32 i, int32 j) { return ctx->extra_args[i][j]; @@ -696,8 +697,8 @@ struct NodeManager { extern "C" { -void LLVMRuntime_store_result(LLVMRuntime *runtime, u64 ret, u32 idx) { - runtime->set_result(taichi_result_buffer_ret_value_id + idx, ret); +void RuntimeContext_store_result(RuntimeContext *ctx, u64 ret, u32 idx) { + ctx->result_buffer[taichi_result_buffer_ret_value_id + idx] = ret; } void LLVMRuntime_profiler_start(LLVMRuntime *runtime, Ptr kernel_name) { diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 3a28bfe42..adf270833 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -22,47 +22,45 @@ def run(): assert x[None] == 42 -# @test_utils.test(arch=[ti.cpu, ti.gpu]) -# def test_function_with_return(): -# x = ti.field(ti.i32, shape=()) -# -# @ti.experimental.real_func -# def foo(val: ti.i32) -> ti.i32: -# x[None] += val -# return val -# -# @ti.kernel -# def run(): -# a = foo(40) -# foo(2) -# assert a == 40 -# -# x[None] = 0 -# run() -# assert x[None] == 42 -# -# -# @test_utils.test(arch=[ti.cpu, ti.gpu]) -# def test_call_expressions(): -# x = ti.field(ti.i32, shape=()) -# -# @ti.experimental.real_func -# def foo(val: ti.i32) -> ti.i32: -# if x[None] > 10: -# x[None] += 1 -# x[None] += val -# return 0 -# -# @ti.kernel -# def run(): -# assert foo(15) == 0 -# assert foo(10) == 0 -# -# x[None] = 0 -# run() -# assert x[None] == 26 -# -# +@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True) +def test_function_with_return(): + x = ti.field(ti.i32, shape=()) + + @ti.experimental.real_func + def foo(val: ti.i32) -> ti.i32: + x[None] += val + return val + + @ti.kernel + def run(): + a = foo(40) + foo(2) + assert a == 40 + + x[None] = 0 + run() + assert x[None] == 42 + + +@test_utils.test(arch=[ti.cpu, ti.gpu]) +def test_call_expressions(): + x = ti.field(ti.i32, shape=()) + + @ti.experimental.real_func + def foo(val: ti.i32) -> ti.i32: + if x[None] > 10: + x[None] += 1 + x[None] += val + return 0 + + @ti.kernel + def run(): + assert foo(15) == 0 + assert foo(10) == 0 + + x[None] = 0 + run() + assert x[None] == 26 @test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)