diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index fd17ac9db1c67..5f1f22db25489 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,4 +1,3 @@ -import functools import numbers from collections.abc import Iterable @@ -24,34 +23,54 @@ def _gen_swizzles(cls): swizzle_gen = SwizzleGenerator() # https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling - KEYMAP_SET = ['xyzw', 'rgba', 'stpq'] + KEYGROUP_SET = ['xyzw', 'rgba', 'stpq'] + + def make_valid_attribs_checker(key_group): + def check(instance, pattern): + valid_attribs = set(key_group[:instance.n]) + pattern_set = set(pattern) + diff = pattern_set - valid_attribs + if len(diff): + valid_attribs = tuple(sorted(valid_attribs)) + pattern = tuple(pattern) + raise TaichiSyntaxError( + f'vec{instance.n} only has ' + f'attributes={valid_attribs}, got={pattern}') - def add_single_swizzle_attrs(cls): - """Add property getter and setter for a single character in "xyzwrgbastpq". - """ - def prop_getter(index, instance): - return instance(index) + return check + + for key_group in KEYGROUP_SET: + for index, attr in enumerate(key_group): - @python_scope - def prop_setter(index, instance, value): - instance[index] = value + def gen_property(attr, attr_idx, key_group): + checker = make_valid_attribs_checker(key_group) - for key_group in KEYMAP_SET: - for index, key in enumerate(key_group): - prop = property(functools.partial(prop_getter, index), - functools.partial(prop_setter, index)) - setattr(cls, key, prop) + def prop_getter(instance): + checker(instance, attr) + return instance._get_entry_and_read([attr_idx]) + + @python_scope + def prop_setter(instance, value): + checker(instance, attr) + instance[attr_idx] = value - add_single_swizzle_attrs(cls) + return property(prop_getter, prop_setter) - for key_group in KEYMAP_SET: + prop = gen_property(attr, index, key_group) + setattr(cls, attr, prop) + + for key_group in KEYGROUP_SET: sw_patterns = swizzle_gen.generate(key_group, required_length=4) # len=1 accessors are handled specially above sw_patterns = filter(lambda p: len(p) > 1, sw_patterns) for pat in sw_patterns: # Create a function for value capturing def gen_property(pattern, key_group): + checker = make_valid_attribs_checker(key_group) + prop_key = ''.join(pattern) + def prop_getter(instance): + checker(instance, pattern) res = [] for ch in pattern: res.append(instance._get_entry(key_group.index(ch))) @@ -60,7 +79,9 @@ def prop_getter(instance): def prop_setter(instance, value): if len(pattern) != len(value): raise TaichiCompilationError( - 'values does not match the attribute') + f'value len does not match the swizzle pattern={prop_key}' + ) + checker(instance, pattern) for ch, val in zip(pattern, value): if in_python_scope(): instance[key_group.index(ch)] = val @@ -68,7 +89,6 @@ def prop_setter(instance, value): instance(key_group.index(ch))._assign(val) prop = property(prop_getter, prop_setter) - prop_key = ''.join(pattern) return prop_key, prop prop_key, prop = gen_property(pat, key_group) diff --git a/tests/python/test_vector_swizzle.py b/tests/python/test_vector_swizzle.py index 83883f3566b5d..588b3f1372460 100644 --- a/tests/python/test_vector_swizzle.py +++ b/tests/python/test_vector_swizzle.py @@ -1,3 +1,5 @@ +import re + import pytest import taichi as ti @@ -48,6 +50,38 @@ def foo(): foo() +@test_utils.test(debug=True) +def test_vector_swizzle2_taichi(): + @ti.kernel + def foo(): + v = ti.math.vec3(0, 0, 0) + v.brg += 1 + assert all(v.xyz == (1, 1, 1)) + v.x = 1 + v.g = 2 + v.p = 3 + v123 = ti.math.vec3(1, 2, 3) + v231 = ti.math.vec3(2, 3, 1) + v113 = ti.math.vec3(1, 1, 3) + assert all(v == v123) + assert all(v.xyz == v123) + assert all(v.rgb == v123) + assert all(v.stp == v123) + assert all(v.yzx == v231) + assert all(v.gbr == v231) + assert all(v.tps == v231) + assert all(v.xxz == v113) + assert all(v.rrb == v113) + assert all(v.ssp == v113) + v.bgr = v123 + v321 = ti.math.vec3(3, 2, 1) + assert all(v.xyz == v321) + assert all(v.rgb == v321) + assert all(v.stp == v321) + + foo() + + @test_utils.test(debug=True) def test_vector_dtype(): @ti.kernel @@ -59,3 +93,22 @@ def foo(): assert all(b == (1, 2, 3)) foo() + + +@test_utils.test() +def test_vector_invalid_swizzle_patterns(): + a = ti.math.vec2(1, 2) + with pytest.raises(ti.TaichiSyntaxError, + match=re.escape( + "vec2 only has attributes=('x', 'y'), got=('z',)")): + a.z = 3 + with pytest.raises( + ti.TaichiSyntaxError, + match=re.escape( + "vec2 only has attributes=('x', 'y'), got=('x', 'y', 'z')")): + a.xyz = [1, 2, 3] + + with pytest.raises(ti.TaichiCompilationError, + match=re.escape( + "value len does not match the swizzle pattern=xy")): + a.xy = [1, 2, 3]