-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[llvm] [aot] CUDA-AOT PR #2: Implemented AOTModuleLoader & AOTModuleB…
…uilder for LLVM-CUDA backend (#5087) * [llvm] [aot] Add LLVM-CPU AOT tests * Refactored AOT test framework * Fixed minor issue * Enabled LLVM CPU-AOT for arm64 architecture * Added aot unit tests programming guide * [llvm] [aot] CUDA-AOT PR #2: Implemented AOT Module Loader for LLVM-CUDA backend * Fixed typo * Fixed minor issue * Refactored AOT test framework * [llvm] [aot] Add LLVM-CUDA AOT tests * Added cuda device availability check
- Loading branch information
1 parent
8ab9b9f
commit 8e8a792
Showing
16 changed files
with
296 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include "taichi/backends/cuda/aot_module_builder_impl.h" | ||
|
||
#include <algorithm> | ||
|
||
#include "taichi/backends/cuda/codegen_cuda.h" | ||
#include "taichi/llvm/launch_arg_info.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
namespace cuda { | ||
|
||
CodeGenLLVM::CompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) { | ||
auto cgen = CodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr); | ||
return cgen->run_compilation(); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#pragma once | ||
|
||
#include "taichi/aot/module_builder.h" | ||
#include "taichi/llvm/llvm_offline_cache.h" | ||
#include "taichi/llvm/llvm_aot_module_builder.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
namespace cuda { | ||
|
||
class AotModuleBuilderImpl : public LlvmAotModuleBuilder { | ||
private: | ||
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override; | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#include "taichi/backends/cuda/aot_module_loader_impl.h" | ||
#include "taichi/llvm/llvm_aot_module_loader.h" | ||
|
||
#include "taichi/llvm/llvm_offline_cache.h" | ||
#include "taichi/llvm/llvm_program.h" | ||
#include "taichi/backends/cuda/codegen_cuda.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
namespace { | ||
|
||
class AotModuleImpl : public LlvmAotModule { | ||
public: | ||
explicit AotModuleImpl(const cuda::AotModuleParams ¶ms) | ||
: LlvmAotModule(params.module_path, params.program) { | ||
} | ||
|
||
private: | ||
FunctionType convert_module_to_function( | ||
const std::string &name, | ||
LlvmOfflineCache::KernelCacheData &&loaded) override { | ||
Arch arch = program_->config->arch; | ||
TI_ASSERT(arch == Arch::cuda); | ||
auto *tlctx = program_->get_llvm_context(arch); | ||
|
||
const auto &tasks = loaded.offloaded_task_list; | ||
std::vector<OffloadedTask> offloaded_tasks; | ||
offloaded_tasks.reserve(tasks.size()); | ||
for (const auto &t : tasks) { | ||
OffloadedTask ot{/*codegen=*/nullptr}; | ||
ot.name = t.name; | ||
ot.block_dim = t.block_dim; | ||
ot.grid_dim = t.grid_dim; | ||
offloaded_tasks.push_back(std::move(ot)); | ||
} | ||
|
||
CUDAModuleToFunctionConverter converter{tlctx, program_}; | ||
return converter.convert(name, loaded.args, std::move(loaded.owned_module), | ||
std::move(offloaded_tasks)); | ||
} | ||
|
||
std::unique_ptr<aot::KernelTemplate> make_new_kernel_template( | ||
const std::string &name) override { | ||
TI_NOT_IMPLEMENTED; | ||
return nullptr; | ||
} | ||
|
||
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override { | ||
TI_NOT_IMPLEMENTED; | ||
return nullptr; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace cuda { | ||
|
||
std::unique_ptr<aot::Module> make_aot_module(std::any mod_params) { | ||
auto mod = std::make_unique<AotModuleImpl>( | ||
std::any_cast<const AotModuleParams &>(mod_params)); | ||
return mod; | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include "taichi/aot/module_loader.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
class LlvmProgramImpl; | ||
|
||
namespace cuda { | ||
|
||
struct TI_DLL_EXPORT AotModuleParams { | ||
std::string module_path; | ||
LlvmProgramImpl *program{nullptr}; | ||
}; | ||
|
||
TI_DLL_EXPORT std::unique_ptr<aot::Module> make_aot_module(std::any mod_params); | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include "taichi/llvm/llvm_aot_module_builder.h" | ||
|
||
#include <algorithm> | ||
#include "taichi/llvm/launch_arg_info.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
void LlvmAotModuleBuilder::dump(const std::string &output_dir, | ||
const std::string &filename) const { | ||
LlvmOfflineCacheFileWriter writer; | ||
writer.set_data(std::move(cache_)); | ||
writer.dump(output_dir); | ||
} | ||
|
||
void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier, | ||
Kernel *kernel) { | ||
auto compiled = compile_kernel(kernel); | ||
LlvmOfflineCache::KernelCacheData kcache; | ||
kcache.kernel_key = identifier; | ||
kcache.module = compiled.llvm_module.get(); | ||
kcache.owned_module = std::move(compiled.llvm_module); | ||
const auto &tasks = compiled.offloaded_tasks; | ||
kcache.args = infer_launch_args(kernel); | ||
kcache.offloaded_task_list.resize(tasks.size()); | ||
std::transform(tasks.begin(), tasks.end(), kcache.offloaded_task_list.begin(), | ||
[](const auto &t) -> LlvmOfflineCache::OffloadedTaskCacheData { | ||
LlvmOfflineCache::OffloadedTaskCacheData res; | ||
res.name = t.name; | ||
res.block_dim = t.block_dim; | ||
res.grid_dim = t.grid_dim; | ||
return res; | ||
}); | ||
cache_.kernels[identifier] = std::move(kcache); | ||
} | ||
|
||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include "taichi/aot/module_builder.h" | ||
#include "taichi/llvm/llvm_offline_cache.h" | ||
#include "taichi/codegen/codegen_llvm.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
class LlvmAotModuleBuilder : public AotModuleBuilder { | ||
public: | ||
void dump(const std::string &output_dir, | ||
const std::string &filename) const override; | ||
|
||
protected: | ||
void add_per_backend(const std::string &identifier, Kernel *kernel) override; | ||
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0; | ||
|
||
private: | ||
mutable LlvmOfflineCache cache_; | ||
}; | ||
|
||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
|
||
import taichi as ti | ||
|
||
|
||
def compile_aot(): | ||
ti.init(arch=ti.cuda) | ||
|
||
@ti.kernel | ||
def run(base: int, arr: ti.types.ndarray()): | ||
for i in arr: | ||
arr[i] = base + i | ||
|
||
arr = ti.ndarray(int, shape=16) | ||
run(42, arr) | ||
|
||
assert "TAICHI_AOT_FOLDER_PATH" in os.environ.keys() | ||
dir_name = str(os.environ["TAICHI_AOT_FOLDER_PATH"]) | ||
|
||
m = ti.aot.Module(ti.cuda) | ||
m.add_kernel(run, template_args={'arr': arr}) | ||
m.save(dir_name, 'cuda-aot') | ||
|
||
|
||
compile_aot() |
Oops, something went wrong.