diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index 9947223e75..1c11938d50 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -662,6 +662,36 @@ def foo() -> {t_bytes}: assert c2.foo() == test_data +@pytest.mark.parametrize("n", range(1, 32)) +def test_literal_bytestrings_to_bytes_m(get_contract, n): + test_data = "1" * n + out = test_data.encode() + + bytes_m_typ = f"bytes{n}" + contract_1 = f""" +@external +def foo() -> {bytes_m_typ}: + return convert(b"{test_data}", {bytes_m_typ}) + +@external +def bar() -> {bytes_m_typ}: + return convert("{test_data}", {bytes_m_typ}) + """ + + contract_2 = f""" +@external +def fubar(x: String[{n}]) -> {bytes_m_typ}: + return convert(x, {bytes_m_typ}) + """ + + c1 = get_contract(contract_1) + assert c1.foo() == out + assert c1.bar() == out + + with pytest.raises(TypeMismatch): + compile_code(contract_2) + + @pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases()) @pytest.mark.fuzzing def test_conversion_failures(get_contract, assert_compile_failed, tx_failed, i_typ, o_typ, val): diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index db2accf359..5cec6492cb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -324,6 +324,17 @@ def foo(): nonpayable FOO: constant(Foo) = Foo(BAR) BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF """, + # conversion of literal bytestrings to bytes_M + """ +b: constant(bytes5) = convert(b"vyper", bytes5) + """, + """ +b: constant(bytes5) = convert("vyper", bytes5) + """, + """ +a: constant(Bytes[5]) = b"vyper" +b: constant(bytes5) = convert(a, bytes5) + """, ] diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index a494e4a344..00cde54a36 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -363,20 +363,34 @@ def to_decimal(expr, arg, out_typ): raise CompilerPanic("unreachable") -@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BytesT, BoolT) +@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BytesT, StringT, BoolT) def to_bytes_m(expr, arg, out_typ): _check_bytes(expr, arg, out_typ, max_bytes_allowed=out_typ.m) - if isinstance(arg.typ, BytesT): - bytes_val = LOAD(bytes_data_ptr(arg)) - - # zero out any dirty bytes (which can happen in the last - # word of a bytearray) - len_ = get_bytearray_length(arg) - num_zero_bits = IRnode.from_list(["mul", ["sub", 32, len_], 8]) - with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits): - arg = shl(num_zero_bits, shr(num_zero_bits, bytes_val)) - arg = b.resolve(arg) + if isinstance(arg.typ, _BytestringT): + # handle literal bytestrings first + if isinstance(expr, vy_ast.Constant) and arg.typ.length <= out_typ.m: + val = expr.value + if isinstance(arg.typ, StringT): + val = val.encode("utf-8") + + # bytes_m types are left padded with zeros + val = int(val.hex(), 16) << 8 * (out_typ.m - arg.typ.length) + arg = shl(256 - out_typ.m_bits, val) + + elif isinstance(arg.typ, BytesT): + bytes_val = LOAD(bytes_data_ptr(arg)) + + # zero out any dirty bytes (which can happen in the last + # word of a bytearray) + len_ = get_bytearray_length(arg) + num_zero_bits = IRnode.from_list(["mul", ["sub", 32, len_], 8]) + with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits): + arg = shl(num_zero_bits, shr(num_zero_bits, bytes_val)) + arg = b.resolve(arg) + + else: + _FAIL(arg.typ, out_typ, expr) elif is_bytes_m_type(arg.typ): # clamp if it's a downcast diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 629ab684a8..979fcd8b53 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -55,6 +55,7 @@ ) from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( + check_modifiability, get_common_types, get_exact_type_from_node, get_possible_types_from_node, @@ -194,6 +195,9 @@ def build_IR(self, expr, args, kwargs, context): class Convert(BuiltinFunctionT): _id = "convert" + def check_modifiability_for_call(self, node, modifiability): + return check_modifiability(node.args[0], modifiability) + def fetch_call_return(self, node): _, target_typedef = self.infer_arg_types(node)