diff --git a/docs/cover-in-ci.lst b/docs/cover-in-ci.lst
index 28ab1dbae1318..f40b6e00f706a 100644
--- a/docs/cover-in-ci.lst
+++ b/docs/cover-in-ci.lst
@@ -1,2 +1,3 @@
docs/lang/articles/basic
docs/lang/articles/advanced
+docs/lang/articles/kernels
diff --git a/docs/lang/articles/kernels/kernel_function.md b/docs/lang/articles/kernels/kernel_function.md
index 7231bbb44832b..0d57c689f16c4 100644
--- a/docs/lang/articles/kernels/kernel_function.md
+++ b/docs/lang/articles/kernels/kernel_function.md
@@ -135,7 +135,23 @@ print(x) # Prints [5, 7, 9]
### Return value
-In Taichi, a kernel can have at most one return value, which can be a scalar, `ti.Matrix`, or `ti.Vector`. Here are the rules to follow when defining the return value of a kernel:
+In Taichi, a kernel is allowed to have a maximum of one return value, which could either be a scalar, `ti.Matrix`, or `ti.Vector`.
+Moreover, in the LLVM-based backends (CPU and CUDA backends), a return value could also be a `ti.Struct`.
+
+Here is an example of a kernel that returns a ti.Struct:
+
+```python
+s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
+s1 = ti.types.struct(a=ti.f32, b=s0)
+
+@ti.kernel
+def foo() -> s1:
+ return s1(a=1, b=s0(a=ti.math.vec3(100, 0.2, 3), b=1))
+
+print(foo()) # {'a': 1.0, 'b': {'a': [100.0, 0.2, 3.0], 'b': 1}}
+```
+
+When defining the return value of a kernel in Taichi, it is important to follow these rules:
- Use type hint to specify the return value of a kernel.
- Make sure that you have at most one return value in a kernel.
@@ -276,14 +292,14 @@ Return values of a Taichi function can be scalars, `ti.Matrix`, `ti.Vector`, `ti
## A recap: Taichi kernel vs. Taichi function
-| | **Kernel** | **Taichi Function** |
-| ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
-| Call scope | Python scope | Taichi scope |
-| Type hint arguments | Mandatory | Recommended |
-| Type hint return values | Mandatory | Recommended |
-| Return type |
- Scalar
- `ti.Vector`
- `ti.Matrix`
| - Scalar
- `ti.Vector`
- `ti.Matrix`
- `ti.Struct`
- ...
|
-| Maximum number of elements in arguments | - 32 (OpenGL)
- 64 (otherwise)
| Unlimited |
-| Maximum number of return values in a return statement | 1 | Unlimited |
+| | **Kernel** | **Taichi Function** |
+| ----------------------------------------------------- |-------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ |
+| Call scope | Python scope | Taichi scope |
+| Type hint arguments | Mandatory | Recommended |
+| Type hint return values | Mandatory | Recommended |
+| Return type | - Scalar
- `ti.Vector`
- `ti.Matrix`
- `ti.Struct`(Only on LLVM-based backends)
| - Scalar
- `ti.Vector`
- `ti.Matrix`
- `ti.Struct`
- ...
|
+| Maximum number of elements in arguments | - 32 (OpenGL)
- 64 (otherwise)
| Unlimited |
+| Maximum number of return values in a return statement | 1 | Unlimited |
## Key terms
diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py
index 0f6603c44d7b9..badb714ad74e9 100644
--- a/python/taichi/lang/kernel_impl.py
+++ b/python/taichi/lang/kernel_impl.py
@@ -902,11 +902,11 @@ def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
]
if isinstance(ret_type, CompoundType):
return ret_type.from_kernel_struct_ret(launch_ctx, index)
- if id(ret_type) in primitive_types.integer_type_ids:
+ if ret_type in primitive_types.integer_types:
if is_signed(cook_dtype(ret_type)):
return launch_ctx.get_struct_ret_int(index)
return launch_ctx.get_struct_ret_uint(index)
- if id(ret_type) in primitive_types.real_type_ids:
+ if ret_type in primitive_types.real_types:
return launch_ctx.get_struct_ret_float(index)
raise TaichiRuntimeTypeError(f"Invalid return type on index={index}")
diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py
index 3ef23b01e8d19..75519279e6dfc 100644
--- a/python/taichi/lang/matrix.py
+++ b/python/taichi/lang/matrix.py
@@ -1511,12 +1511,12 @@ def from_real_func_ret(self, func_ret, ret_index=()):
])
def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
- if id(self.dtype) in primitive_types.integer_type_ids:
+ if self.dtype in primitive_types.integer_types:
if is_signed(cook_dtype(self.dtype)):
get_ret_func = launch_ctx.get_struct_ret_int
else:
get_ret_func = launch_ctx.get_struct_ret_uint
- elif id(self.dtype) in primitive_types.real_type_ids:
+ elif self.dtype in primitive_types.real_types:
get_ret_func = launch_ctx.get_struct_ret_float
else:
raise TaichiRuntimeTypeError(
diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py
index c1e50af8eb713..756fb16384a01 100644
--- a/python/taichi/lang/struct.py
+++ b/python/taichi/lang/struct.py
@@ -746,14 +746,14 @@ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
d[name] = dtype.from_kernel_struct_ret(launch_ctx,
ret_index + (index, ))
else:
- if id(dtype) in primitive_types.integer_type_ids:
+ if dtype in primitive_types.integer_types:
if is_signed(cook_dtype(dtype)):
d[name] = launch_ctx.get_struct_ret_int(ret_index +
(index, ))
else:
d[name] = launch_ctx.get_struct_ret_uint(ret_index +
(index, ))
- elif id(dtype) in primitive_types.real_type_ids:
+ elif dtype in primitive_types.real_types:
d[name] = launch_ctx.get_struct_ret_float(ret_index +
(index, ))
else:
diff --git a/tests/python/test_return.py b/tests/python/test_return.py
index 5e8c94f41c429..112cf9489ba40 100644
--- a/tests/python/test_return.py
+++ b/tests/python/test_return.py
@@ -1,4 +1,5 @@
import pytest
+from pytest import approx
import taichi as ti
from tests import test_utils
@@ -181,3 +182,20 @@ def foo() -> ti.types.vector(2, ti.u64):
return ti.Vector([ti.u64(2**64 - 1), ti.u64(2**64 - 1)])
assert (foo()[0] == 2**64 - 1)
+
+
+@test_utils.test(arch=[ti.cpu, ti.cuda])
+def test_struct_ret_with_matrix():
+ s0 = ti.types.struct(a=ti.math.vec3, b=ti.i16)
+ s1 = ti.types.struct(a=ti.f32, b=s0)
+
+ @ti.kernel
+ def foo() -> s1:
+ return s1(a=1, b=s0(a=ti.math.vec3([100, 0.2, 3]), b=65537))
+
+ ret = foo()
+ assert (ret.a == approx(1))
+ assert (ret.b.a[0] == approx(100))
+ assert (ret.b.a[1] == approx(0.2))
+ assert (ret.b.a[2] == approx(3))
+ assert (ret.b.b == 1)