Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: infer expected types #3765

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6b9fff2
rename validate_expected_type to infer_type and have it return a type
charles-cooper Feb 9, 2024
c1b8979
Merge branch 'master' into feat/infer_expected_types
charles-cooper Feb 13, 2024
b3e2fd9
update a comment
charles-cooper Feb 13, 2024
3fd9fb8
improve type inference for revert reason strings
charles-cooper Feb 13, 2024
1350838
format some comments
charles-cooper Feb 13, 2024
2013d81
use the result of infer_type
charles-cooper Feb 13, 2024
71aab61
remove length modification functions on bytestring
charles-cooper Feb 13, 2024
d7993ec
feat[lang]: allow downcasting of bytestrings
charles-cooper Mar 5, 2024
93e53c1
fix direction of some comparisons
charles-cooper Mar 5, 2024
b2e62a2
fix existing tests, add tests for new functionality, add compile-time…
charles-cooper Mar 6, 2024
5348802
Merge branch 'master' into feat/bytestring-cast
charles-cooper Mar 8, 2024
f8689ab
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 8, 2024
37880cb
Merge branch 'feat/bytestring-cast' into feat/infer_expected_types
charles-cooper Mar 8, 2024
35ec413
allow bytestrings with ellipsis length
charles-cooper Mar 8, 2024
d8169e5
wip - fix external call codegen
charles-cooper Mar 9, 2024
45960ef
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 13, 2024
a7556d8
add a hint
charles-cooper Mar 14, 2024
0eedb38
handle more cases in generic
charles-cooper Mar 19, 2024
06217ef
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 19, 2024
6ec7730
handle TYPE_T
charles-cooper Mar 19, 2024
9988eb3
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 23, 2024
2e126d9
Merge branch 'master' into feat/infer_expected_types
charles-cooper Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import eth.codecs.abi.exceptions
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import InvalidLiteral, InvalidType, TypeMismatch
from vyper.semantics.types import AddressT, BoolT, BytesM_T, BytesT, DecimalT, IntegerT, StringT
from vyper.semantics.types.shortcuts import BYTES20_T, BYTES32_T, UINT, UINT160_T, UINT256_T
Expand Down Expand Up @@ -560,23 +561,75 @@ def foo(x: {i_typ}) -> {o_typ}:
assert_compile_failed(lambda: get_contract(code), TypeMismatch)


@pytest.mark.parametrize("typ", sorted(TEST_TYPES))
def test_bytes_too_large_cases(get_contract, assert_compile_failed, typ):
@pytest.mark.parametrize("typ", sorted(BASE_TYPES))
def test_bytes_too_large_cases(typ):
code_1 = f"""
@external
def foo(x: Bytes[33]) -> {typ}:
return convert(x, {typ})
"""
assert_compile_failed(lambda: get_contract(code_1), TypeMismatch)
with pytest.raises(TypeMismatch):
compile_code(code_1)

bytes_33 = b"1" * 33
code_2 = f"""
@external
def foo() -> {typ}:
return convert({bytes_33}, {typ})
"""
with pytest.raises(TypeMismatch):
compile_code(code_2)

assert_compile_failed(lambda: get_contract(code_2, TypeMismatch))

@pytest.mark.parametrize("cls1,cls2", itertools.product((StringT, BytesT), (StringT, BytesT)))
def test_bytestring_conversions(cls1, cls2, get_contract, tx_failed):
typ1 = cls1(33)
typ2 = cls2(32)

def bytestring(cls, string):
if cls == BytesT:
return string.encode("utf-8")
return string

code_1 = f"""
@external
def foo(x: {typ1}) -> {typ2}:
return convert(x, {typ2})
"""
c = get_contract(code_1)

for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
assert c.foo(arg) == out

with tx_failed():
# TODO: sanity check it is convert which is reverting, not arg clamping
c.foo(bytestring(cls1, "1" * 33))

