Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: guard against mutating code in non-mutable functions #3555

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions tests/parser/exceptions/test_constancy_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,111 @@ def foo():
for i in range(x):
pass""",
"""
f:int128
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test now throws for its other issue of call violation, which is checked first in visit_Expr. Hence, I removed it as there are already tests for call violation in this file.

from vyper.interfaces import ERC20

token: ERC20

@external
def a (x:int128):
self.f = 100
@view
def topup(amount: uint256):
assert self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256):
x: bool = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
def b():
self.a(10)""",
@view
def topup(amount: uint256):
x: bool = False
x = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256) -> bool:
return self.token.transferFrom(msg.sender, self, amount)
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
assert self.a.pop() > 123, "vyper"
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = 0
x = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo() -> uint256:
return self.a.pop()
""",
"""
@external
@view
def foo(x: address):
assert convert(
raw_call(
x,
b'',
max_outsize=32,
revert_on_failure=False
), uint256
) > 123, "vyper"
""",
"""
@external
@view
def foo(a: address):
x: address = create_minimal_proxy_to(a)
""",
"""
@external
@view
def foo(a: address):
x: address = empty(address)
x = create_copy_of(a)
""",
"""
@external
@view
def foo(a: address) -> address:
return create_from_blueprint(a)
""",
],
)
def test_statefulness_violations(bad_code):
Expand Down
42 changes: 40 additions & 2 deletions vyper/semantics/analysis/annotation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from vyper import ast as vy_ast
from vyper.exceptions import StructureException, TypeCheckFailure
from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
)
from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [vyper.semantics.types.function](1) begins an import cycle.


class _AnnotationVisitorBase:
Expand Down Expand Up @@ -136,6 +137,23 @@ def visit_Call(self, node, type_):
self.visit(node.func)

if isinstance(call_type, ContractFunctionT):
if (
call_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and call_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

# function calls
if call_type.is_internal:
self.func.called_functions.add(call_type)
Expand All @@ -157,10 +175,30 @@ def visit_Call(self, node, type_):
):
self.visit(value, arg_type)
elif isinstance(call_type, MemberFunctionT):
if call_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.func.value)
expr_info.validate_modification(node, self.func.mutability)

assert len(node.args) == len(call_type.arg_types)
for arg, arg_type in zip(node.args, call_type.arg_types):
self.visit(arg, arg_type)
else:
mutable_builtins = (
"raw_call",
"create_minimal_proxy_to",
"create_copy_of",
"create_from_blueprint",
)
if (
self.func.mutability <= StateMutability.VIEW
and node.get("func.id") in mutable_builtins
):
raise StateAccessViolation(
f"Cannot call a mutating builtin from a {self.func.mutability.value} function",
node,
)

# builtin functions
arg_types = call_type.infer_arg_types(node)
for arg, arg_type in zip(node.args, arg_types):
Expand Down
23 changes: 0 additions & 23 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,29 +518,6 @@ def visit_Expr(self, node):
if is_type_t(fn_type, StructT):
raise StructureException("Struct creation without assignment is disallowed", node)

if isinstance(fn_type, ContractFunctionT):
if (
fn_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and fn_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.value.func.value)
expr_info.validate_modification(node, self.func.mutability)

# NOTE: fetch_call_return validates call args.
return_value = fn_type.fetch_call_return(node.value)
if (
Expand Down