diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index e16d7fe1a..352bc7314 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -591,16 +591,39 @@ def inner(operands): return inner + @staticmethod + def build_static_short_circuit_and(operands): + for operand in operands: + if not operand.ptr: + return operand.ptr + return operands[-1].ptr + + @staticmethod + def build_static_short_circuit_or(operands): + for operand in operands: + if operand.ptr: + return operand.ptr + return operands[-1].ptr + @staticmethod def build_BoolOp(ctx, node): build_stmts(ctx, node.values) - ops = { - ast.And: ASTTransformer.build_short_circuit_and, - ast.Or: ASTTransformer.build_short_circuit_or, - } if impl.get_runtime().short_circuit_operators else { - ast.And: ASTTransformer.build_normal_bool_op(ti_ops.logical_and), - ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or), - } + if ctx.is_in_static_scope: + ops = { + ast.And: ASTTransformer.build_static_short_circuit_and, + ast.Or: ASTTransformer.build_static_short_circuit_or, + } + elif impl.get_runtime().short_circuit_operators: + ops = { + ast.And: ASTTransformer.build_short_circuit_and, + ast.Or: ASTTransformer.build_short_circuit_or, + } + else: + ops = { + ast.And: + ASTTransformer.build_normal_bool_op(ti_ops.logical_and), + ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or), + } op = ops.get(type(node.op)) node.ptr = op(node.values) return node.ptr diff --git a/tests/python/test_short_circuit.py b/tests/python/test_bool_op.py similarity index 75% rename from tests/python/test_short_circuit.py rename to tests/python/test_bool_op.py index 40e45998e..5a85de68d 100644 --- a/tests/python/test_short_circuit.py +++ b/tests/python/test_bool_op.py @@ -47,3 +47,21 @@ def func() -> ti.i32: return False or True assert func() == 1 + + +@ti.test(debug=True) +def test_static_or(): + @ti.kernel + def func() -> ti.i32: + return ti.static(0 or 3 or 5) + + assert func() == 3 + + +@ti.test(debug=True) +def test_static_and(): + @ti.kernel + def func() -> ti.i32: + return ti.static(5 and 2 and 0) + + assert func() == 0