diff --git a/taichi/backends/cpu/aot_module_builder_impl.h b/taichi/backends/cpu/aot_module_builder_impl.h index 1d81fa41d7c2e..039174aa88503 100644 --- a/taichi/backends/cpu/aot_module_builder_impl.h +++ b/taichi/backends/cpu/aot_module_builder_impl.h @@ -9,6 +9,11 @@ namespace lang { namespace cpu { class AotModuleBuilderImpl : public LlvmAotModuleBuilder { + public: + explicit AotModuleBuilderImpl(LlvmProgramImpl *prog) + : LlvmAotModuleBuilder(prog) { + } + private: CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override; }; diff --git a/taichi/backends/cpu/aot_module_loader_impl.cpp b/taichi/backends/cpu/aot_module_loader_impl.cpp index e2ff3b2ecf0f6..16c297dced325 100644 --- a/taichi/backends/cpu/aot_module_loader_impl.cpp +++ b/taichi/backends/cpu/aot_module_loader_impl.cpp @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule { TI_NOT_IMPLEMENTED; return nullptr; } - - std::unique_ptr make_new_field(const std::string &name) override { - TI_NOT_IMPLEMENTED; - return nullptr; - } }; } // namespace diff --git a/taichi/backends/cuda/aot_module_builder_impl.h b/taichi/backends/cuda/aot_module_builder_impl.h index f0fdc74e14f9c..94ac89380d1e0 100644 --- a/taichi/backends/cuda/aot_module_builder_impl.h +++ b/taichi/backends/cuda/aot_module_builder_impl.h @@ -9,6 +9,11 @@ namespace lang { namespace cuda { class AotModuleBuilderImpl : public LlvmAotModuleBuilder { + public: + explicit AotModuleBuilderImpl(LlvmProgramImpl *prog) + : LlvmAotModuleBuilder(prog) { + } + private: CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override; }; diff --git a/taichi/backends/cuda/aot_module_loader_impl.cpp b/taichi/backends/cuda/aot_module_loader_impl.cpp index b08efdc9632da..69bf52d749772 100644 --- a/taichi/backends/cuda/aot_module_loader_impl.cpp +++ b/taichi/backends/cuda/aot_module_loader_impl.cpp @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule { TI_NOT_IMPLEMENTED; return nullptr; } - - std::unique_ptr make_new_field(const std::string &name) override { - TI_NOT_IMPLEMENTED; - return nullptr; - } }; } // namespace diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 1a583cda431b5..f36511cb27b5f 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -326,7 +326,7 @@ void SNode::set_snode_tree_id(int id) { snode_tree_id_ = id; } -int SNode::get_snode_tree_id() { +int SNode::get_snode_tree_id() const { return snode_tree_id_; } diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 8a21721c2a7bc..da7560501d97f 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -354,7 +354,7 @@ class SNode { void set_snode_tree_id(int id); - int get_snode_tree_id(); + int get_snode_tree_id() const; static void reset_counter() { counter = 0; diff --git a/taichi/llvm/llvm_aot_module_builder.cpp b/taichi/llvm/llvm_aot_module_builder.cpp index d23ee5c47c564..664ee933893c9 100644 --- a/taichi/llvm/llvm_aot_module_builder.cpp +++ b/taichi/llvm/llvm_aot_module_builder.cpp @@ -2,6 +2,7 @@ #include #include "taichi/llvm/launch_arg_info.h" +#include "taichi/llvm/llvm_program.h" namespace taichi { namespace lang { @@ -34,5 +35,37 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier, cache_.kernels[identifier] = std::move(kcache); } +void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, + bool is_scalar, + DataType dt, + std::vector shape, + int row_num, + int column_num) { + // Field refers to a leaf node(Place SNode) in a SNodeTree. + // It makes no sense to just serialize the leaf node or its corresponding + // branch. Instead, the minimal unit we have to serialize is the entire + // SNodeTree. Note that SNodeTree's uses snode_tree_id as its identifier, + // rather than the field's name. (multiple fields may end up referring to the + // same SNodeTree) + + // 1. Find snode_tree_id + int snode_tree_id = rep_snode->get_snode_tree_id(); + + // 2. Fetch Cache from the Program + // Kernel compilation is not allowed until all the Fields are finalized, + // so we finished SNodeTree compilation during AOTModuleBuilder construction. + // + // By the time "add_field_per_backend()" is called, + // SNodeTrees should have already been finalized, + // with compiled info stored in LlvmProgramImpl::cache_data_. + TI_ASSERT(prog_ != nullptr); + LlvmOfflineCache::FieldCacheData field_cache = + prog_->get_cached_field(snode_tree_id); + + // 3. Update AOT Cache + cache_.fields[snode_tree_id] = std::move(field_cache); +} + } // namespace lang } // namespace taichi diff --git a/taichi/llvm/llvm_aot_module_builder.h b/taichi/llvm/llvm_aot_module_builder.h index b88133a761783..857f237c4a73c 100644 --- a/taichi/llvm/llvm_aot_module_builder.h +++ b/taichi/llvm/llvm_aot_module_builder.h @@ -9,6 +9,9 @@ namespace lang { class LlvmAotModuleBuilder : public AotModuleBuilder { public: + explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) { + } + void dump(const std::string &output_dir, const std::string &filename) const override; @@ -16,8 +19,17 @@ class LlvmAotModuleBuilder : public AotModuleBuilder { void add_per_backend(const std::string &identifier, Kernel *kernel) override; virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0; + void add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, + bool is_scalar, + DataType dt, + std::vector shape, + int row_num, + int column_num) override; + private: mutable LlvmOfflineCache cache_; + LlvmProgramImpl *prog_ = nullptr; }; } // namespace lang diff --git a/taichi/llvm/llvm_aot_module_loader.cpp b/taichi/llvm/llvm_aot_module_loader.cpp index 5d725927388d7..99ca51f665363 100644 --- a/taichi/llvm/llvm_aot_module_loader.cpp +++ b/taichi/llvm/llvm_aot_module_loader.cpp @@ -17,6 +17,24 @@ class KernelImpl : public aot::Kernel { FunctionType fn_; }; +class FieldImpl : public aot::Field { + public: + explicit FieldImpl(const LlvmOfflineCache::FieldCacheData &field) + : field_(field) { + } + + explicit FieldImpl(LlvmOfflineCache::FieldCacheData &&field) + : field_(std::move(field)) { + } + + LlvmOfflineCache::FieldCacheData get_field() const { + return field_; + } + + private: + LlvmOfflineCache::FieldCacheData field_; +}; + } // namespace LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache( @@ -37,5 +55,42 @@ std::unique_ptr LlvmAotModule::make_new_kernel( return std::make_unique(fn); } +std::unique_ptr LlvmAotModule::make_new_field( + const std::string &name) { + // Check if "name" represents snode_tree_id. + // Avoid using std::atoi due to its poor error handling. + char *end; + int snode_tree_id = static_cast(strtol(name.c_str(), &end, 10 /*base*/)); + + TI_ASSERT(end != name.c_str()); + TI_ASSERT(*end == '\0'); + + // Load FieldCache + LlvmOfflineCache::FieldCacheData loaded; + auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id); + TI_ERROR_IF(!ok, "Failed to load field with id={}", snode_tree_id); + + return std::make_unique(std::move(loaded)); +} + +void finalize_aot_field(aot::Module *aot_module, + aot::Field *aot_field, + uint64 *result_buffer) { + auto *llvm_aot_module = dynamic_cast(aot_module); + auto *aot_field_impl = dynamic_cast(aot_field); + + TI_ASSERT(llvm_aot_module != nullptr); + TI_ASSERT(aot_field_impl != nullptr); + + auto *llvm_prog = llvm_aot_module->get_program(); + const auto &field_cache = aot_field_impl->get_field(); + + int snode_tree_id = field_cache.tree_id; + if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) { + llvm_prog->initialize_llvm_runtime_snodes(field_cache, result_buffer); + llvm_aot_module->set_initialized_snode_tree(snode_tree_id); + } +} + } // namespace lang } // namespace taichi diff --git a/taichi/llvm/llvm_aot_module_loader.h b/taichi/llvm/llvm_aot_module_loader.h index b5e8f527cea67..1e4e093bcfc2c 100644 --- a/taichi/llvm/llvm_aot_module_loader.h +++ b/taichi/llvm/llvm_aot_module_loader.h @@ -6,6 +6,10 @@ namespace taichi { namespace lang { +TI_DLL_EXPORT void finalize_aot_field(aot::Module *aot_module, + aot::Field *aot_field, + uint64 *result_buffer); + class LlvmAotModule : public aot::Module { public: explicit LlvmAotModule(const std::string &module_path, @@ -27,6 +31,18 @@ class LlvmAotModule : public aot::Module { return 0; } + LlvmProgramImpl *const get_program() { + return program_; + } + + void set_initialized_snode_tree(int snode_tree_id) { + initialized_snode_tree_ids.insert(snode_tree_id); + } + + bool is_snode_tree_initialized(int snode_tree_id) { + return initialized_snode_tree_ids.count(snode_tree_id); + } + protected: virtual FunctionType convert_module_to_function( const std::string &name, @@ -38,8 +54,13 @@ class LlvmAotModule : public aot::Module { std::unique_ptr make_new_kernel( const std::string &name) override; + std::unique_ptr make_new_field(const std::string &name) override; + LlvmProgramImpl *const program_{nullptr}; std::unique_ptr cache_reader_{nullptr}; + + // To prevent repeated SNodeTree initialization + std::unordered_set initialized_snode_tree_ids; }; } // namespace lang diff --git a/taichi/llvm/llvm_offline_cache.h b/taichi/llvm/llvm_offline_cache.h index b5403982b10ba..54b7356903eb0 100644 --- a/taichi/llvm/llvm_offline_cache.h +++ b/taichi/llvm/llvm_offline_cache.h @@ -95,7 +95,7 @@ struct LlvmOfflineCache { std::unordered_map kernels; // key = kernel_name - TI_IO_DEF(kernels); + TI_IO_DEF(fields, kernels); }; class LlvmOfflineCacheFileReader { diff --git a/taichi/llvm/llvm_program.cpp b/taichi/llvm/llvm_program.cpp index eea60dad165f7..a805ade265e37 100644 --- a/taichi/llvm/llvm_program.cpp +++ b/taichi/llvm/llvm_program.cpp @@ -273,37 +273,22 @@ std::unique_ptr LlvmProgramImpl::compile_snode_tree_types_impl( } void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) { - compile_snode_tree_types_impl(tree); -} - -static LlvmOfflineCache::FieldCacheData construct_filed_cache_data( - const SNodeTree &tree, - const StructCompiler &struct_compiler) { - LlvmOfflineCache::FieldCacheData ret; - ret.tree_id = tree.id(); - ret.root_id = tree.root()->id; - ret.root_size = struct_compiler.root_size; - - const auto &snodes = struct_compiler.snodes; - for (size_t i = 0; i < snodes.size(); i++) { - LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data; - snode_cache_data.id = snodes[i]->id; - snode_cache_data.type = snodes[i]->type; - snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes; - snode_cache_data.chunk_size = snodes[i]->chunk_size; - - ret.snode_metas.emplace_back(std::move(snode_cache_data)); - } + auto struct_compiler = compile_snode_tree_types_impl(tree); + int snode_tree_id = tree->id(); + int root_id = tree->root()->id; - return ret; + // Add compiled result to Cache + cache_field(snode_tree_id, root_id, *struct_compiler); } void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) { - auto struct_compiler = compile_snode_tree_types_impl(tree); + compile_snode_tree_types(tree); + int snode_tree_id = tree->id(); - auto field_cache_data = construct_filed_cache_data(*tree, *struct_compiler); - initialize_llvm_runtime_snodes(field_cache_data, result_buffer); + TI_ASSERT(cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end()); + initialize_llvm_runtime_snodes(cache_data_.fields.at(snode_tree_id), + result_buffer); } uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) { @@ -365,12 +350,12 @@ void LlvmProgramImpl::print_list_manager_info(void *list_manager, std::unique_ptr LlvmProgramImpl::make_aot_module_builder() { if (config->arch == Arch::x64 || config->arch == Arch::arm64) { - return std::make_unique(); + return std::make_unique(this); } #if defined(TI_WITH_CUDA) if (config->arch == Arch::cuda) { - return std::make_unique(); + return std::make_unique(this); } #endif @@ -701,6 +686,33 @@ void LlvmProgramImpl::cache_kernel( kernel_cache.offloaded_task_list = std::move(offloaded_task_list); } +void LlvmProgramImpl::cache_field(int snode_tree_id, + int root_id, + const StructCompiler &struct_compiler) { + if (cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end()) { + // [TODO] check and update the Cache, instead of simply return. + return; + } + + LlvmOfflineCache::FieldCacheData ret; + ret.tree_id = snode_tree_id; + ret.root_id = root_id; + ret.root_size = struct_compiler.root_size; + + const auto &snodes = struct_compiler.snodes; + for (size_t i = 0; i < snodes.size(); i++) { + LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data; + snode_cache_data.id = snodes[i]->id; + snode_cache_data.type = snodes[i]->type; + snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes; + snode_cache_data.chunk_size = snodes[i]->chunk_size; + + ret.snode_metas.emplace_back(std::move(snode_cache_data)); + } + + cache_data_.fields[snode_tree_id] = std::move(ret); +} + void LlvmProgramImpl::dump_cache_data_to_disk() { if (config->offline_cache && !cache_data_.kernels.empty()) { LlvmOfflineCacheFileWriter writer{}; diff --git a/taichi/llvm/llvm_program.h b/taichi/llvm/llvm_program.h index 69378ee660bf1..2eec64dd8e7bd 100644 --- a/taichi/llvm/llvm_program.h +++ b/taichi/llvm/llvm_program.h @@ -118,10 +118,27 @@ class LlvmProgramImpl : public ProgramImpl { std::vector &&offloaded_task_list); + void cache_field(int snode_tree_id, + int root_id, + const StructCompiler &struct_compiler); + + LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const { + TI_ASSERT(cache_data_.fields.find(snode_tree_id) != + cache_data_.fields.end()); + return cache_data_.fields.at(snode_tree_id); + } + Device *get_compute_device() override { return device_.get(); } + /** + * Initializes the SNodes for LLVM based backends. + */ + void initialize_llvm_runtime_snodes( + const LlvmOfflineCache::FieldCacheData &field_cache_data, + uint64 *result_buffer); + private: std::unique_ptr clone_struct_compiler_initial_context( bool has_multiple_snode_trees, @@ -129,12 +146,6 @@ class LlvmProgramImpl : public ProgramImpl { std::unique_ptr compile_snode_tree_types_impl( SNodeTree *tree); - /** - * Initializes the SNodes for LLVM based backends. - */ - void initialize_llvm_runtime_snodes( - const LlvmOfflineCache::FieldCacheData &field_cache_data, - uint64 *result_buffer); uint64 fetch_result_uint64(int i, uint64 *result_buffer);