code_2_template = """
@external
def foo() -> {typ}:
return convert({arg}, {typ})
"""

# test literals
for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
code = code_2_template.format(typ=typ2, arg=repr(arg))
if cls1 == cls2: # ex.: can't convert "" to String[32]
with pytest.raises(InvalidType):
compile_code(code)
else:
c = get_contract(code)
assert c.foo() == out

failing_code = code_2_template.format(typ=typ2, arg=bytestring(cls1, "1" * 33))
with pytest.raises(TypeMismatch):
compile_code(failing_code)


@pytest.mark.parametrize("n", range(1, 33))
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.utils import parse_and_fold
from vyper.exceptions import OverflowException, TypeMismatch
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import infer_type
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
from vyper.utils import unsigned_to_signed

Expand Down Expand Up @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256:
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
# more types are added).
validate_expected_type(new_node, UINT256_T)
_ = infer_type(new_node, UINT256_T)
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except OverflowException: # here: check the wrapped value matches runtime
Expand All @@ -82,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256:
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
validate_expected_type(new_node, INT256_T) # force bounds check
_ = infer_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except (TypeMismatch, OverflowException):
Expand Down
2 changes: 1 addition & 1 deletion vyper/abi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def is_complex_type(self):

class ABI_Bytes(ABIType):
def __init__(self, bytes_bound):
if not bytes_bound >= 0:
if bytes_bound is not None and not bytes_bound >= 0:
raise InvalidABIType("Negative bytes_bound provided to ABI_Bytes")

self.bytes_bound = bytes_bound
Expand Down
31 changes: 20 additions & 11 deletions vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,23 +422,32 @@ def to_address(expr, arg, out_typ):
return IRnode.from_list(ret, out_typ)


# question: should we allow bytesM -> String?
@_input_types(BytesT)
def to_string(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
def _cast_bytestring(expr, arg, out_typ):
# ban converting Bytes[20] to Bytes[21]
if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)
# can't downcast literals with known length (e.g. b"abc" to Bytes[2])
if isinstance(expr, vy_ast.Constant) and arg.typ.maxlen > out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)

ret = ["seq"]
assert out_typ.maxlen is not None
if arg.typ.maxlen is None or out_typ.maxlen < arg.typ.maxlen:
ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]])
ret.append(arg)
# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
return IRnode.from_list(ret, typ=out_typ, location=arg.location, encoding=arg.encoding)


