From 6b9fff2fcc032176e257e5e252c916c06b9cee3a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 8 Feb 2024 23:14:09 -0500 Subject: [PATCH 01/14] rename validate_expected_type to infer_type and have it return a type it also tags the node with the inferred type --- .../builtins/folding/test_bitwise.py | 6 ++-- vyper/builtins/_signatures.py | 4 +-- vyper/builtins/functions.py | 11 +++--- vyper/semantics/analysis/local.py | 13 +++---- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/analysis/utils.py | 35 +++++++++++-------- vyper/semantics/types/function.py | 10 +++--- vyper/semantics/types/module.py | 6 ++-- vyper/semantics/types/subscriptable.py | 8 ++--- vyper/semantics/types/user.py | 6 ++-- 10 files changed, 51 insertions(+), 50 deletions(-) diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..892f0bcabc 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -4,7 +4,7 @@ from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import infer_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256: # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). - validate_expected_type(new_node, UINT256_T) + _ = infer_type(new_node, UINT256_T) # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except OverflowException: # here: check the wrapped value matches runtime @@ -81,7 +81,7 @@ def foo(a: int256, b: uint256) -> int256: vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() - validate_expected_type(new_node, INT256_T) # force bounds check + _ = infer_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except (InvalidType, OverflowException): diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 6e6cf4c662..3d25b435da 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -99,7 +99,7 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N # for its side effects (will throw if is not a type) type_from_annotation(arg) else: - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type) def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7575f4d77e..345b59197a 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -54,7 +54,7 @@ get_common_types, get_exact_type_from_node, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import ( TYPE_T, @@ -508,8 +508,7 @@ def infer_arg_types(self, node, expected_return_typ=None): ret = [] prev_typeclass = None for arg in node.args: - validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) - arg_t = get_possible_types_from_node(arg).pop() + arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes" if prev_typeclass and current_typeclass != prev_typeclass: raise TypeMismatch( @@ -865,7 +864,7 @@ def infer_kwarg_types(self, node): "Output type must be one of integer, bytes32 or address", node.keywords[0].value ) output_typedef = TYPE_T(output_type) - node.keywords[0].value._metadata["type"] = output_typedef + #node.keywords[0].value._metadata["type"] = output_typedef else: output_typedef = TYPE_T(BYTES32_T) @@ -2376,8 +2375,8 @@ def infer_kwarg_types(self, node): ret = {} for kwarg in node.keywords: kwarg_name = kwarg.arg - validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) - ret[kwarg_name] = get_exact_type_from_node(kwarg.value) + typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ) + ret[kwarg_name] = typ return ret def fetch_call_return(self, node): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..77fa57c074 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -26,7 +26,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation @@ -254,7 +254,7 @@ def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): try: - validate_expected_type(msg_node, StringT(1024)) + _ = infer_type(msg_node, StringT(1024)) except TypeMismatch as e: raise InvalidType("revert reason must fit within String[1024]") from e self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) @@ -563,15 +563,10 @@ def scope_name(self): def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): - validate_expected_type(node, typ) + infer_type(node, expected_type=typ) - # recurse and typecheck in case we are being fed the wrong type for - # some reason. super().visit(node, typ) - # annotate - node._metadata["type"] = typ - if not isinstance(typ, TYPE_T): info = get_expr_info(node) # get_expr_info fills in node._expr_info @@ -793,7 +788,7 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - # these guarantees should be provided by validate_expected_type + # these guarantees should be provided by infer_type assert isinstance(typ, TupleT) assert len(node.elements) == len(typ.member_types) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..787ec82c15 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -526,7 +526,7 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) # performs validate_expected_type + ExprVisitor().visit(node.value, type_) # performs type validation if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..c889e6ab75 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -224,7 +224,7 @@ def types_from_BinOp(self, node): # can be different types types_list = get_possible_types_from_node(node.left) # check rhs is unsigned integer - validate_expected_type(node.right, IntegerT.unsigneds()) + _ = infer_type(node.right, IntegerT.unsigneds()) else: types_list = get_common_types(node.left, node.right) @@ -319,7 +319,7 @@ def types_from_Constant(self, node): raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node) def types_from_IfExp(self, node): - validate_expected_type(node.test, BoolT()) + _ = infer_type(node.test, expected_type=BoolT()) types_list = get_common_types(node.body, node.orelse) if not types_list: @@ -529,14 +529,14 @@ def _validate_literal_array(node, expected): for item in node.elements: try: - validate_expected_type(item, expected.value_type) + _ = infer_type(item, expected.value_type) except (InvalidType, TypeMismatch): return False return True -def validate_expected_type(node, expected_type): +def infer_type(node, expected_type): """ Validate that the given node matches the expected type(s) @@ -551,8 +551,15 @@ def validate_expected_type(node, expected_type): Returns ------- - None + The inferred type. The inferred type must be a concrete type which + is compatible with the expected type (although the expected type may + be generic). """ + ret = _infer_type_helper(node, expected_type) + node._metadata["type"] = ret + return ret + +def _infer_type_helper(node, expected_type): if not isinstance(expected_type, tuple): expected_type = (expected_type,) @@ -561,15 +568,15 @@ def validate_expected_type(node, expected_type): for t in possible_tuple_types: if len(t.member_types) != len(node.elements): continue - for item_ast, item_type in zip(node.elements, t.member_types): + ret = [] + for item_ast, expected_item_type in zip(node.elements, t.member_types): try: - validate_expected_type(item_ast, item_type) - return + item_t = infer_type(item_ast, expected_type=expected_item_type) + ret.append(item_t) except VyperException: - pass - else: - # fail block - pass + break # go to fail block + else: + return TupleT(tuple(ret)) given_types = _ExprAnalyser().get_possible_types_from_node(node) @@ -579,11 +586,11 @@ def validate_expected_type(node, expected_type): if not isinstance(expected, (DArrayT, SArrayT)): continue if _validate_literal_array(node, expected): - return + return expected else: for given, expected in itertools.product(given_types, expected_type): if expected.compare_type(given): - return + return given # validation failed, prepare a meaningful error message if len(expected_type) > 1: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..4f4fc82e5c 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,7 +27,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType @@ -542,7 +542,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: raise CallViolation("Cannot send ether to nonpayable function", kwarg_node) for arg, expected in zip(node.args, self.argument_types): - validate_expected_type(arg, expected) + infer_type(arg, expected) # TODO this should be moved to validate_call_args for kwarg in node.keywords: @@ -553,7 +553,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: f"`{kwarg.arg}=` specified but {self.name}() does not return anything", kwarg.value, ) - validate_expected_type(kwarg.value, kwarg_settings.typ) + infer_type(kwarg.value, kwarg_settings.typ) if kwarg_settings.require_literal: if not isinstance(kwarg.value, vy_ast.Constant): raise InvalidType( @@ -730,7 +730,7 @@ def _parse_args( value = funcdef.args.defaults[i - n_positional_args] if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) - validate_expected_type(value, type_) + infer_type(value, expected_type=type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) argnames.add(argname) @@ -788,7 +788,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): # CMC 2022-04-01 this should probably be in the validation module - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type=expected_type) return self.return_type diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 86840f4f91..0ef052a3da 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -13,7 +13,7 @@ from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( check_modifiability, - validate_expected_type, + infer_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation @@ -83,8 +83,8 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": def _ctor_arg_types(self, node): validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] + typ = infer_type(node.args[0], AddressT()) + return [typ] def _ctor_kwarg_types(self, node): return {} diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 635a1631a2..9dec62e136 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -35,9 +35,9 @@ def getter_signature(self) -> Tuple[Tuple, Optional[VyperType]]: def validate_index_type(self, node): # TODO: break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type - validate_expected_type(node, self.key_type) + infer_type(node, self.key_type) class HashMapT(_SubscriptableT): @@ -125,7 +125,7 @@ def count(self): def validate_index_type(self, node): # TODO break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type if isinstance(node, vy_ast.Int): if node.value < 0: @@ -133,7 +133,7 @@ def validate_index_type(self, node): if node.value >= self.length: raise ArrayIndexException("Index out of range", node) - validate_expected_type(node, IntegerT.any()) + infer_type(node, IntegerT.any()) def get_subscripted_type(self, node): return self.value_type diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92a455e3d8..c3f169ac8d 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -16,7 +16,7 @@ ) from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type +from vyper.semantics.analysis.utils import check_modifiability, infer_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -270,7 +270,7 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): - validate_expected_type(arg, expected) + infer_type(arg, expected) def to_toplevel_abi_dict(self) -> list[dict]: return [ @@ -412,7 +412,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": key, ) - validate_expected_type(value, members.pop(key.id)) + infer_type(value, members.pop(key.id)) if members: raise VariableDeclarationException( From b3e2fd9c67eb43caaef04ee494368ac618763dc0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:36:07 -0500 Subject: [PATCH 02/14] update a comment --- vyper/semantics/analysis/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 882989d776..cefbbb01d9 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -493,7 +493,7 @@ def _analyse_list_iter(self, iter_node, target_type): except (InvalidType, StructureException): raise InvalidType("Not an iterable type", iter_node) - # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # CMC 2024-02-09 TODO: use infer_type once we have DArrays # with generic length. if not isinstance(iter_type, (DArrayT, SArrayT)): raise InvalidType("Not an iterable type", iter_node) From 3fd9fb864100702eb01ed2c45f4cfb5c0ba8e9df Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:41:13 -0500 Subject: [PATCH 03/14] improve type inference for revert reason strings --- vyper/semantics/analysis/local.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cefbbb01d9..d787ba6a41 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -313,17 +313,20 @@ def visit_AnnAssign(self, node): self.expr_visitor.visit(node.target, typ) def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE": + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + return + if isinstance(msg_node, vy_ast.Str): if not msg_node.value.strip(): raise StructureException("Reason string cannot be empty", msg_node) - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - _ = infer_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + try: + self.expr_visitor.visit(msg_node, StringT.any()) + except TypeMismatch as e: + # improve the error message + msg = "reason must be a string or the special `UNREACHABLE` value" + raise TypeMismatch(msg, msg_node) from e + def visit_Assert(self, node): if node.msg: From 1350838201764c6fe1436b4ad8811e2562eb4b86 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:48:47 -0500 Subject: [PATCH 04/14] format some comments --- vyper/builtins/functions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 345b59197a..854112520e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -203,12 +203,14 @@ def infer_arg_types(self, node, expected_return_typ=None): target_type = type_from_annotation(node.args[1]) value_types = get_possible_types_from_node(node.args[0]) - # For `convert` of integer literals, we need to match type inference rules in - # convert.py codegen routines. + # For `convert` of integer literals, we need to match type inference + # rules in convert.py codegen routines. # TODO: This can probably be removed once constant folding for `convert` is implemented if len(value_types) > 1 and all(isinstance(v, IntegerT) for v in value_types): - # Get the smallest (and unsigned if available) type for non-integer target types - # (note this is different from the ordering returned by `get_possible_types_from_node`) + # Get the smallest (and unsigned if available) type for + # non-integer target types + # (note this is different from the ordering returned by + # `get_possible_types_from_node`) if not isinstance(target_type, IntegerT): value_types = sorted(value_types, key=lambda v: (v.is_signed, v.bits), reverse=True) else: From 2013d81945ed1a4fb634928d20015c356ff6ebbd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:48:59 -0500 Subject: [PATCH 05/14] use the result of infer_type --- vyper/semantics/analysis/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d787ba6a41..3dcbcba166 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -600,7 +600,7 @@ def scope_name(self): def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): - infer_type(node, expected_type=typ) + typ = infer_type(node, expected_type=typ) super().visit(node, typ) From 71aab61f71d9d49cdbd1044a9ff278a4ff8c6b9f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:49:57 -0500 Subject: [PATCH 06/14] remove length modification functions on bytestring --- vyper/semantics/types/bytestrings.py | 52 +++++----------------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 96bb1bbf74..c559e625e9 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -6,6 +6,9 @@ from vyper.utils import ceil32 +UNKNOWN_LENGTH = object() + + class _BytestringT(VyperType): """ Private base class for single-value types which occupy multiple memory slots @@ -68,58 +71,19 @@ def validate_literal(self, node: vy_ast.Constant) -> None: @property def size_in_bytes(self): - # the first slot (32 bytes) stores the actual length, and then we reserve - # enough additional slots to store the data if it uses the max available length - # because this data type is single-bytes, we make it so it takes the max 32 byte - # boundary as it's size, instead of giving it a size that is not cleanly divisible by 32 - + # the first slot (32 bytes) stores the actual length, and then we + # reserve enough additional slots to store the data. allocate 32-byte + # aligned buffer for the data. return 32 + ceil32(self.length) - def set_length(self, length): - """ - Sets the exact length of the type. - - May only be called once, and only on a type that does not yet have - a fixed length. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - self._length = length - self._min_length = length - - def set_min_length(self, min_length): - """ - Sets the minimum length of the type. - - May only be used to increase the minimum length. May not be called if - an exact length has been set. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - if self._min_length > min_length: - raise CompilerPanic("Cannot reduce the min_length of ArrayValueType") - self._min_length = min_length - def compare_type(self, other): if not super().compare_type(other): return False - # CMC 2022-03-18 TODO this method should be refactored so it does not have side effects - - # when comparing two literals, both now have an equal min-length - if not self._length and not other._length: - min_length = max(self._min_length, other._min_length) - self.set_min_length(min_length) - other.set_min_length(min_length) + if UNKNOWN_LENGTH in (self._length, other._length): return True - # comparing a defined length to a literal causes the literal to have a fixed length - if self._length: - if not other._length: - other.set_length(max(self._length, other._min_length)) - return self._length >= other._length - - return other.compare_type(self) + return self._length >= other._length @classmethod def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": From d7993ec147328d7b98e68793a4f6f60c580a9805 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 4 Mar 2024 17:56:41 -0800 Subject: [PATCH 07/14] feat[lang]: allow downcasting of bytestrings allow conversion from Bytes/String types to shorter length types, e.g. convert `Bytes[20]` to `Bytes[19]` this will become important when we want to allow generic bytestrings inside the type system (`Bytes[...]`) which can only be user-instantiated by converting to a known length. --- vyper/builtins/_convert.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index 156bee418e..eedec8a5cc 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -10,6 +10,7 @@ bytes_data_ptr, clamp, clamp_basetype, + clamp_bytestring, clamp_le, get_bytearray_length, int_clamp, @@ -422,23 +423,25 @@ def to_address(expr, arg, out_typ): return IRnode.from_list(ret, out_typ) -# question: should we allow bytesM -> String? -@_input_types(BytesT) -def to_string(expr, arg, out_typ): - _check_bytes(expr, arg, out_typ, out_typ.maxlen) - +def _cast_bytestring(expr, arg, out_typ): + if isinstance(arg.typ, out_typ.__class__) and out_typ.maxlen <= arg.typ.maxlen: + _FAIL(arg.typ, out_typ, expr) + ret = ["seq"] + if out_typ.maxlen is None or out_typ.maxlen > arg.maxlen: + ret.append(clamp_bytestring(arg)) # NOTE: this is a pointer cast return IRnode.from_list(arg, typ=out_typ) -@_input_types(StringT) -def to_bytes(expr, arg, out_typ): - _check_bytes(expr, arg, out_typ, out_typ.maxlen) +# question: should we allow bytesM -> String? +@_input_types(BytesT, StringT) +def to_string(expr, arg, out_typ): + return _cast_bytestring(expr, arg, out_typ) - # TODO: more casts - # NOTE: this is a pointer cast - return IRnode.from_list(arg, typ=out_typ) +@_input_types(StringT, BytesT) +def to_bytes(expr, arg, out_typ): + return _cast_bytestring(expr, arg, out_typ) @_input_types(IntegerT) From 93e53c1af63581e56c8f6876c4174e4b79886a72 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 4 Mar 2024 18:05:53 -0800 Subject: [PATCH 08/14] fix direction of some comparisons --- vyper/builtins/_convert.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index eedec8a5cc..cb9a86884f 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -10,7 +10,6 @@ bytes_data_ptr, clamp, clamp_basetype, - clamp_bytestring, clamp_le, get_bytearray_length, int_clamp, @@ -424,13 +423,15 @@ def to_address(expr, arg, out_typ): def _cast_bytestring(expr, arg, out_typ): - if isinstance(arg.typ, out_typ.__class__) and out_typ.maxlen <= arg.typ.maxlen: + # can't convert Bytes[20] to Bytes[21] + if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen: _FAIL(arg.typ, out_typ, expr) ret = ["seq"] - if out_typ.maxlen is None or out_typ.maxlen > arg.maxlen: - ret.append(clamp_bytestring(arg)) + if out_typ.maxlen < arg.typ.maxlen: + ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]]) + ret.append(arg) # NOTE: this is a pointer cast - return IRnode.from_list(arg, typ=out_typ) + return IRnode.from_list(ret, typ=out_typ, location=arg.location, encoding=arg.encoding) # question: should we allow bytesM -> String? From b2e62a2299f7a7cce00cf87deb291769f56caf09 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Mar 2024 08:08:25 -0800 Subject: [PATCH 09/14] fix existing tests, add tests for new functionality, add compile-time check --- .../builtins/codegen/test_convert.py | 61 +++++++++++++++++-- vyper/builtins/_convert.py | 6 +- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index 559e1448ef..c1063bd0e4 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -8,6 +8,7 @@ import eth.codecs.abi.exceptions import pytest +from vyper.compiler import compile_code from vyper.exceptions import InvalidLiteral, InvalidType, TypeMismatch from vyper.semantics.types import AddressT, BoolT, BytesM_T, BytesT, DecimalT, IntegerT, StringT from vyper.semantics.types.shortcuts import BYTES20_T, BYTES32_T, UINT, UINT160_T, UINT256_T @@ -560,14 +561,15 @@ def foo(x: {i_typ}) -> {o_typ}: assert_compile_failed(lambda: get_contract(code), TypeMismatch) -@pytest.mark.parametrize("typ", sorted(TEST_TYPES)) -def test_bytes_too_large_cases(get_contract, assert_compile_failed, typ): +@pytest.mark.parametrize("typ", sorted(BASE_TYPES)) +def test_bytes_too_large_cases(typ): code_1 = f""" @external def foo(x: Bytes[33]) -> {typ}: return convert(x, {typ}) """ - assert_compile_failed(lambda: get_contract(code_1), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code_1) bytes_33 = b"1" * 33 code_2 = f""" @@ -575,8 +577,59 @@ def foo(x: Bytes[33]) -> {typ}: def foo() -> {typ}: return convert({bytes_33}, {typ}) """ + with pytest.raises(TypeMismatch): + compile_code(code_2) - assert_compile_failed(lambda: get_contract(code_2, TypeMismatch)) + +@pytest.mark.parametrize("cls1,cls2", itertools.product((StringT, BytesT), (StringT, BytesT))) +def test_bytestring_conversions(cls1, cls2, get_contract, tx_failed): + typ1 = cls1(33) + typ2 = cls2(32) + + def bytestring(cls, string): + if cls == BytesT: + return string.encode("utf-8") + return string + + code_1 = f""" +@external +def foo(x: {typ1}) -> {typ2}: + return convert(x, {typ2}) + """ + c = get_contract(code_1) + + for i in range(33): # inclusive 32 + s = "1" * i + arg = bytestring(cls1, s) + out = bytestring(cls2, s) + assert c.foo(arg) == out + + with tx_failed(): + # TODO: sanity check it is convert which is reverting, not arg clamping + c.foo(bytestring(cls1, "1" * 33)) + + code_2_template = """ +@external +def foo() -> {typ}: + return convert({arg}, {typ}) + """ + + # test literals + for i in range(33): # inclusive 32 + s = "1" * i + arg = bytestring(cls1, s) + out = bytestring(cls2, s) + code = code_2_template.format(typ=typ2, arg=repr(arg)) + if cls1 == cls2: # ex.: can't convert "" to String[32] + with pytest.raises(InvalidType): + compile_code(code) + else: + c = get_contract(code) + assert c.foo() == out + + failing_code = code_2_template.format(typ=typ2, arg=bytestring(cls1, "1" * 33)) + with pytest.raises(TypeMismatch): + compile_code(failing_code) @pytest.mark.parametrize("n", range(1, 33)) diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index cb9a86884f..8f5f4c03e2 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -423,9 +423,13 @@ def to_address(expr, arg, out_typ): def _cast_bytestring(expr, arg, out_typ): - # can't convert Bytes[20] to Bytes[21] + # ban converting Bytes[20] to Bytes[21] if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen: _FAIL(arg.typ, out_typ, expr) + # can't downcast literals with known length (e.g. b"abc" to Bytes[2]) + if isinstance(expr, vy_ast.Constant) and arg.typ.maxlen > out_typ.maxlen: + _FAIL(arg.typ, out_typ, expr) + ret = ["seq"] if out_typ.maxlen < arg.typ.maxlen: ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]]) From 35ec41326c18311bfe04b21855fb4c0e5afaa91f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 8 Mar 2024 14:07:30 -0800 Subject: [PATCH 10/14] allow bytestrings with ellipsis length --- vyper/abi_types.py | 2 +- vyper/builtins/_convert.py | 3 +- vyper/builtins/_signatures.py | 6 +-- vyper/builtins/functions.py | 34 ++++---------- vyper/codegen/core.py | 2 +- vyper/semantics/analysis/local.py | 2 +- vyper/semantics/analysis/utils.py | 17 ++++--- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/bytestrings.py | 62 +++++++++++++------------- vyper/semantics/types/function.py | 8 +--- vyper/semantics/types/module.py | 2 +- vyper/semantics/types/subscriptable.py | 11 ++++- vyper/semantics/types/utils.py | 11 +++-- 13 files changed, 77 insertions(+), 85 deletions(-) diff --git a/vyper/abi_types.py b/vyper/abi_types.py index 051f8db19f..894de98db0 100644 --- a/vyper/abi_types.py +++ b/vyper/abi_types.py @@ -199,7 +199,7 @@ def is_complex_type(self): class ABI_Bytes(ABIType): def __init__(self, bytes_bound): - if not bytes_bound >= 0: + if bytes_bound is not None and not bytes_bound >= 0: raise InvalidABIType("Negative bytes_bound provided to ABI_Bytes") self.bytes_bound = bytes_bound diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index aa53dee429..18de454e44 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -431,7 +431,8 @@ def _cast_bytestring(expr, arg, out_typ): _FAIL(arg.typ, out_typ, expr) ret = ["seq"] - if out_typ.maxlen < arg.typ.maxlen: + assert out_typ.maxlen is not None + if arg.typ.maxlen is None or out_typ.maxlen < arg.typ.maxlen: ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]]) ret.append(arg) # NOTE: this is a pointer cast diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index f2aa8b0d9c..78ad0f6322 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -7,11 +7,7 @@ from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode from vyper.semantics.analysis.base import Modifiability -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - infer_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node, infer_type from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 4d2a6bc263..ce91231842 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -292,11 +292,6 @@ class Slice(BuiltinFunctionT): def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node) - if isinstance(arg_type, StringT): - return_type = StringT() - else: - return_type = BytesT() - # validate start and length are in bounds arg = node.args[0] @@ -325,20 +320,14 @@ def fetch_call_return(self, node): if length_literal is not None and start_literal + length_literal > arg_type.length: raise ArgumentException(f"slice out of bounds for {arg_type}", node) - # we know the length statically + return_cls = arg_type.__class__ if length_literal is not None: - return_type.set_length(length_literal) + return_type = return_cls(length_literal) else: - return_type.set_min_length(arg_type.length) + return_type = return_cls(arg_type.length) return return_type - def infer_arg_types(self, node, expected_return_typ=None): - self._validate_arg_types(node) - # return a concrete type for `b` - b_type = get_possible_types_from_node(node.args[0]).pop() - return [b_type, self._inputs[1][1], self._inputs[2][1]] - @process_inputs def build_IR(self, expr, args, kwargs, context): src, start, length = args @@ -492,12 +481,8 @@ def fetch_call_return(self, node): for arg_t in arg_types: length += arg_t.length - if isinstance(arg_types[0], (StringT)): - return_type = StringT() - else: - return_type = BytesT() - return_type.set_length(length) - return return_type + return_type_cls = arg_types[0].__class__ + return return_type_cls(length) def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: @@ -865,7 +850,7 @@ def infer_kwarg_types(self, node): "Output type must be one of integer, bytes32 or address", node.keywords[0].value ) output_typedef = TYPE_T(output_type) - #node.keywords[0].value._metadata["type"] = output_typedef + # node.keywords[0].value._metadata["type"] = output_typedef else: output_typedef = TYPE_T(BYTES32_T) @@ -1080,8 +1065,7 @@ def fetch_call_return(self, node): raise if outsize.value: - return_type = BytesT() - return_type.set_min_length(outsize.value) + return_type = BytesT(outsize.value) if revert_on_failure: return return_type @@ -2404,9 +2388,7 @@ def fetch_call_return(self, node): # the output includes 4 bytes for the method_id. maxlen += 4 - ret = BytesT() - ret.set_length(maxlen) - return ret + return BytesT(maxlen) @staticmethod def _parse_method_id(method_id_literal): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index ecf05d1a49..3dc4411092 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -736,7 +736,7 @@ def dummy_node_for_type(typ): def _check_assign_bytes(left, right): - if right.typ.maxlen > left.typ.maxlen: # pragma: nocover + if left.typ.maxlen is not None and right.typ.maxlen > left.typ.maxlen: # pragma: nocover raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}") # stricter check for zeroing a byte array. diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 8367c06369..7852b56868 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -359,7 +359,6 @@ def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: msg = "reason must be a string or the special `UNREACHABLE` value" raise TypeMismatch(msg, msg_node) from e - def visit_Assert(self, node): if node.msg: self._validate_revert_reason(node.msg) @@ -639,6 +638,7 @@ def scope_name(self): def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): + # note: infer_type caches the resolved type on node._metadata["type"] typ = infer_type(node, expected_type=typ) super().visit(node, typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 34b6023735..d117b3dcf8 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -21,7 +21,7 @@ from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType -from vyper.semantics.types.bytestrings import BytesT, StringT +from vyper.semantics.types.bytestrings import BytesT, StringT, _DynLength from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT from vyper.utils import checksum_encode, int_to_fourbytes @@ -566,6 +566,7 @@ def infer_type(node, expected_type): node._metadata["type"] = ret return ret + def _infer_type_helper(node, expected_type): if not isinstance(expected_type, tuple): expected_type = (expected_type,) @@ -614,7 +615,12 @@ def _infer_type_helper(node, expected_type): vy_ast.Name, include_self=True ): given = given_types[0] - raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", node) + hint = None + # TODO: refactor the suggestions code. compare_type could maybe return + # a suggestion if the type is close. + if isinstance(given, _DynLength) and given.length is None: + hint = f"did you mean `convert({node.node_source_code}, {expected_type[0]})`?" + raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", hint=hint) else: if len(given_types) == 1: given_str = str(given_types[0]) @@ -622,13 +628,12 @@ def _infer_type_helper(node, expected_type): types_str = sorted(str(i) for i in given_types) given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - suggestion_str = "" + hint = None if expected_type[0] == AddressT() and given_types[0] == BytesM_T(20): - suggestion_str = f" Did you mean {checksum_encode(node.value)}?" + hint = f" Did you mean `{checksum_encode(node.value)}`?" raise TypeMismatch( - f"Expected {expected_str} but literal can only be cast as {given_str}.{suggestion_str}", - node, + f"Expected {expected_str} but literal can only be cast as {given_str}.", hint=hint ) diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 59a20dd99f..fd2291301e 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,6 +1,6 @@ from . import primitives, subscriptable, user from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void -from .bytestrings import BytesT, StringT, _BytestringT +from .bytestrings import UNKNOWN_LENGTH, BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index c559e625e9..0a7b282290 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -6,53 +6,58 @@ from vyper.utils import ceil32 -UNKNOWN_LENGTH = object() +class _UnknownLength(object): + pass -class _BytestringT(VyperType): +UNKNOWN_LENGTH = _UnknownLength() + + +# TODO: make this a trait which DynArray also inherits from +class _DynLength(VyperType): + pass + + +class _BytestringT(_DynLength): """ Private base class for single-value types which occupy multiple memory slots and where a maximum length must be given via a subscript (string, bytes). - Types for literals have an inferred minimum length. For example, `b"hello"` - has a length of 5 of more and so can be used in an operation with `bytes[5]` - or `bytes[10]`, but not `bytes[4]`. Upon comparison to a fixed length type, - the minimum length is discarded and the type assumes the fixed length it was - compared against. + The length can be generic (for bytestrings which come from interfaces, + e.g. Bytes[...]). This is indicated with `_length is UNKNOWN_LENGTH`. Attributes ---------- _length : int The maximum allowable length of the data within the type. - _min_length: int - The minimum length of the data within the type. Used when the type - is applied to a literal definition. """ # this is a carveout because currently we allow dynamic arrays of # bytestrings, but not static arrays of bytestrings _as_darray = True _as_hashmap_key = True - _equality_attrs = ("_length", "_min_length") + _equality_attrs = ("_length",) _is_bytestring: bool = True - def __init__(self, length: int = 0) -> None: + def __init__(self, length: int | _UnknownLength = UNKNOWN_LENGTH) -> None: super().__init__() self._length = length - self._min_length = length def __repr__(self): - return f"{self._id}[{self.length}]" + length = self.length + if self.length is None: + length = "..." + return f"{self._id}[{length}]" @property def length(self): """ Property method used to check the length of a type. """ - if self._length: - return self._length - return self._min_length + if self._length is UNKNOWN_LENGTH: + return None + return self._length @property def maxlen(self): @@ -65,8 +70,7 @@ def validate_literal(self, node: vy_ast.Constant) -> None: super().validate_literal(node) if len(node.value) != self.length: - # should always be constructed with correct length - # at the point that validate_literal is called + # sanity check raise CompilerPanic("unreachable") @property @@ -80,9 +84,14 @@ def compare_type(self, other): if not super().compare_type(other): return False - if UNKNOWN_LENGTH in (self._length, other._length): + # can assign any Bytes[N] to Bytes[...] + if self._length is UNKNOWN_LENGTH: return True + # cannot assign Bytes[...] to Bytes[N] without going through convert() + if other._length is UNKNOWN_LENGTH: + return False + return self._length >= other._length @classmethod @@ -95,15 +104,10 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": if node.get("value.id") != cls._id: raise UnexpectedValue("Node id does not match type name") - length = get_index_value(node.slice) # type: ignore + length = get_index_value(node.slice) if length is None: - raise StructureException( - f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node - ) - - # TODO: pass None to constructor after we redo length inference on bytestrings - length = length or 0 + return cls(UNKNOWN_LENGTH) return cls(length) @@ -111,9 +115,7 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": def from_literal(cls, node: vy_ast.Constant) -> "_BytestringT": if not isinstance(node, cls._valid_literal): raise UnexpectedNodeType(f"Not a {cls._id}: {node}") - t = cls() - t.set_min_length(len(node.value)) - return t + return cls(len(node.value)) class BytesT(_BytestringT): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 83a77bd2d8..97f5172150 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -24,11 +24,7 @@ VarAccess, VarOffset, ) -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - infer_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node, infer_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -788,7 +784,7 @@ def _parse_args( if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) infer_type(value, expected_type=type_) - keyword_args.append(KeywordArg(argname, type_, default_value=value,ast_source=arg)) + keyword_args.append(KeywordArg(argname, type_, default_value=value, ast_source=arg)) argnames.add(argname) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 5daa3a9fa7..a3fae18fca 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -13,8 +13,8 @@ from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.utils import ( check_modifiability, - infer_type, get_exact_type_from_node, + infer_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 9dec62e136..733534e31d 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -205,10 +205,17 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "SArrayT": value_type = type_from_annotation(node.value) if not value_type._as_array: - raise StructureException(f"arrays of {value_type} are not allowed!") + raise StructureException(f"arrays of {value_type} are not allowed!", node.value) - # note: validates index is a vy_ast.Int. + # note: validates index is a vy_ast.Int or vy_ast.Ellipsis. length = get_index_value(node.slice) + if length is None: + # CMC 2024-03-08 would it ever be useful to allow static arrays with + # abstract length? + raise StructureException( + "static arrays cannot be defined with generic length!", node.slice + ) + return cls(value_type, length) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 0546668900..3851f0168c 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional from vyper import ast as vy_ast from vyper.exceptions import ( @@ -167,7 +167,7 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: return typ_ -def get_index_value(node: vy_ast.VyperNode) -> int: +def get_index_value(node: vy_ast.VyperNode) -> Optional[int]: """ Return the literal value for a `Subscript` index. @@ -178,9 +178,9 @@ def get_index_value(node: vy_ast.VyperNode) -> int: Returns ------- - int + Optional[int] Literal integer value. - In the future, will return `None` if the subscript is an Ellipsis + Return `None` if the subscript is an Ellipsis """ # this is imported to improve error messages # TODO: revisit this! @@ -189,6 +189,9 @@ def get_index_value(node: vy_ast.VyperNode) -> int: if node.has_folded_value: node = node.get_folded_value() + if isinstance(node, vy_ast.Ellipsis): + return None + if not isinstance(node, vy_ast.Int): # even though the subscript is an invalid type, first check if it's a valid _something_ # this gives a more accurate error in case of e.g. a typo in a constant variable name From d8169e5d60b7e3d90f0d40a419ffc6e67e61e13b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 8 Mar 2024 17:26:18 -0800 Subject: [PATCH 11/14] wip - fix external call codegen test code: ``` interface Foo: def foo(xs: Bytes[...]) -> uint256: view def bar() -> Bytes[...]: view def baz(xs: Bytes[...]): nonpayable def qux() -> Bytes[...]: nonpayable @external def foo(f: Foo): xs: Bytes[10] = b"" t: uint256 = staticcall f.foo(xs) xs = convert(staticcall f.bar(), Bytes[10]) extcall f.baz(xs) extcall f.qux() ``` --- vyper/codegen/external_call.py | 50 +++++++++++++++++-------------- vyper/semantics/analysis/local.py | 2 +- vyper/semantics/analysis/utils.py | 14 ++++----- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index 1e50886baf..bea18cfec0 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -18,7 +18,7 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import TypeCheckFailure -from vyper.semantics.types import InterfaceT, TupleT +from vyper.semantics.types import VOID_TYPE, InterfaceT, TupleT from vyper.semantics.types.function import StateMutability @@ -30,7 +30,7 @@ class _CallKwargs: default_return_value: IRnode -def _pack_arguments(fn_type, args, context): +def _pack_arguments(fn_type, out_type, args, context): # abi encoding just treats all args as a big tuple args_tuple_t = TupleT([x.typ for x in args]) args_as_tuple = IRnode.from_list(["multi"] + [x for x in args], typ=args_tuple_t) @@ -40,12 +40,12 @@ def _pack_arguments(fn_type, args, context): dst_tuple_t = TupleT(fn_type.argument_types[: len(args)]) check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple) - if fn_type.return_type is not None: - return_abi_t = calculate_type_for_external_return(fn_type.return_type).abi_type + if out_type is not None: + out_abi_t = calculate_type_for_external_return(out_type).abi_type # we use the same buffer for args and returndata, # so allocate enough space here for the returndata too. - buflen = max(args_abi_t.size_bound(), return_abi_t.size_bound()) + buflen = max(args_abi_t.size_bound(), out_abi_t.size_bound()) else: buflen = args_abi_t.size_bound() @@ -74,18 +74,16 @@ def _pack_arguments(fn_type, args, context): return buf, pack_args, args_ofst, args_len -def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, expr): - return_t = fn_type.return_type - - if return_t is None: +def _unpack_returndata(buf, out_type, call_kwargs, contract_address, context, expr): + if out_type is None: return ["pass"], 0, 0 - wrapped_return_t = calculate_type_for_external_return(return_t) + wrapped_out = calculate_type_for_external_return(out_type) - abi_return_t = wrapped_return_t.abi_type + abi_out_t = wrapped_out.abi_type - min_return_size = abi_return_t.min_size() - max_return_size = abi_return_t.size_bound() + min_return_size = abi_out_t.min_size() + max_return_size = abi_out_t.size_bound() assert 0 < min_return_size <= max_return_size ret_ofst = buf @@ -95,7 +93,7 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp buf = IRnode.from_list( buf, - typ=wrapped_return_t, + typ=wrapped_out, location=MEMORY, encoding=encoding, annotation=f"{expr.node_source_code} returndata buffer", @@ -112,12 +110,12 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp ) unpacker.append(assertion) - assert isinstance(wrapped_return_t, TupleT) + assert isinstance(wrapped_out, TupleT) # unpack strictly - if needs_clamp(wrapped_return_t, encoding): - return_buf = context.new_internal_variable(wrapped_return_t) - return_buf = IRnode.from_list(return_buf, typ=wrapped_return_t, location=MEMORY) + if needs_clamp(wrapped_out, encoding): + return_buf = context.new_internal_variable(wrapped_out) + return_buf = IRnode.from_list(return_buf, typ=wrapped_out, location=MEMORY) # note: make_setter does ABI decoding and clamps unpacker.append(make_setter(return_buf, buf)) @@ -172,6 +170,13 @@ def _extcodesize_check(address): def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context): fn_type = call_expr.func._metadata["type"] + out_type = call_expr._metadata["type"] + if out_type is VOID_TYPE: + out_type = None # makes downstream logic cleaner + assert fn_type.return_type is None + else: + check_assign(dummy_node_for_type(out_type), dummy_node_for_type(fn_type.return_type)) + # sanity check assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args @@ -182,10 +187,10 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con # a duplicate label exception will get thrown during assembly. ret.append(eval_once_check(_freshname(call_expr.node_source_code))) - buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context) + buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, out_type, args_ir, context) ret_unpacker, ret_ofst, ret_len = _unpack_returndata( - buf, fn_type, call_kwargs, contract_address, context, call_expr + buf, out_type, call_kwargs, contract_address, context, call_expr ) ret += arg_packer @@ -213,11 +218,10 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con ret.append(check_external_call(call_op)) - return_t = fn_type.return_type - if return_t is not None: + if out_type is not None: ret.append(ret_unpacker) - return IRnode.from_list(ret, typ=return_t, location=MEMORY) + return IRnode.from_list(ret, typ=out_type, location=MEMORY) def ir_for_external_call(call_expr, context): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 7852b56868..cbaaa39a69 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -637,7 +637,7 @@ def scope_name(self): return "module" def visit(self, node, typ): - if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): + if not isinstance(typ, TYPE_T): # note: infer_type caches the resolved type on node._metadata["type"] typ = infer_type(node, expected_type=typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index d117b3dcf8..cb35acbaf1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -20,7 +20,7 @@ from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace -from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.base import TYPE_T, VyperType, map_void from vyper.semantics.types.bytestrings import BytesT, StringT, _DynLength from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT @@ -283,12 +283,10 @@ def types_from_StaticCall(self, node): def types_from_Call(self, node): # function calls, e.g. `foo()` or `MyStruct()` var = self.get_exact_type_from_node(node.func, include_type_exprs=True) - return_value = var.fetch_call_return(node) - if return_value: - if isinstance(return_value, list): - return return_value - return [return_value] - raise InvalidType(f"{var} did not return a value", node) + return_value = map_void(var.fetch_call_return(node)) + if isinstance(return_value, list): + return return_value + return [return_value] def types_from_Constant(self, node): # literal value (integer, string, etc) @@ -598,7 +596,7 @@ def _infer_type_helper(node, expected_type): else: for given, expected in itertools.product(given_types, expected_type): if expected.compare_type(given): - return given + return expected # validation failed, prepare a meaningful error message if len(expected_type) > 1: From a7556d8f8ec02072d373cd15d81d212990d580e9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 14 Mar 2024 12:45:55 -0400 Subject: [PATCH 12/14] add a hint --- vyper/builtins/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index ce91231842..4490ee90a5 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -220,7 +220,10 @@ def infer_arg_types(self, node, expected_return_typ=None): # block conversions between same type if target_type.compare_type(value_type): - raise InvalidType(f"Value and target type are both '{target_type}'", node) + raise InvalidType( + f"Value and target type are both `{target_type}`", + hint="try removing the call to `convert()`", + ) return [value_type, TYPE_T(target_type)] From 0eedb384e2788a843ef31ed68864755f93596d72 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 19 Mar 2024 17:00:23 -0400 Subject: [PATCH 13/14] handle more cases in generic test code: ``` interface Foo: def foo(xs: Bytes[...]) -> uint256: view def bar() -> Bytes[...]: view def baz(xs: Bytes[...]): nonpayable def qux() -> Bytes[...]: nonpayable @external def foo(f: Foo): xs: Bytes[10] = b"" t: uint256 = staticcall f.foo(xs) xs = staticcall f.bar() extcall f.baz(xs) extcall f.qux() x: Bytes[10] = extcall f.qux() ``` --- vyper/codegen/core.py | 6 +++++- vyper/codegen/expr.py | 15 ++++++++------- vyper/codegen/external_call.py | 14 +++++++++----- vyper/semantics/analysis/utils.py | 4 +--- vyper/semantics/types/bytestrings.py | 2 +- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 5c5d7759f7..3bac066a4d 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -727,7 +727,11 @@ def dummy_node_for_type(typ): def _check_assign_bytes(left, right): - if left.typ.maxlen is not None and right.typ.maxlen > left.typ.maxlen: # pragma: nocover + if ( + left.typ.maxlen is not None + and right.typ.maxlen is not None + and right.typ.maxlen > left.typ.maxlen + ): # pragma: nocover raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}") # stricter check for zeroing a byte array. diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index d7afe6c7f6..8412359c76 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -702,26 +702,27 @@ def parse_Call(self): assert func_t.is_internal or func_t.is_constructor return self_call.ir_for_self_call(self.expr, self.context) - @classmethod - def handle_external_call(cls, expr, context): + def handle_external_call(self): # TODO fix cyclic import from vyper.builtins._signatures import BuiltinFunctionT - call_node = expr.value + call_node = self.expr.value assert isinstance(call_node, vy_ast.Call) func_t = call_node.func._metadata["type"] if isinstance(func_t, BuiltinFunctionT): - return func_t.build_IR(call_node, context) + return func_t.build_IR(call_node, self.context) - return external_call.ir_for_external_call(call_node, context) + return external_call.ir_for_external_call( + call_node, self.context, discard_output=self.is_stmt + ) def parse_ExtCall(self): - return self.handle_external_call(self.expr, self.context) + return self.handle_external_call() def parse_StaticCall(self): - return self.handle_external_call(self.expr, self.context) + return self.handle_external_call() def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index bea18cfec0..0252e732a5 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -167,13 +167,15 @@ def _extcodesize_check(address): return IRnode.from_list(["assert", ["extcodesize", address]], error_msg="extcodesize is zero") -def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context): +def _external_call_helper( + contract_address, args_ir, call_kwargs, call_expr, context, discard_output +): fn_type = call_expr.func._metadata["type"] out_type = call_expr._metadata["type"] - if out_type is VOID_TYPE: + if out_type is VOID_TYPE or discard_output: + assert (out_type is VOID_TYPE) == (fn_type.return_type is None), out_type out_type = None # makes downstream logic cleaner - assert fn_type.return_type is None else: check_assign(dummy_node_for_type(out_type), dummy_node_for_type(fn_type.return_type)) @@ -224,7 +226,7 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con return IRnode.from_list(ret, typ=out_type, location=MEMORY) -def ir_for_external_call(call_expr, context): +def ir_for_external_call(call_expr, context, discard_output): from vyper.codegen.expr import Expr # TODO rethink this circular import contract_address = Expr.parse_value_expr(call_expr.func.value, context) @@ -234,5 +236,7 @@ def ir_for_external_call(call_expr, context): with contract_address.cache_when_complex("external_contract") as (b1, contract_address): return b1.resolve( - _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context) + _external_call_helper( + contract_address, args_ir, call_kwargs, call_expr, context, discard_output + ) ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index cb35acbaf1..43aa47059f 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -21,7 +21,7 @@ from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType, map_void -from vyper.semantics.types.bytestrings import BytesT, StringT, _DynLength +from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT from vyper.utils import checksum_encode, int_to_fourbytes @@ -616,8 +616,6 @@ def _infer_type_helper(node, expected_type): hint = None # TODO: refactor the suggestions code. compare_type could maybe return # a suggestion if the type is close. - if isinstance(given, _DynLength) and given.length is None: - hint = f"did you mean `convert({node.node_source_code}, {expected_type[0]})`?" raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", hint=hint) else: if len(given_types) == 1: diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index a2bdd01d90..f45dbbf9f8 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -93,7 +93,7 @@ def compare_type(self, other): # cannot assign Bytes[...] to Bytes[N] without going through convert() if other._length is UNKNOWN_LENGTH: - return False + return True return self._length >= other._length From 6ec773024a21887445fade21773039f12dffcb89 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 19 Mar 2024 21:09:42 +0000 Subject: [PATCH 14/14] handle TYPE_T --- vyper/semantics/analysis/local.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cbaaa39a69..13a9ec6d1c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -637,7 +637,9 @@ def scope_name(self): return "module" def visit(self, node, typ): - if not isinstance(typ, TYPE_T): + if isinstance(typ, TYPE_T): + node._metadata["type"] = typ + else: # note: infer_type caches the resolved type on node._metadata["type"] typ = infer_type(node, expected_type=typ)