Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[ux]: allow conversion of bytestring literals to bytes_M #4480

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,36 @@ def foo() -> {t_bytes}:
assert c2.foo() == test_data


@pytest.mark.parametrize("n", range(1, 32))
def test_literal_bytestrings_to_bytes_m(get_contract, n):
test_data = "1" * n
out = test_data.encode()

bytes_m_typ = f"bytes{n}"
contract_1 = f"""
@external
def foo() -> {bytes_m_typ}:
return convert(b"{test_data}", {bytes_m_typ})

@external
def bar() -> {bytes_m_typ}:
return convert("{test_data}", {bytes_m_typ})
"""

contract_2 = f"""
@external
def fubar(x: String[{n}]) -> {bytes_m_typ}:
return convert(x, {bytes_m_typ})
"""

c1 = get_contract(contract_1)
assert c1.foo() == out
assert c1.bar() == out

with pytest.raises(TypeMismatch):
compile_code(contract_2)


@pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases())
@pytest.mark.fuzzing
def test_conversion_failures(get_contract, assert_compile_failed, tx_failed, i_typ, o_typ, val):
Expand Down
11 changes: 11 additions & 0 deletions tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@ def foo(): nonpayable
FOO: constant(Foo) = Foo(BAR)
BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
""",
# conversion of literal bytestrings to bytes_M
"""
b: constant(bytes5) = convert(b"vyper", bytes5)
""",
"""
b: constant(bytes5) = convert("vyper", bytes5)
""",
"""
a: constant(Bytes[5]) = b"vyper"
b: constant(bytes5) = convert(a, bytes5)
""",
]


Expand Down
36 changes: 25 additions & 11 deletions vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,34 @@ def to_decimal(expr, arg, out_typ):
raise CompilerPanic("unreachable")


@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BytesT, BoolT)
@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BytesT, StringT, BoolT)
def to_bytes_m(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, max_bytes_allowed=out_typ.m)

if isinstance(arg.typ, BytesT):
bytes_val = LOAD(bytes_data_ptr(arg))

# zero out any dirty bytes (which can happen in the last
# word of a bytearray)
len_ = get_bytearray_length(arg)
num_zero_bits = IRnode.from_list(["mul", ["sub", 32, len_], 8])
with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits):
arg = shl(num_zero_bits, shr(num_zero_bits, bytes_val))
arg = b.resolve(arg)
if isinstance(arg.typ, _BytestringT):
# handle literal bytestrings first
if isinstance(expr, vy_ast.Constant) and arg.typ.length <= out_typ.m:
val = expr.value
if isinstance(arg.typ, StringT):
val = val.encode("utf-8")

# bytes_m types are left padded with zeros
val = int(val.hex(), 16) << 8 * (out_typ.m - arg.typ.length)
arg = shl(256 - out_typ.m_bits, val)

elif isinstance(arg.typ, BytesT):
bytes_val = LOAD(bytes_data_ptr(arg))

# zero out any dirty bytes (which can happen in the last
# word of a bytearray)
len_ = get_bytearray_length(arg)
num_zero_bits = IRnode.from_list(["mul", ["sub", 32, len_], 8])
with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits):
arg = shl(num_zero_bits, shr(num_zero_bits, bytes_val))
arg = b.resolve(arg)

else:
_FAIL(arg.typ, out_typ, expr)

elif is_bytes_m_type(arg.typ):
# clamp if it's a downcast
Expand Down
4 changes: 4 additions & 0 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import (
check_modifiability,
get_common_types,
get_exact_type_from_node,
get_possible_types_from_node,
Expand Down Expand Up @@ -194,6 +195,9 @@ def build_IR(self, expr, args, kwargs, context):
class Convert(BuiltinFunctionT):
_id = "convert"

def check_modifiability_for_call(self, node, modifiability):
return check_modifiability(node.args[0], modifiability)

def fetch_call_return(self, node):
_, target_typedef = self.infer_arg_types(node)

Expand Down