@_input_types(StringT)
def to_bytes(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
# question: should we allow bytesM -> String?
@_input_types(BytesT, StringT)
def to_string(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)

# TODO: more casts

# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
@_input_types(StringT, BytesT)
def to_bytes(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(IntegerT)
Expand Down
8 changes: 2 additions & 6 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node, infer_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation

Expand Down Expand Up @@ -99,7 +95,7 @@
# for its side effects (will throw if is not a type)
type_from_annotation(arg)
else:
validate_expected_type(arg, expected_type)
infer_type(arg, expected_type)

def _validate_arg_types(self, node: vy_ast.Call) -> None:
num_args = len(self._inputs) # the number of args the signature indicates
Expand Down
53 changes: 18 additions & 35 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
get_common_types,
get_exact_type_from_node,
get_possible_types_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.types import (
TYPE_T,
Expand Down Expand Up @@ -202,12 +202,14 @@ def infer_arg_types(self, node, expected_return_typ=None):
target_type = type_from_annotation(node.args[1])
value_types = get_possible_types_from_node(node.args[0])

# For `convert` of integer literals, we need to match type inference rules in
# convert.py codegen routines.
# For `convert` of integer literals, we need to match type inference
# rules in convert.py codegen routines.
# TODO: This can probably be removed once constant folding for `convert` is implemented
if len(value_types) > 1 and all(isinstance(v, IntegerT) for v in value_types):
# Get the smallest (and unsigned if available) type for non-integer target types
# (note this is different from the ordering returned by `get_possible_types_from_node`)
# Get the smallest (and unsigned if available) type for
# non-integer target types
# (note this is different from the ordering returned by
# `get_possible_types_from_node`)
if not isinstance(target_type, IntegerT):
value_types = sorted(value_types, key=lambda v: (v.is_signed, v.bits), reverse=True)
else:
Expand Down Expand Up @@ -290,11 +292,6 @@ class Slice(BuiltinFunctionT):
def fetch_call_return(self, node):
arg_type, _, _ = self.infer_arg_types(node)

if isinstance(arg_type, StringT):
return_type = StringT()
else:
return_type = BytesT()

# validate start and length are in bounds

arg = node.args[0]
Expand Down Expand Up @@ -323,20 +320,14 @@ def fetch_call_return(self, node):
if length_literal is not None and start_literal + length_literal > arg_type.length:
raise ArgumentException(f"slice out of bounds for {arg_type}", node)

# we know the length statically
return_cls = arg_type.__class__
if length_literal is not None:
return_type.set_length(length_literal)
return_type = return_cls(length_literal)
else:
return_type.set_min_length(arg_type.length)
return_type = return_cls(arg_type.length)

return return_type

def infer_arg_types(self, node, expected_return_typ=None):
self._validate_arg_types(node)
# return a concrete type for `b`
b_type = get_possible_types_from_node(node.args[0]).pop()
return [b_type, self._inputs[1][1], self._inputs[2][1]]

@process_inputs
def build_IR(self, expr, args, kwargs, context):
src, start, length = args
Expand Down Expand Up @@ -490,12 +481,8 @@ def fetch_call_return(self, node):
for arg_t in arg_types:
length += arg_t.length

if isinstance(arg_types[0], (StringT)):
return_type = StringT()
else:
return_type = BytesT()
return_type.set_length(length)
return return_type
return_type_cls = arg_types[0].__class__
return return_type_cls(length)

def infer_arg_types(self, node, expected_return_typ=None):
if len(node.args) < 2:
Expand All @@ -507,8 +494,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
ret = []
prev_typeclass = None
for arg in node.args:
validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
arg_t = get_possible_types_from_node(arg).pop()
arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes"
if prev_typeclass and current_typeclass != prev_typeclass:
raise TypeMismatch(
Expand Down Expand Up @@ -864,7 +850,7 @@ def infer_kwarg_types(self, node):
"Output type must be one of integer, bytes32 or address", node.keywords[0].value
)
output_typedef = TYPE_T(output_type)
node.keywords[0].value._metadata["type"] = output_typedef
# node.keywords[0].value._metadata["type"] = output_typedef
else:
output_typedef = TYPE_T(BYTES32_T)

Expand Down Expand Up @@ -1079,8 +1065,7 @@ def fetch_call_return(self, node):
raise

if outsize.value:
return_type = BytesT()
return_type.set_min_length(outsize.value)
return_type = BytesT(outsize.value)

if revert_on_failure:
return return_type
Expand Down Expand Up @@ -2372,8 +2357,8 @@ def infer_kwarg_types(self, node):
ret = {}
for kwarg in node.keywords:
kwarg_name = kwarg.arg
validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = get_exact_type_from_node(kwarg.value)
typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = typ
return ret

def fetch_call_return(self, node):
Expand Down Expand Up @@ -2403,9 +2388,7 @@ def fetch_call_return(self, node):
# the output includes 4 bytes for the method_id.
maxlen += 4

ret = BytesT()
ret.set_length(maxlen)
return ret
return BytesT(maxlen)

@staticmethod
def _parse_method_id(method_id_literal):
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def dummy_node_for_type(typ):


def _check_assign_bytes(left, right):
if right.typ.maxlen > left.typ.maxlen: # pragma: nocover
if left.typ.maxlen is not None and right.typ.maxlen > left.typ.maxlen: # pragma: nocover
raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}")

# stricter check for zeroing a byte array.
Expand Down
Loading
Loading