From 161928a45aaeb13bac4678266c2f46bee653a894 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 15:53:02 +0800 Subject: [PATCH 1/8] move call validation to annotation file --- vyper/semantics/analysis/annotation.py | 42 ++++++++++++++++++++++++-- vyper/semantics/analysis/local.py | 23 -------------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index d309f102cd..a3cd50ada0 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -1,12 +1,13 @@ from vyper import ast as vy_ast -from vyper.exceptions import StructureException, TypeCheckFailure +from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure from vyper.semantics.analysis.utils import ( get_common_types, get_exact_type_from_node, + get_expr_info, get_possible_types_from_node, ) from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t -from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability class _AnnotationVisitorBase: @@ -136,6 +137,23 @@ def visit_Call(self, node, type_): self.visit(node.func) if isinstance(call_type, ContractFunctionT): + if ( + call_type.mutability > StateMutability.VIEW + and self.func.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot call a mutating function from a {self.func.mutability.value} function", + node, + ) + + if ( + self.func.mutability == StateMutability.PURE + and call_type.mutability != StateMutability.PURE + ): + raise StateAccessViolation( + "Cannot call non-pure function from a pure function", node + ) + # function calls if call_type.is_internal: self.func.called_functions.add(call_type) @@ -157,10 +175,30 @@ def visit_Call(self, node, type_): ): self.visit(value, arg_type) elif isinstance(call_type, MemberFunctionT): + if call_type.is_modifying: + # it's a dotted function call like dynarray.pop() + expr_info = get_expr_info(node.func.value) + expr_info.validate_modification(node, self.func.mutability) + assert len(node.args) == len(call_type.arg_types) for arg, arg_type in zip(node.args, call_type.arg_types): self.visit(arg, arg_type) else: + mutable_builtins = ( + "raw_call", + "create_minimal_proxy_to", + "create_copy_of", + "create_from_blueprint", + ) + if ( + self.func.mutability <= StateMutability.VIEW + and node.get("func.id") in mutable_builtins + ): + raise StateAccessViolation( + f"Cannot call a mutating builtin from a {self.func.mutability.value} function", + node, + ) + # builtin functions arg_types = call_type.infer_arg_types(node) for arg, arg_type in zip(node.args, arg_types): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c0c05325f2..dd230d338d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -518,29 +518,6 @@ def visit_Expr(self, node): if is_type_t(fn_type, StructT): raise StructureException("Struct creation without assignment is disallowed", node) - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - # NOTE: fetch_call_return validates call args. return_value = fn_type.fetch_call_return(node.value) if ( From f112bee5d4d586912094dc02dda530601ce0801a Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 15:53:08 +0800 Subject: [PATCH 2/8] add tests --- .../exceptions/test_constancy_exception.py | 105 +++++++++++++++++- 1 file changed, 100 insertions(+), 5 deletions(-) diff --git a/tests/parser/exceptions/test_constancy_exception.py b/tests/parser/exceptions/test_constancy_exception.py index 4bd0b4fcb9..fb99dc9273 100644 --- a/tests/parser/exceptions/test_constancy_exception.py +++ b/tests/parser/exceptions/test_constancy_exception.py @@ -76,16 +76,111 @@ def foo(): for i in range(x): pass""", """ -f:int128 +from vyper.interfaces import ERC20 + +token: ERC20 @external -def a (x:int128): - self.f = 100 +@view +def topup(amount: uint256): + assert self.token.transferFrom(msg.sender, self, amount) + """, + """ +from vyper.interfaces import ERC20 +token: ERC20 + +@external @view +def topup(amount: uint256): + x: bool = self.token.transferFrom(msg.sender, self, amount) + """, + """ +from vyper.interfaces import ERC20 + +token: ERC20 + @external -def b(): - self.a(10)""", +@view +def topup(amount: uint256): + x: bool = False + x = self.token.transferFrom(msg.sender, self, amount) + """, + """ +from vyper.interfaces import ERC20 + +token: ERC20 + +@external +@view +def topup(amount: uint256) -> bool: + return self.token.transferFrom(msg.sender, self, amount) + """, + """ +a: DynArray[uint256, 3] + +@external +@view +def foo(): + assert self.a.pop() > 123, "vyper" + """, + """ +a: DynArray[uint256, 3] + +@external +@view +def foo(): + x: uint256 = self.a.pop() + """, + """ +a: DynArray[uint256, 3] + +@external +@view +def foo(): + x: uint256 = 0 + x = self.a.pop() + """, + """ +a: DynArray[uint256, 3] + +@external +@view +def foo() -> uint256: + return self.a.pop() + """, + """ +@external +@view +def foo(x: address): + assert convert( + raw_call( + x, + b'', + max_outsize=32, + revert_on_failure=False + ), uint256 + ) > 123, "vyper" + """, + """ +@external +@view +def foo(a: address): + x: address = create_minimal_proxy_to(a) + """, + """ +@external +@view +def foo(a: address): + x: address = empty(address) + x = create_copy_of(a) + """, + """ +@external +@view +def foo(a: address) -> address: + return create_from_blueprint(a) + """, ], ) def test_statefulness_violations(bad_code): From acec274ce2ee391cba02227a995ae1471ecdd1e9 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 17:38:52 +0800 Subject: [PATCH 3/8] exclude raw call --- vyper/semantics/analysis/annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index a3cd50ada0..02fc0ffae4 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -184,8 +184,8 @@ def visit_Call(self, node, type_): for arg, arg_type in zip(node.args, call_type.arg_types): self.visit(arg, arg_type) else: + # note that mutability for`raw_call` is handled in its `build_IR` function mutable_builtins = ( - "raw_call", "create_minimal_proxy_to", "create_copy_of", "create_from_blueprint", From c7c7d3756bd6c73305e323544a521da7234c9c3d Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 17:38:58 +0800 Subject: [PATCH 4/8] add test --- tests/parser/features/decorators/test_view.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/parser/features/decorators/test_view.py b/tests/parser/features/decorators/test_view.py index 25a78db279..280be07e61 100644 --- a/tests/parser/features/decorators/test_view.py +++ b/tests/parser/features/decorators/test_view.py @@ -1,3 +1,5 @@ +import pytest + from vyper.exceptions import FunctionDeclarationException @@ -28,3 +30,25 @@ def foo() -> num: assert_compile_failed( lambda: get_contract_with_gas_estimation_for_constants(code), FunctionDeclarationException ) + + +good_code = [ + """ +@external +@view +def foo(x: address): + assert convert( + raw_call( + x, + b'', + max_outsize=32, + is_static_call=True, + ), uint256 + ) > 123, "vyper" + """ +] + + +@pytest.mark.parametrize("code", good_code) +def test_view_call_compiles(get_contract, code): + get_contract(code) From bf58e4ee2a1d098ca4b1cb11e9a8160508713c9c Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 21:38:24 +0800 Subject: [PATCH 5/8] fix test --- .../external_contracts/test_external_contract_calls.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index b3cc6f5576..8f07d7f0e2 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -892,11 +892,6 @@ def set_lucky(_lucky: int128) -> int128: nonpayable @view def set_lucky_expr(arg1: address, arg2: int128): Foo(arg1).set_lucky(arg2) - -@external -@view -def set_lucky_stmt(arg1: address, arg2: int128) -> int128: - return Foo(arg1).set_lucky(arg2) """ assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) From da88fe499cbc343648d3ea2b765d5a3354474cbe Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 7 Aug 2023 21:40:11 +0800 Subject: [PATCH 6/8] split test --- .../test_external_contract_calls.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index 8f07d7f0e2..ab00edd2fe 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -881,7 +881,7 @@ def set_lucky(arg1: address, arg2: int128): print("Successfully executed an external contract call state change") -def test_constant_external_contract_call_cannot_change_state( +def test_constant_external_contract_call_cannot_change_state1( assert_compile_failed, get_contract_with_gas_estimation ): c = """ @@ -898,6 +898,23 @@ def set_lucky_expr(arg1: address, arg2: int128): print("Successfully blocked an external contract call from a constant function") +def test_constant_external_contract_call_cannot_change_state2( + assert_compile_failed, get_contract_with_gas_estimation +): + c = """ +interface Foo: + def set_lucky(_lucky: int128) -> int128: nonpayable + +@external +@view +def set_lucky_stmt(arg1: address, arg2: int128) -> int128: + return Foo(arg1).set_lucky(arg2) + """ + assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) + + print("Successfully blocked an external contract call from a constant function") + + def test_external_contract_can_be_changed_based_on_address(get_contract): contract_1 = """ lucky: public(int128) From 199eece58d816c09e7346574c866e1f6398655cd Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 14 Aug 2023 12:41:40 +0800 Subject: [PATCH 7/8] add mutability attribute to builtins --- vyper/builtins/_signatures.py | 2 ++ vyper/builtins/functions.py | 11 ++++++++++- vyper/semantics/analysis/annotation.py | 20 +++++--------------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d39a4a085f..e4bc55ff4e 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -6,6 +6,7 @@ from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic, TypeMismatch +from vyper.semantics.analysis.base import StateMutability from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -77,6 +78,7 @@ def decorator_fn(self, node, context): class BuiltinFunction: _has_varargs = False _kwargs: Dict[str, KwargSettings] = {} + mutability = StateMutability.PURE # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg, expected_type): diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 685d832c01..51ae8ff254 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -51,7 +51,7 @@ UnfoldableNode, ZeroDivisionException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import StateMutability, VarInfo from vyper.semantics.analysis.utils import ( get_common_types, get_exact_type_from_node, @@ -1083,12 +1083,15 @@ class RawCall(BuiltinFunction): "revert_on_failure": KwargSettings(BoolT(), True, require_literal=True), } _return_type = None + mutability = StateMutability.NONPAYABLE def fetch_call_return(self, node): self._validate_arg_types(node) kwargz = {i.arg: i.value for i in node.keywords} + value = kwargz.get("value") + static_call = kwargz.get("is_static_call") outsize = kwargz.get("max_outsize") revert_on_failure = kwargz.get("revert_on_failure") revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True @@ -1101,6 +1104,11 @@ def fetch_call_return(self, node): if not isinstance(outsize, vy_ast.Int) or outsize.value < 0: raise + if static_call: + self.mutability = StateMutability.VIEW + elif value: + self.mutability = StateMutability.PAYABLE + if outsize.value: return_type = BytesT() return_type.set_min_length(outsize.value) @@ -1724,6 +1732,7 @@ class _CreateBase(BuiltinFunction): "salt": KwargSettings(BYTES32_T, empty_value), } _return_type = AddressT() + mutability = StateMutability.PAYABLE @process_inputs def build_IR(self, expr, args, kwargs, context): diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index 02fc0ffae4..f76e8dd2bc 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -136,7 +136,7 @@ def visit_Call(self, node, type_): node._metadata["type"] = node_type self.visit(node.func) - if isinstance(call_type, ContractFunctionT): + def _check_mutability(call_type): if ( call_type.mutability > StateMutability.VIEW and self.func.mutability <= StateMutability.VIEW @@ -154,6 +154,9 @@ def visit_Call(self, node, type_): "Cannot call non-pure function from a pure function", node ) + if isinstance(call_type, ContractFunctionT): + _check_mutability(call_type) + # function calls if call_type.is_internal: self.func.called_functions.add(call_type) @@ -184,20 +187,7 @@ def visit_Call(self, node, type_): for arg, arg_type in zip(node.args, call_type.arg_types): self.visit(arg, arg_type) else: - # note that mutability for`raw_call` is handled in its `build_IR` function - mutable_builtins = ( - "create_minimal_proxy_to", - "create_copy_of", - "create_from_blueprint", - ) - if ( - self.func.mutability <= StateMutability.VIEW - and node.get("func.id") in mutable_builtins - ): - raise StateAccessViolation( - f"Cannot call a mutating builtin from a {self.func.mutability.value} function", - node, - ) + _check_mutability(call_type) # builtin functions arg_types = call_type.infer_arg_types(node) From 65b5384b3ef46b481e0af2bcbc88f67af3fd9b70 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 14 Aug 2023 14:00:42 +0800 Subject: [PATCH 8/8] check mutability attr exists --- vyper/semantics/analysis/annotation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index f76e8dd2bc..2f59b89ca9 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -187,7 +187,8 @@ def _check_mutability(call_type): for arg, arg_type in zip(node.args, call_type.arg_types): self.visit(arg, arg_type) else: - _check_mutability(call_type) + if hasattr(call_type, "mutability"): + _check_mutability(call_type) # builtin functions arg_types = call_type.infer_arg_types(node)