diff --git a/tests/parser/exceptions/test_constancy_exception.py b/tests/parser/exceptions/test_constancy_exception.py index 4bd0b4fcb9..0cecb12ca7 100644 --- a/tests/parser/exceptions/test_constancy_exception.py +++ b/tests/parser/exceptions/test_constancy_exception.py @@ -48,17 +48,6 @@ def foo() -> int128: def foo() -> int128: x: address = create_minimal_proxy_to(0x1234567890123456789012345678901234567890, value=9) return 5""", - # test constancy in range expressions - """ -glob: int128 -@internal -def foo() -> int128: - self.glob += 1 - return 5 -@external -def bar(): - for i in range(self.foo(), self.foo() + 1): - pass""", """ glob: int128 @internal @@ -120,6 +109,18 @@ def foo(f: Foo) -> Foo: f.a[1] = [0, 1] return f """, + # test constancy in range expressions + """ +glob: int128 +@internal +def foo() -> int128: + self.glob += 1 + return 5 +@external +def bar(): + for i in range(self.foo(), self.foo() + 1): + pass + """, ], ) def test_immutability_violations(bad_code): diff --git a/tests/parser/syntax/test_for_range.py b/tests/parser/syntax/test_for_range.py index b2a9491058..81f4de82b3 100644 --- a/tests/parser/syntax/test_for_range.py +++ b/tests/parser/syntax/test_for_range.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import StructureException +from vyper.exceptions import ImmutableViolation, StructureException fail_list = [ ( @@ -12,7 +12,115 @@ def foo(): pass """, StructureException, - ) + ), + ( + """ +interface A: + def foo()-> uint256: nonpayable + +@external +def bar(x:address): + a: A = A(x) + for i in range(a.foo(), bound=12): + pass + """, + ImmutableViolation, + ), + ( + """ +interface A: + def foo()-> uint256: nonpayable + +@external +def bar(x:address): + a: A = A(x) + for i in range(max(a.foo(), 123), bound=12): + pass + """, + ImmutableViolation, + ), + ( + """ +interface A: + def foo()-> uint256: nonpayable + +@external +def bar(x:address): + a: A = A(x) + for i in range(a.foo(), a.foo() + 1): + pass + """, + ImmutableViolation, + ), + ( + """ +interface A: + def foo()-> uint256: nonpayable + +@external +def bar(x:address): + a: A = A(x) + for i in range(min(a.foo(), 123), min(a.foo(), 123) + 1): + pass + """, + ImmutableViolation, + ), + # Cannot call `pop()` in for range because it modifies state + ( + """ +arr: DynArray[uint256, 10] +@external +def test()-> (DynArray[uint256, 6], DynArray[uint256, 10]): + b: DynArray[uint256, 6] = [] + self.arr = [1,0] + for i in range(self.arr.pop(), self.arr.pop() + 2): + b.append(i) + return b, self.arr + """, + ImmutableViolation, + ), + ( + """ +@external +def bar(x:address): + for i in range(1 if raw_call( + x, + b'', + max_outsize=32, + ) == b"vyper" else 2, + bound=12 + ): + pass + """, + ImmutableViolation, + ), + ( + """ +@external +def foo(a: address): + for i in range(1 if convert(create_minimal_proxy_to(a), uint256) > 2 else 2, bound=12): + pass + """, + ImmutableViolation, + ), + ( + """ +@external +def foo(a: address): + for i in range(1 if convert(create_copy_of(a), uint256) > 2 else 2, bound=12): + pass + """, + ImmutableViolation, + ), + ( + """ +@external +def foo(a: address): + for i in range(1 if convert(create_from_blueprint(a), uint256) > 2 else 2, bound=12): + pass + """, + ImmutableViolation, + ), ] @@ -51,6 +159,59 @@ def kick_foos(): for foo in self.foos: foo.kick() """, + """ +interface A: + def foo()-> uint256: view + +@external +def bar(x:address): + a: A = A(x) + for i in range(a.foo(), bound=12): + pass + """, + """ +interface A: + def foo()-> uint256: view + +@external +def bar(x:address): + a: A = A(x) + for i in range(max(a.foo(), 123), bound=12): + pass + """, + """ +interface A: + def foo()-> uint256: view + +@external +def bar(x:address): + a: A = A(x) + for i in range(a.foo(), a.foo() + 1): + pass + """, + """ +interface A: + def foo()-> uint256: view + +@external +def bar(x:address): + a: A = A(x) + for i in range(min(a.foo(), 123), min(a.foo(), 123) + 1): + pass + """, + """ +@external +def bar(x:address): + for i in range(1 if raw_call( + x, + b'', + max_outsize=32, + is_static_call=True + ) == b"vyper" else 2, + bound=12 + ): + pass + """, ] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c0c05325f2..765fe79608 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -404,6 +404,35 @@ def visit_For(self, node): if not type_list: raise TypeMismatch("Iterator values are of different types", node.iter) + # Check for state-modifying expressions in `range` expression + range_call_nodes = node.iter.get_descendants(vy_ast.Call) + for call_node in range_call_nodes: + call_type = get_exact_type_from_node(call_node.func) + func_name = call_node.get("func.id") + disallowed_builtins = ( + "create_minimal_proxy_to", + "create_copy_of", + "create_from_blueprint", + ) + if ( + # state-modifying internal and external calls + (isinstance(call_type, ContractFunctionT) and call_type.is_mutable) + # `pop` on dynamic arrays + or (isinstance(call_type, MemberFunctionT) and call_type.is_modifying) + # state-modifying builtin functions + or func_name in disallowed_builtins + # `raw_call` is handled specially due to the `is_static_call` kwarg + or ( + func_name == "raw_call" + and not {i.arg: i.value for i in call_node.keywords}.get( + "is_static_call", False + ) + ) + ): + raise ImmutableViolation( + "Cannot call state-modifying functions for `range` expression", call_node + ) + else: # iteration over a variable or literal list if isinstance(node.iter, vy_ast.List) and len(node.iter.elements) == 0: