diff --git a/docs/cover-in-ci.lst b/docs/cover-in-ci.lst index 28ab1dbae1318..f40b6e00f706a 100644 --- a/docs/cover-in-ci.lst +++ b/docs/cover-in-ci.lst @@ -1,2 +1,3 @@ docs/lang/articles/basic docs/lang/articles/advanced +docs/lang/articles/kernels diff --git a/docs/lang/articles/kernels/kernel_function.md b/docs/lang/articles/kernels/kernel_function.md index 7231bbb44832b..0d57c689f16c4 100644 --- a/docs/lang/articles/kernels/kernel_function.md +++ b/docs/lang/articles/kernels/kernel_function.md @@ -135,7 +135,23 @@ print(x) # Prints [5, 7, 9] ### Return value -In Taichi, a kernel can have at most one return value, which can be a scalar, `ti.Matrix`, or `ti.Vector`. Here are the rules to follow when defining the return value of a kernel: +In Taichi, a kernel is allowed to have a maximum of one return value, which could either be a scalar, `ti.Matrix`, or `ti.Vector`. +Moreover, in the LLVM-based backends (CPU and CUDA backends), a return value could also be a `ti.Struct`. + +Here is an example of a kernel that returns a ti.Struct: + +```python +s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16) +s1 = ti.types.struct(a=ti.f32, b=s0) + +@ti.kernel +def foo() -> s1: + return s1(a=1, b=s0(a=ti.math.vec3(100, 0.2, 3), b=1)) + +print(foo()) # {'a': 1.0, 'b': {'a': [100.0, 0.2, 3.0], 'b': 1}} +``` + +When defining the return value of a kernel in Taichi, it is important to follow these rules: - Use type hint to specify the return value of a kernel. - Make sure that you have at most one return value in a kernel. @@ -276,14 +292,14 @@ Return values of a Taichi function can be scalars, `ti.Matrix`, `ti.Vector`, `ti ## A recap: Taichi kernel vs. Taichi function -| | **Kernel** | **Taichi Function** | -| ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| Call scope | Python scope | Taichi scope | -| Type hint arguments | Mandatory | Recommended | -| Type hint return values | Mandatory | Recommended | -| Return type | | | -| Maximum number of elements in arguments | | Unlimited | -| Maximum number of return values in a return statement | 1 | Unlimited | +| | **Kernel** | **Taichi Function** | +| ----------------------------------------------------- |-------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ | +| Call scope | Python scope | Taichi scope | +| Type hint arguments | Mandatory | Recommended | +| Type hint return values | Mandatory | Recommended | +| Return type | | | +| Maximum number of elements in arguments | | Unlimited | +| Maximum number of return values in a return statement | 1 | Unlimited | ## Key terms diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 0f6603c44d7b9..badb714ad74e9 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -902,11 +902,11 @@ def construct_kernel_ret(self, launch_ctx, ret_type, index=()): ] if isinstance(ret_type, CompoundType): return ret_type.from_kernel_struct_ret(launch_ctx, index) - if id(ret_type) in primitive_types.integer_type_ids: + if ret_type in primitive_types.integer_types: if is_signed(cook_dtype(ret_type)): return launch_ctx.get_struct_ret_int(index) return launch_ctx.get_struct_ret_uint(index) - if id(ret_type) in primitive_types.real_type_ids: + if ret_type in primitive_types.real_types: return launch_ctx.get_struct_ret_float(index) raise TaichiRuntimeTypeError(f"Invalid return type on index={index}") diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 3ef23b01e8d19..75519279e6dfc 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1511,12 +1511,12 @@ def from_real_func_ret(self, func_ret, ret_index=()): ]) def from_kernel_struct_ret(self, launch_ctx, ret_index=()): - if id(self.dtype) in primitive_types.integer_type_ids: + if self.dtype in primitive_types.integer_types: if is_signed(cook_dtype(self.dtype)): get_ret_func = launch_ctx.get_struct_ret_int else: get_ret_func = launch_ctx.get_struct_ret_uint - elif id(self.dtype) in primitive_types.real_type_ids: + elif self.dtype in primitive_types.real_types: get_ret_func = launch_ctx.get_struct_ret_float else: raise TaichiRuntimeTypeError( diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index c1e50af8eb713..756fb16384a01 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -746,14 +746,14 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()): d[name] = dtype.from_kernel_struct_ret(launch_ctx, ret_index + (index, )) else: - if id(dtype) in primitive_types.integer_type_ids: + if dtype in primitive_types.integer_types: if is_signed(cook_dtype(dtype)): d[name] = launch_ctx.get_struct_ret_int(ret_index + (index, )) else: d[name] = launch_ctx.get_struct_ret_uint(ret_index + (index, )) - elif id(dtype) in primitive_types.real_type_ids: + elif dtype in primitive_types.real_types: d[name] = launch_ctx.get_struct_ret_float(ret_index + (index, )) else: diff --git a/tests/python/test_return.py b/tests/python/test_return.py index 5e8c94f41c429..112cf9489ba40 100644 --- a/tests/python/test_return.py +++ b/tests/python/test_return.py @@ -1,4 +1,5 @@ import pytest +from pytest import approx import taichi as ti from tests import test_utils @@ -181,3 +182,20 @@ def foo() -> ti.types.vector(2, ti.u64): return ti.Vector([ti.u64(2**64 - 1), ti.u64(2**64 - 1)]) assert (foo()[0] == 2**64 - 1) + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_struct_ret_with_matrix(): + s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16) + s1 = ti.types.struct(a=ti.f32, b=s0) + + @ti.kernel + def foo() -> s1: + return s1(a=1, b=s0(a=ti.math.vec3([100, 0.2, 3]), b=65537)) + + ret = foo() + assert (ret.a == approx(1)) + assert (ret.b.a[0] == approx(100)) + assert (ret.b.a[1] == approx(0.2)) + assert (ret.b.a[2] == approx(3)) + assert (ret.b.b == 1)