Skip to content

Commit

Permalink
[lang] Add better error detection for swizzle patterens (#4860)
Browse files Browse the repository at this point in the history
* [lang] Add better error detection for swizzle patterens

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
k-ye and pre-commit-ci[bot] authored Apr 26, 2022
1 parent baedd61 commit 433b0e3
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 19 deletions.
58 changes: 39 additions & 19 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import numbers
from collections.abc import Iterable

Expand All @@ -24,34 +23,54 @@
def _gen_swizzles(cls):
swizzle_gen = SwizzleGenerator()
# https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
KEYMAP_SET = ['xyzw', 'rgba', 'stpq']
KEYGROUP_SET = ['xyzw', 'rgba', 'stpq']

def make_valid_attribs_checker(key_group):
def check(instance, pattern):
valid_attribs = set(key_group[:instance.n])
pattern_set = set(pattern)
diff = pattern_set - valid_attribs
if len(diff):
valid_attribs = tuple(sorted(valid_attribs))
pattern = tuple(pattern)
raise TaichiSyntaxError(
f'vec{instance.n} only has '
f'attributes={valid_attribs}, got={pattern}')

def add_single_swizzle_attrs(cls):
"""Add property getter and setter for a single character in "xyzwrgbastpq".
"""
def prop_getter(index, instance):
return instance(index)
return check

for key_group in KEYGROUP_SET:
for index, attr in enumerate(key_group):

@python_scope
def prop_setter(index, instance, value):
instance[index] = value
def gen_property(attr, attr_idx, key_group):
checker = make_valid_attribs_checker(key_group)

for key_group in KEYMAP_SET:
for index, key in enumerate(key_group):
prop = property(functools.partial(prop_getter, index),
functools.partial(prop_setter, index))
setattr(cls, key, prop)
def prop_getter(instance):
checker(instance, attr)
return instance._get_entry_and_read([attr_idx])

@python_scope
def prop_setter(instance, value):
checker(instance, attr)
instance[attr_idx] = value

add_single_swizzle_attrs(cls)
return property(prop_getter, prop_setter)

for key_group in KEYMAP_SET:
prop = gen_property(attr, index, key_group)
setattr(cls, attr, prop)

for key_group in KEYGROUP_SET:
sw_patterns = swizzle_gen.generate(key_group, required_length=4)
# len=1 accessors are handled specially above
sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
for pat in sw_patterns:
# Create a function for value capturing
def gen_property(pattern, key_group):
checker = make_valid_attribs_checker(key_group)
prop_key = ''.join(pattern)

def prop_getter(instance):
checker(instance, pattern)
res = []
for ch in pattern:
res.append(instance._get_entry(key_group.index(ch)))
Expand All @@ -60,15 +79,16 @@ def prop_getter(instance):
def prop_setter(instance, value):
if len(pattern) != len(value):
raise TaichiCompilationError(
'values does not match the attribute')
f'value len does not match the swizzle pattern={prop_key}'
)
checker(instance, pattern)
for ch, val in zip(pattern, value):
if in_python_scope():
instance[key_group.index(ch)] = val
else:
instance(key_group.index(ch))._assign(val)

prop = property(prop_getter, prop_setter)
prop_key = ''.join(pattern)
return prop_key, prop

prop_key, prop = gen_property(pat, key_group)
Expand Down
53 changes: 53 additions & 0 deletions tests/python/test_vector_swizzle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest

import taichi as ti
Expand Down Expand Up @@ -48,6 +50,38 @@ def foo():
foo()


@test_utils.test(debug=True)
def test_vector_swizzle2_taichi():
@ti.kernel
def foo():
v = ti.math.vec3(0, 0, 0)
v.brg += 1
assert all(v.xyz == (1, 1, 1))
v.x = 1
v.g = 2
v.p = 3
v123 = ti.math.vec3(1, 2, 3)
v231 = ti.math.vec3(2, 3, 1)
v113 = ti.math.vec3(1, 1, 3)
assert all(v == v123)
assert all(v.xyz == v123)
assert all(v.rgb == v123)
assert all(v.stp == v123)
assert all(v.yzx == v231)
assert all(v.gbr == v231)
assert all(v.tps == v231)
assert all(v.xxz == v113)
assert all(v.rrb == v113)
assert all(v.ssp == v113)
v.bgr = v123
v321 = ti.math.vec3(3, 2, 1)
assert all(v.xyz == v321)
assert all(v.rgb == v321)
assert all(v.stp == v321)

foo()


@test_utils.test(debug=True)
def test_vector_dtype():
@ti.kernel
Expand All @@ -59,3 +93,22 @@ def foo():
assert all(b == (1, 2, 3))

foo()


@test_utils.test()
def test_vector_invalid_swizzle_patterns():
a = ti.math.vec2(1, 2)
with pytest.raises(ti.TaichiSyntaxError,
match=re.escape(
"vec2 only has attributes=('x', 'y'), got=('z',)")):
a.z = 3
with pytest.raises(
ti.TaichiSyntaxError,
match=re.escape(
"vec2 only has attributes=('x', 'y'), got=('x', 'y', 'z')")):
a.xyz = [1, 2, 3]

with pytest.raises(ti.TaichiCompilationError,
match=re.escape(
"value len does not match the swizzle pattern=xy")):
a.xy = [1, 2, 3]

0 comments on commit 433b0e3

Please sign in to comment.