Skip to content

Commit

Permalink
[perf] [refactor] Reduce kernel launch context construction overhead …
Browse files Browse the repository at this point in the history
…(#3947)

* Set external array launch context in C++ scope

* Shorten execution path for ndarray in kernel_impl

* Fix small problem on shape argument.

* Auto Format

* Fix naming and int ptr type convention

* Add another args set method to reduce launch overhead.

* Also check the annotations in the Ndarray shortcut.

* Fix the ndarray_use_torch access method after merge with master branch.

* Auto Format

* Revert the shortcut loop in order to respect original branch logics.

* Code formatting.

* Revise func names and argument lists.

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
turbo0628 and taichi-gardener authored Jan 6, 2022
1 parent 551af9c commit ef6237a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
24 changes: 6 additions & 18 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,17 +515,13 @@ def func__(*args):
tmp = np.ascontiguousarray(v)
# Purpose: DO NOT GC |tmp|!
tmps.append(tmp)
launch_ctx.set_arg_external_array(
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.ctypes.data),
tmp.nbytes, False)
tmp.nbytes, v.shape)
elif is_ndarray and not impl.get_runtime(
).ndarray_use_torch:
# Use ndarray's own memory allocator
tmp = v
launch_ctx.set_arg_external_array(
actual_argument_slot,
int(tmp.device_allocation_ptr()),
tmp.element_size() * tmp.nelement(), True)
launch_ctx.set_arg_ndarray(actual_argument_slot, v)
else:

def get_call_back(u, v):
Expand Down Expand Up @@ -560,18 +556,10 @@ def call_back():
gpu_v = v.cuda()
tmp = gpu_v
callbacks.append(get_call_back(v, gpu_v))
launch_ctx.set_arg_external_array(
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.data_ptr()),
tmp.element_size() * tmp.nelement(), False)

shape = v.shape
max_num_indices = _ti_core.get_max_num_indices()
assert len(
shape
) <= max_num_indices, f"External array cannot have > {max_num_indices} indices"
for ii, s in enumerate(shape):
launch_ctx.set_extra_arg_int(actual_argument_slot, ii,
s)
tmp.element_size() * tmp.nelement(), v.shape)

elif isinstance(needed, MatrixType):
if id(needed.dtype) in primitive_types.real_type_ids:
for a in range(needed.n):
Expand Down
27 changes: 26 additions & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {

void Kernel::LaunchContextBuilder::set_arg_external_array(
int arg_id,
uint64 ptr,
uintptr_t ptr,
uint64 size,
bool is_device_allocation) {
TI_ASSERT_INFO(
Expand All @@ -266,6 +266,31 @@ void Kernel::LaunchContextBuilder::set_arg_external_array(
ctx_->set_device_allocation(arg_id, is_device_allocation);
}

void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape) {
this->set_arg_external_array(arg_id, ptr, size, false);
TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
for (uint64 i = 0; i < shape.size(); ++i) {
this->set_extra_arg_int(arg_id, i, shape[i]);
}
}

void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id,
const Ndarray &arr) {
intptr_t ptr = arr.get_device_allocation_ptr_as_int();
uint64 arr_size = arr.get_element_size() * arr.get_nelement();
this->set_arg_external_array(arg_id, ptr, arr_size, true);
TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
for (uint64 i = 0; i < arr.shape.size(); ++i) {
this->set_extra_arg_int(arg_id, i, arr.shape[i]);
}
}

void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) {
TI_ASSERT_INFO(!kernel_->args[arg_id].is_array,
"Assigning scalar value to external (numpy) array argument is "
Expand Down
10 changes: 9 additions & 1 deletion taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/ir/ir.h"
#include "taichi/program/arch.h"
#include "taichi/program/callable.h"
#include "taichi/program/ndarray.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -37,10 +38,17 @@ class Kernel : public Callable {
void set_extra_arg_int(int i, int j, int32 d);

void set_arg_external_array(int arg_id,
uint64 ptr,
uintptr_t ptr,
uint64 size,
bool is_device_allocation);

void set_arg_external_array_with_shape(int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape);

void set_arg_ndarray(int arg_id, const Ndarray &arr);

// Sets the |arg_id|-th arg in the context to the bits stored in |d|.
// This ignores the underlying kernel's |arg_id|-th arg type.
void set_arg_raw(int arg_id, uint64 d);
Expand Down
3 changes: 3 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ void export_lang(py::module &m) {
.def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float)
.def("set_arg_external_array",
&Kernel::LaunchContextBuilder::set_arg_external_array)
.def("set_arg_external_array_with_shape",
&Kernel::LaunchContextBuilder::set_arg_external_array_with_shape)
.def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray)
.def("set_extra_arg_int",
&Kernel::LaunchContextBuilder::set_extra_arg_int);

Expand Down

0 comments on commit ef6237a

Please sign in to comment.