diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index f63ef8484a..ebf4bf6f89 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -4,7 +4,7 @@ from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import infer_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256: # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). - validate_expected_type(new_node, UINT256_T) + _ = infer_type(new_node, UINT256_T) # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except OverflowException: # here: check the wrapped value matches runtime @@ -82,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256: vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() - validate_expected_type(new_node, INT256_T) # force bounds check + _ = infer_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except (TypeMismatch, OverflowException): diff --git a/vyper/abi_types.py b/vyper/abi_types.py index a95930b16d..38d0b43f04 100644 --- a/vyper/abi_types.py +++ b/vyper/abi_types.py @@ -154,7 +154,7 @@ def is_complex_type(self): class ABI_Bytes(ABIType): def __init__(self, bytes_bound): - if not bytes_bound >= 0: + if bytes_bound is not None and not bytes_bound >= 0: raise InvalidABIType("Negative bytes_bound provided to ABI_Bytes") self.bytes_bound = bytes_bound diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index a494e4a344..1b9be8af44 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -431,7 +431,8 @@ def _cast_bytestring(expr, arg, out_typ): _FAIL(arg.typ, out_typ, expr) ret = ["seq"] - if out_typ.maxlen < arg.typ.maxlen: + assert out_typ.maxlen is not None + if arg.typ.maxlen is None or out_typ.maxlen < arg.typ.maxlen: ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]]) ret.append(arg) # NOTE: this is a pointer cast diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d012e4a1cf..bac839a997 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -7,11 +7,7 @@ from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode from vyper.semantics.analysis.base import Modifiability -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node, infer_type from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -101,7 +97,7 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N # for its side effects (will throw if is not a type) type_from_annotation(arg) else: - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type) def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 629ab684a8..296e0d1fa6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -58,7 +58,7 @@ get_common_types, get_exact_type_from_node, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import ( TYPE_T, @@ -207,12 +207,14 @@ def infer_arg_types(self, node, expected_return_typ=None): target_type = type_from_annotation(node.args[1]) value_types = get_possible_types_from_node(node.args[0]) - # For `convert` of integer literals, we need to match type inference rules in - # convert.py codegen routines. + # For `convert` of integer literals, we need to match type inference + # rules in convert.py codegen routines. # TODO: This can probably be removed once constant folding for `convert` is implemented if len(value_types) > 1 and all(isinstance(v, IntegerT) for v in value_types): - # Get the smallest (and unsigned if available) type for non-integer target types - # (note this is different from the ordering returned by `get_possible_types_from_node`) + # Get the smallest (and unsigned if available) type for + # non-integer target types + # (note this is different from the ordering returned by + # `get_possible_types_from_node`) if not isinstance(target_type, IntegerT): value_types = sorted(value_types, key=lambda v: (v.is_signed, v.bits), reverse=True) else: @@ -223,7 +225,10 @@ def infer_arg_types(self, node, expected_return_typ=None): # block conversions between same type if target_type.compare_type(value_type): - raise InvalidType(f"Value and target type are both '{target_type}'", node) + raise InvalidType( + f"Value and target type are both `{target_type}`", + hint="try removing the call to `convert()`", + ) return [value_type, TYPE_T(target_type)] @@ -296,11 +301,6 @@ class Slice(BuiltinFunctionT): def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node) - if isinstance(arg_type, StringT): - return_type = StringT() - else: - return_type = BytesT() - # validate start and length are in bounds arg = node.args[0] @@ -329,20 +329,14 @@ def fetch_call_return(self, node): if length_literal is not None and start_literal + length_literal > arg_type.length: raise ArgumentException(f"slice out of bounds for {arg_type}", node) - # we know the length statically + return_cls = arg_type.__class__ if length_literal is not None: - return_type.set_length(length_literal) + return_type = return_cls(length_literal) else: - return_type.set_min_length(arg_type.length) + return_type = return_cls(arg_type.length) return return_type - def infer_arg_types(self, node, expected_return_typ=None): - self._validate_arg_types(node) - # return a concrete type for `b` - b_type = get_possible_types_from_node(node.args[0]).pop() - return [b_type, self._inputs[1][1], self._inputs[2][1]] - @process_inputs def build_IR(self, expr, args, kwargs, context): src, start, length = args @@ -498,12 +492,8 @@ def fetch_call_return(self, node): for arg_t in arg_types: length += arg_t.length - if isinstance(arg_types[0], (StringT)): - return_type = StringT() - else: - return_type = BytesT() - return_type.set_length(length) - return return_type + return_type_cls = arg_types[0].__class__ + return return_type_cls(length) def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: @@ -515,8 +505,7 @@ def infer_arg_types(self, node, expected_return_typ=None): ret = [] prev_typeclass = None for arg in node.args: - validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) - arg_t = get_possible_types_from_node(arg).pop() + arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes" if prev_typeclass and current_typeclass != prev_typeclass: raise TypeMismatch( @@ -855,7 +844,7 @@ def infer_kwarg_types(self, node): "Output type must be one of integer, bytes32 or address", node.keywords[0].value ) output_typedef = TYPE_T(output_type) - node.keywords[0].value._metadata["type"] = output_typedef + # node.keywords[0].value._metadata["type"] = output_typedef else: output_typedef = TYPE_T(BYTES32_T) @@ -1039,8 +1028,7 @@ def fetch_call_return(self, node): raise if outsize.value: - return_type = BytesT() - return_type.set_min_length(outsize.value) + return_type = BytesT(outsize.value) if revert_on_failure: return return_type @@ -2355,9 +2343,8 @@ def infer_kwarg_types(self, node): ret = {} for kwarg in node.keywords: kwarg_name = kwarg.arg - validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) + typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ) - typ = get_exact_type_from_node(kwarg.value) if kwarg_name == "method_id" and isinstance(typ, BytesT): if typ.length != 4: raise InvalidLiteral("method_id must be exactly 4 bytes!", kwarg.value) @@ -2392,9 +2379,7 @@ def fetch_call_return(self, node): # the output includes 4 bytes for the method_id. maxlen += 4 - ret = BytesT() - ret.set_length(maxlen) - return ret + return BytesT(maxlen) @staticmethod def _parse_method_id(method_id_literal): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 789cc77524..8858395c84 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -737,7 +737,11 @@ def dummy_node_for_type(typ): def _check_assign_bytes(left, right): - if right.typ.maxlen > left.typ.maxlen: # pragma: nocover + if ( + left.typ.maxlen is not None + and right.typ.maxlen is not None + and right.typ.maxlen > left.typ.maxlen + ): # pragma: nocover raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}") # stricter check for zeroing a byte array. diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index a712a09f96..fc2ad90504 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -733,26 +733,27 @@ def parse_Call(self): return self_call.ir_for_self_call(self.expr, self.context) - @classmethod - def handle_external_call(cls, expr, context): + def handle_external_call(self): # TODO fix cyclic import from vyper.builtins._signatures import BuiltinFunctionT - call_node = expr.value + call_node = self.expr.value assert isinstance(call_node, vy_ast.Call) func_t = call_node.func._metadata["type"] if isinstance(func_t, BuiltinFunctionT): - return func_t.build_IR(call_node, context) + return func_t.build_IR(call_node, self.context) - return external_call.ir_for_external_call(call_node, context) + return external_call.ir_for_external_call( + call_node, self.context, discard_output=self.is_stmt + ) def parse_ExtCall(self): - return self.handle_external_call(self.expr, self.context) + return self.handle_external_call() def parse_StaticCall(self): - return self.handle_external_call(self.expr, self.context) + return self.handle_external_call() def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index 331b991bfe..cf2041099b 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -20,7 +20,7 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import TypeCheckFailure -from vyper.semantics.types import InterfaceT, TupleT +from vyper.semantics.types import VOID_TYPE, InterfaceT, TupleT from vyper.semantics.types.function import StateMutability @@ -32,7 +32,7 @@ class _CallKwargs: default_return_value: IRnode -def _pack_arguments(fn_type, args, context): +def _pack_arguments(fn_type, out_type, args, context): # abi encoding just treats all args as a big tuple args_tuple_t = TupleT([x.typ for x in args]) args_as_tuple = IRnode.from_list(["multi"] + [x for x in args], typ=args_tuple_t) @@ -42,12 +42,12 @@ def _pack_arguments(fn_type, args, context): dst_tuple_t = TupleT(fn_type.argument_types[: len(args)]) check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple) - if fn_type.return_type is not None: - return_abi_t = calculate_type_for_external_return(fn_type.return_type).abi_type + if out_type is not None: + out_abi_t = calculate_type_for_external_return(out_type).abi_type # we use the same buffer for args and returndata, # so allocate enough space here for the returndata too. - buflen = max(args_abi_t.size_bound(), return_abi_t.size_bound()) + buflen = max(args_abi_t.size_bound(), out_abi_t.size_bound()) else: buflen = args_abi_t.size_bound() @@ -78,13 +78,11 @@ def _pack_arguments(fn_type, args, context): return buf, pack_args, args_ofst, args_len -def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, expr): - return_t = fn_type.return_type - - if return_t is None: +def _unpack_returndata(buf, out_type, call_kwargs, contract_address, context, expr): + if out_type is None: return ["pass"], 0, 0 - wrapped_return_t = calculate_type_for_external_return(return_t) + wrapped_return_t = calculate_type_for_external_return(out_type) abi_return_t = wrapped_return_t.abi_type @@ -188,9 +186,18 @@ def _extcodesize_check(address): return IRnode.from_list(["assert", ["extcodesize", address]], error_msg="extcodesize is zero") -def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context): +def _external_call_helper( + contract_address, args_ir, call_kwargs, call_expr, context, discard_output +): fn_type = call_expr.func._metadata["type"] + out_type = call_expr._metadata["type"] + if out_type is VOID_TYPE or discard_output: + assert (out_type is VOID_TYPE) == (fn_type.return_type is None), out_type + out_type = None # makes downstream logic cleaner + else: + check_assign(dummy_node_for_type(out_type), dummy_node_for_type(fn_type.return_type)) + # sanity check assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args @@ -201,10 +208,10 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con # a duplicate label exception will get thrown during assembly. ret.append(eval_once_check(_freshname(call_expr.node_source_code))) - buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context) + buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, out_type, args_ir, context) ret_unpacker, ret_ofst, ret_len = _unpack_returndata( - buf, fn_type, call_kwargs, contract_address, context, call_expr + buf, out_type, call_kwargs, contract_address, context, call_expr ) ret += arg_packer @@ -232,14 +239,13 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con ret.append(check_external_call(call_op)) - return_t = fn_type.return_type - if return_t is not None: + if out_type is not None: ret.append(ret_unpacker) - return IRnode.from_list(ret, typ=return_t, location=MEMORY) + return IRnode.from_list(ret, typ=out_type, location=MEMORY) -def ir_for_external_call(call_expr, context): +def ir_for_external_call(call_expr, context, discard_output): from vyper.codegen.expr import Expr # TODO rethink this circular import contract_address = Expr.parse_value_expr(call_expr.func.value, context) @@ -249,5 +255,7 @@ def ir_for_external_call(call_expr, context): with contract_address.cache_when_complex("external_contract") as (b1, contract_address): return b1.resolve( - _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context) + _external_call_helper( + contract_address, args_ir, call_kwargs, call_expr, context, discard_output + ) ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 461326d72d..c734a2f35d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -35,8 +35,8 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, + infer_type, uses_state, - validate_expected_type, ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS @@ -368,17 +368,19 @@ def visit_AnnAssign(self, node): self.expr_visitor.visit(node.target, typ) def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE": + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + return + if isinstance(msg_node, vy_ast.Str): if not msg_node.value.strip(): raise StructureException("Reason string cannot be empty", msg_node) - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - validate_expected_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + try: + self.expr_visitor.visit(msg_node, StringT.any()) + except TypeMismatch as e: + # improve the error message + msg = "reason must be a string or the special `UNREACHABLE` value" + raise TypeMismatch(msg, msg_node) from e def visit_Assert(self, node): if node.msg: @@ -545,7 +547,7 @@ def _analyse_list_iter(self, iter_node, target_type): except (InvalidType, StructureException): raise InvalidType("Not an iterable type", iter_node) - # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # CMC 2024-02-09 TODO: use infer_type once we have DArrays # with generic length. if not isinstance(iter_type, (DArrayT, SArrayT)): raise InvalidType("Not an iterable type", iter_node) @@ -653,16 +655,14 @@ def scope_name(self): return "module" def visit(self, node, typ): - if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): - validate_expected_type(node, typ) + if isinstance(typ, TYPE_T): + node._metadata["type"] = typ + else: + # note: infer_type caches the resolved type on node._metadata["type"] + typ = infer_type(node, expected_type=typ) - # recurse and typecheck in case we are being fed the wrong type for - # some reason. super().visit(node, typ) - # annotate - node._metadata["type"] = typ - if not isinstance(typ, TYPE_T): info = get_expr_info(node) # get_expr_info fills in node._expr_info @@ -932,7 +932,7 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - # these guarantees should be provided by validate_expected_type + # these guarantees should be provided by infer_type assert isinstance(typ, TupleT) assert len(node.elements) == len(typ.member_types) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 534af4d633..e6cf96fe4d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -664,7 +664,8 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) # performs validate_expected_type + + ExprVisitor().visit(node.value, type_) # performs type validation if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 8727f3750d..8a11372fdd 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -22,7 +22,7 @@ from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace -from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.base import TYPE_T, VyperType, map_void from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT @@ -226,7 +226,7 @@ def types_from_BinOp(self, node): # can be different types types_list = get_possible_types_from_node(node.left) # check rhs is unsigned integer - validate_expected_type(node.right, IntegerT.unsigneds()) + _ = infer_type(node.right, IntegerT.unsigneds()) else: types_list = get_common_types(node.left, node.right) @@ -289,12 +289,10 @@ def types_from_StaticCall(self, node): def types_from_Call(self, node): # function calls, e.g. `foo()` or `MyStruct()` var = self.get_exact_type_from_node(node.func, include_type_exprs=True) - return_value = var.fetch_call_return(node) - if return_value: - if isinstance(return_value, list): - return return_value - return [return_value] - raise InvalidType(f"{var} did not return a value", node) + return_value = map_void(var.fetch_call_return(node)) + if isinstance(return_value, list): + return return_value + return [return_value] def types_from_Constant(self, node): # literal value (integer, string, etc) @@ -329,7 +327,7 @@ def types_from_Constant(self, node): raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node) def types_from_IfExp(self, node): - validate_expected_type(node.test, BoolT()) + _ = infer_type(node.test, expected_type=BoolT()) types_list = get_common_types(node.body, node.orelse) if not types_list: @@ -542,14 +540,14 @@ def _validate_literal_array(node, expected): for item in node.elements: try: - validate_expected_type(item, expected.value_type) + _ = infer_type(item, expected.value_type) except (InvalidType, TypeMismatch): return False return True -def validate_expected_type(node, expected_type): +def infer_type(node, expected_type): """ Validate that the given node matches the expected type(s) @@ -564,8 +562,16 @@ def validate_expected_type(node, expected_type): Returns ------- - None + The inferred type. The inferred type must be a concrete type which + is compatible with the expected type (although the expected type may + be generic). """ + ret = _infer_type_helper(node, expected_type) + node._metadata["type"] = ret + return ret + + +def _infer_type_helper(node, expected_type): if not isinstance(expected_type, tuple): expected_type = (expected_type,) @@ -574,15 +580,15 @@ def validate_expected_type(node, expected_type): for t in possible_tuple_types: if len(t.member_types) != len(node.elements): continue - for item_ast, item_type in zip(node.elements, t.member_types): + ret = [] + for item_ast, expected_item_type in zip(node.elements, t.member_types): try: - validate_expected_type(item_ast, item_type) - return + item_t = infer_type(item_ast, expected_type=expected_item_type) + ret.append(item_t) except VyperException: - pass - else: - # fail block - pass + break # go to fail block + else: + return TupleT(tuple(ret)) given_types = _ExprAnalyser().get_possible_types_from_node(node) @@ -592,11 +598,11 @@ def validate_expected_type(node, expected_type): if not isinstance(expected, (DArrayT, SArrayT)): continue if _validate_literal_array(node, expected): - return + return expected else: for given, expected in itertools.product(given_types, expected_type): if expected.compare_type(given): - return + return expected # validation failed, prepare a meaningful error message if len(expected_type) > 1: @@ -613,7 +619,10 @@ def validate_expected_type(node, expected_type): vy_ast.Name, include_self=True ): given = given_types[0] - raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", node) + hint = None + # TODO: refactor the suggestions code. compare_type could maybe return + # a suggestion if the type is close. + raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", hint=hint) else: if len(given_types) == 1: given_str = str(given_types[0]) @@ -621,13 +630,12 @@ def validate_expected_type(node, expected_type): types_str = sorted(str(i) for i in given_types) given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - suggestion_str = "" + hint = None if expected_type[0] == AddressT() and given_types[0] == BytesM_T(20): - suggestion_str = f" Did you mean {checksum_encode(node.value)}?" + hint = f" Did you mean `{checksum_encode(node.value)}`?" raise TypeMismatch( - f"Expected {expected_str} but literal can only be cast as {given_str}.{suggestion_str}", - node, + f"Expected {expected_str} but literal can only be cast as {given_str}.", hint=hint ) @@ -729,7 +737,7 @@ def validate_kwargs(node: vy_ast.Call, members: dict[str, VyperType], typeclass: raise InvalidAttribute(msg, kwarg, hint=hint) expected_type = members[argname] - validate_expected_type(kwarg.value, expected_type) + _ = infer_type(kwarg.value, expected_type) missing = OrderedSet(members.keys()) - OrderedSet(seen.keys()) if len(missing) > 0: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index b881f52b2b..6e1237db83 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,6 +1,6 @@ from . import primitives, subscriptable, user from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void -from .bytestrings import BytesT, StringT, _BytestringT +from .bytestrings import UNKNOWN_LENGTH, BytesT, StringT, _BytestringT from .function import ContractFunctionT, MemberFunctionT from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 02e3bb213f..dfac1abaeb 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -6,41 +6,49 @@ from vyper.utils import ceil32 -class _BytestringT(VyperType): +class _UnknownLength(object): + pass + + +UNKNOWN_LENGTH = _UnknownLength() + + +# TODO: make this a trait which DynArray also inherits from +class _DynLength(VyperType): + pass + + +class _BytestringT(_DynLength): """ Private base class for single-value types which occupy multiple memory slots and where a maximum length must be given via a subscript (string, bytes). - Types for literals have an inferred minimum length. For example, `b"hello"` - has a length of 5 of more and so can be used in an operation with `bytes[5]` - or `bytes[10]`, but not `bytes[4]`. Upon comparison to a fixed length type, - the minimum length is discarded and the type assumes the fixed length it was - compared against. + The length can be generic (for bytestrings which come from interfaces, + e.g. Bytes[...]). This is indicated with `_length is UNKNOWN_LENGTH`. Attributes ---------- _length : int The maximum allowable length of the data within the type. - _min_length: int - The minimum length of the data within the type. Used when the type - is applied to a literal definition. """ # this is a carveout because currently we allow dynamic arrays of # bytestrings, but not static arrays of bytestrings _as_darray = True _as_hashmap_key = True - _equality_attrs = ("_length", "_min_length") + _equality_attrs = ("_length",) _is_bytestring: bool = True - def __init__(self, length: int = 0) -> None: + def __init__(self, length: int | _UnknownLength = UNKNOWN_LENGTH) -> None: super().__init__() self._length = length - self._min_length = length def __repr__(self): - return f"{self._id}[{self.length}]" + length = self.length + if self.length is None: + length = "..." + return f"{self._id}[{length}]" def _addl_dict_fields(self): return {"length": self.length} @@ -50,9 +58,9 @@ def length(self): """ Property method used to check the length of a type. """ - if self._length: - return self._length - return self._min_length + if self._length is UNKNOWN_LENGTH: + return None + return self._length @property def maxlen(self): @@ -65,64 +73,29 @@ def validate_literal(self, node: vy_ast.Constant) -> None: super().validate_literal(node) if len(node.value) != self.length: - # should always be constructed with correct length - # at the point that validate_literal is called + # sanity check raise CompilerPanic("unreachable") @property def size_in_bytes(self): - # the first slot (32 bytes) stores the actual length, and then we reserve - # enough additional slots to store the data if it uses the max available length - # because this data type is single-bytes, we make it so it takes the max 32 byte - # boundary as it's size, instead of giving it a size that is not cleanly divisible by 32 - + # the first slot (32 bytes) stores the actual length, and then we + # reserve enough additional slots to store the data. allocate 32-byte + # aligned buffer for the data. return 32 + ceil32(self.length) - def set_length(self, length): - """ - Sets the exact length of the type. - - May only be called once, and only on a type that does not yet have - a fixed length. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - self._length = length - self._min_length = length - - def set_min_length(self, min_length): - """ - Sets the minimum length of the type. - - May only be used to increase the minimum length. May not be called if - an exact length has been set. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - if self._min_length > min_length: - raise CompilerPanic("Cannot reduce the min_length of ArrayValueType") - self._min_length = min_length - def compare_type(self, other): if not super().compare_type(other): return False - # CMC 2022-03-18 TODO this method should be refactored so it does not have side effects - - # when comparing two literals, both now have an equal min-length - if not self._length and not other._length: - min_length = max(self._min_length, other._min_length) - self.set_min_length(min_length) - other.set_min_length(min_length) + # can assign any Bytes[N] to Bytes[...] + if self._length is UNKNOWN_LENGTH: return True - # comparing a defined length to a literal causes the literal to have a fixed length - if self._length: - if not other._length: - other.set_length(max(self._length, other._min_length)) - return self._length >= other._length + # cannot assign Bytes[...] to Bytes[N] without going through convert() + if other._length is UNKNOWN_LENGTH: + return True - return other.compare_type(self) + return self._length >= other._length @classmethod def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": @@ -134,15 +107,10 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": if node.get("value.id") != cls._id: raise UnexpectedValue("Node id does not match type name") - length = get_index_value(node.slice) # type: ignore + length = get_index_value(node.slice) if length is None: - raise StructureException( - f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node - ) - - # TODO: pass None to constructor after we redo length inference on bytestrings - length = length or 0 + return cls(UNKNOWN_LENGTH) return cls(length) @@ -150,9 +118,7 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": def from_literal(cls, node: vy_ast.Constant) -> "_BytestringT": if not isinstance(node, cls._valid_literal): raise UnexpectedNodeType(f"Not a {cls._id}: {node}") - t = cls() - t.set_min_length(len(node.value)) - return t + return cls(len(node.value)) class BytesT(_BytestringT): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index ffeb5b7299..61e6913c29 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,8 +27,8 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, + infer_type, uses_state, - validate_expected_type, ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType @@ -640,7 +640,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: for arg, expected in zip(node.args, self.arguments): try: - validate_expected_type(arg, expected.typ) + infer_type(arg, expected.typ) except TypeMismatch as e: raise self._enhance_call_exception(e, expected.ast_source or self.ast_def) @@ -653,7 +653,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: f"`{kwarg.arg}=` specified but {self.name}() does not return anything", kwarg.value, ) - validate_expected_type(kwarg.value, kwarg_settings.typ) + infer_type(kwarg.value, kwarg_settings.typ) if kwarg_settings.require_literal: if not isinstance(kwarg.value, vy_ast.Constant): raise InvalidType( @@ -819,8 +819,8 @@ def _parse_args( value = funcdef.args.defaults[i - n_positional_args] if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) - validate_expected_type(value, type_) - keyword_args.append(KeywordArg(argname, type_, ast_source=arg, default_value=value)) + infer_type(value, expected_type=type_) + keyword_args.append(KeywordArg(argname, type_, default_value=value, ast_source=arg)) argnames.add(argname) @@ -882,7 +882,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): # CMC 2022-04-01 this should probably be in the validation module - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type=expected_type) return self.return_type diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 498757b94e..4d335c3f5d 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -14,7 +14,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation @@ -100,8 +100,8 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": def _ctor_arg_types(self, node): validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] + typ = infer_type(node.args[0], AddressT()) + return [typ] def _ctor_kwarg_types(self, node): return {} diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 4068d815d2..bb3a677413 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -35,9 +35,9 @@ def getter_signature(self) -> Tuple[Tuple, Optional[VyperType]]: def validate_index_type(self, node): # TODO: break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type - validate_expected_type(node, self.key_type) + infer_type(node, self.key_type) class HashMapT(_SubscriptableT): @@ -127,7 +127,7 @@ def count(self): def validate_index_type(self, node): # TODO break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type node = node.reduced() @@ -137,7 +137,7 @@ def validate_index_type(self, node): if node.value >= self.length: raise ArrayIndexException("Index out of range", node) - validate_expected_type(node, IntegerT.any()) + infer_type(node, IntegerT.any()) def get_subscripted_type(self, node): return self.value_type @@ -213,10 +213,17 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "SArrayT": value_type = type_from_annotation(node.value) if not value_type._as_array: - raise StructureException(f"arrays of {value_type} are not allowed!") + raise StructureException(f"arrays of {value_type} are not allowed!", node.value) - # note: validates index is a vy_ast.Int. + # note: validates index is a vy_ast.Int or vy_ast.Ellipsis. length = get_index_value(node.slice) + if length is None: + # CMC 2024-03-08 would it ever be useful to allow static arrays with + # abstract length? + raise StructureException( + "static arrays cannot be defined with generic length!", node.slice + ) + return cls(value_type, length) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index bed5542785..709cfbb178 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -14,11 +14,7 @@ VariableDeclarationException, ) from vyper.semantics.analysis.base import Modifiability -from vyper.semantics.analysis.utils import ( - check_modifiability, - validate_expected_type, - validate_kwargs, -) +from vyper.semantics.analysis.utils import check_modifiability, infer_type, validate_kwargs from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -312,7 +308,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): - validate_expected_type(arg, expected) + infer_type(arg, expected) def to_toplevel_abi_dict(self) -> list[dict]: return [ diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 2840c37bb6..6fbc7b886f 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -1,3 +1,5 @@ +from typing import Optional + from vyper import ast as vy_ast from vyper.compiler.settings import get_global_settings from vyper.exceptions import ( @@ -176,7 +178,7 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: return typ_ -def get_index_value(node: vy_ast.VyperNode) -> int: +def get_index_value(node: vy_ast.VyperNode) -> Optional[int]: """ Return the literal value for a `Subscript` index. @@ -187,9 +189,9 @@ def get_index_value(node: vy_ast.VyperNode) -> int: Returns ------- - int + Optional[int] Literal integer value. - In the future, will return `None` if the subscript is an Ellipsis + Return `None` if the subscript is an Ellipsis """ # this is imported to improve error messages # TODO: revisit this! @@ -197,6 +199,9 @@ def get_index_value(node: vy_ast.VyperNode) -> int: node = node.reduced() + if isinstance(node, vy_ast.Ellipsis): + return None + if not isinstance(node, vy_ast.Int): # even though the subscript is an invalid type, first check if it's a valid _something_ # this gives a more accurate error in case of e.g. a typo in a constant variable name