Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[ux]: catch state modifying functions in range expr and loop iterator #3884

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions tests/functional/builtins/codegen/test_raw_call.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest
from hexbytes import HexBytes

from tests.utils import ZERO_ADDRESS
from vyper import compile_code
from vyper.builtins.functions import eip1167_bytecode
from vyper.exceptions import ArgumentException, StateAccessViolation, TypeMismatch


def test_max_outsize_exceeds_returndatasize(get_contract):
Expand Down Expand Up @@ -599,54 +597,3 @@ def bar(f: uint256) -> Bytes[100]:
c.bar(15).hex() == "0423a132"
"000000000000000000000000000000000000000000000000000000000000000f"
)


uncompilable_code = [
(
"""
@external
@view
def foo(_addr: address):
raw_call(_addr, method_id("foo()"))
""",
StateAccessViolation,
),
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_delegate_call=True, is_static_call=True)
""",
ArgumentException,
),
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_delegate_call=True, value=1)
""",
ArgumentException,
),
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_static_call=True, value=1)
""",
ArgumentException,
),
(
"""
@external
@view
def foo(_addr: address):
raw_call(_addr, 256)
""",
TypeMismatch,
),
]


@pytest.mark.parametrize("source_code,exc", uncompilable_code)
def test_invalid_type_exception(assert_compile_failed, get_contract, source_code, exc):
assert_compile_failed(lambda: get_contract(source_code), exc)
34 changes: 0 additions & 34 deletions tests/functional/codegen/features/iteration/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,37 +473,3 @@ def foo() -> DynArray[int256, 10]:
return
with pytest.raises(StaticAssertionException):
get_contract(code)


def test_for_range_start_double_eval(get_contract, tx_failed):
code = """
@external
def foo() -> (uint256, DynArray[uint256, 3]):
x:DynArray[uint256, 3] = [3, 1]
res: DynArray[uint256, 3] = empty(DynArray[uint256, 3])
for i:uint256 in range(x.pop(),x.pop(), bound = 3):
res.append(i)

return len(x), res
"""
c = get_contract(code)
length, res = c.foo()

assert (length, res) == (0, [1, 2])


def test_for_range_stop_double_eval(get_contract, tx_failed):
code = """
@external
def foo() -> (uint256, DynArray[uint256, 3]):
x:DynArray[uint256, 3] = [3, 3]
res: DynArray[uint256, 3] = empty(DynArray[uint256, 3])
for i:uint256 in range(x.pop(), bound = 3):
res.append(i)

return len(x), res
"""
c = get_contract(code)
length, res = c.foo()

assert (length, res) == (1, [0, 1, 2])
105 changes: 104 additions & 1 deletion tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import pytest

from vyper import compiler
from vyper.exceptions import ArgumentException, StructureException, TypeMismatch, UnknownType
from vyper.exceptions import (
ArgumentException,
StateAccessViolation,
StructureException,
TypeMismatch,
UnknownType,
)

fail_list = [
(
Expand Down Expand Up @@ -322,6 +328,103 @@ def foo():
None,
"10.1",
),
(
"""
interface I:
def bar() -> uint256: payable

@external
def bar(t: address):
for i: uint256 in range(extcall I(t).bar(), bound=10):
pass
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"extcall I(t).bar()",
),
(
"""
a: DynArray[uint256, 3]

@internal
def foo() -> uint256:
return self.a.pop()


@external
def bar():
for i: uint256 in range(2, self.foo(), bound=5):
pass
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"self.foo()",
),
(
"""
interface I:
def bar() -> DynArray[uint256, 10]: nonpayable

@external
def bar(t: address):
for i: uint256 in extcall I(t).bar():
pass
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"extcall I(t).bar()",
),
# Cannot call `pop()` in for range because it modifies state
(
"""
arr: DynArray[uint256, 10]
@external
def test()-> (DynArray[uint256, 6], DynArray[uint256, 10]):
b: DynArray[uint256, 6] = []
self.arr = [1,0]
for i: uint256 in range(self.arr.pop(), 20, bound=12):
b.append(i)
return b, self.arr
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"self.arr.pop()",
),
(
"""
arr: DynArray[uint256, 10]
@external
def test()-> (DynArray[uint256, 6], DynArray[uint256, 10]):
b: DynArray[uint256, 6] = []
self.arr = [1,0]
for i: uint256 in range(5, self.arr.pop() + 2, bound=12):
b.append(i)
return b, self.arr
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"self.arr.pop() + 2",
),
# Cannot call `pop()` in iterator because it modifies state
(
"""
a: DynArray[uint256, 3]

@external
def foo():
for i: uint256 in [1, 2, self.a.pop()]:
pass
""",
StateAccessViolation,
"May not call state modifying function within a range expression or for loop iterator.",
None,
"self.a.pop()",
),
]

for_code_regex = re.compile(r"for .+ in (.*):", re.DOTALL)
Expand Down
68 changes: 67 additions & 1 deletion tests/functional/syntax/test_raw_call.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import pytest

from vyper import compile_code
from vyper.exceptions import ArgumentException, InvalidType, SyntaxException, TypeMismatch
from vyper.exceptions import (
ArgumentException,
InvalidType,
StateAccessViolation,
SyntaxException,
TypeMismatch,
)

fail_list = [
(
Expand Down Expand Up @@ -33,6 +39,66 @@ def foo():
""",
InvalidType,
),
(
"""
@external
@view
def foo(_addr: address):
raw_call(_addr, method_id("foo()"))
""",
StateAccessViolation,
),
# non-static call cannot be used in a range expression
(
"""
@external
def foo(a: address):
for i: uint256 in range(
0,
extract32(raw_call(a, b"", max_outsize=32), 0, output_type=uint256),
bound = 12
):
pass
""",
StateAccessViolation,
),
# call cannot be both a delegate call and a static call
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_delegate_call=True, is_static_call=True)
""",
ArgumentException,
),
# value cannot be passed for delegate call
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_delegate_call=True, value=1)
""",
ArgumentException,
),
#
(
"""
@external
def foo(_addr: address):
raw_call(_addr, method_id("foo()"), is_static_call=True, value=1)
""",
ArgumentException,
),
# second argument should be Bytes
(
"""
@external
@view
def foo(_addr: address):
raw_call(_addr, 256)
""",
TypeMismatch,
),
]


Expand Down
24 changes: 24 additions & 0 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,28 @@ def _analyse_list_iter(self, iter_node, target_type):
# through folded constants
return _get_variable_access(iter_val)

def _check_for_loop_modifiability(self, iter_node: vy_ast.VyperNode):
args = None
if isinstance(iter_node, vy_ast.Call):
args = iter_node.args
else:
iter_val = iter_node.reduced()
if isinstance(iter_val, vy_ast.List):
args = iter_val.elements
else:
args = [iter_node]

for arg in args:
call_nodes = arg.get_descendants(vy_ast.Call, include_self=True)
for c in call_nodes:
func_type = c.func._metadata["type"]
if getattr(func_type, "is_modifying", False) or getattr(
func_type, "is_mutable", False
):
msg = "May not call state modifying function within a range expression "
msg += "or for loop iterator."
raise StateAccessViolation(msg, arg)

def visit_For(self, node):
if not isinstance(node.target.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target.target)
Expand Down Expand Up @@ -584,6 +606,8 @@ def visit_For(self, node):
for stmt in node.body:
self.visit(stmt)

self._check_for_loop_modifiability(node.iter)

def visit_If(self, node):
self.expr_visitor.visit(node.test, BoolT())
with self.namespace.enter_scope():
Expand Down