diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 53e092019f..1a57025410 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -2,6 +2,8 @@ import pytest from hypothesis import given, settings +from vyper import ast as vy_ast +from vyper.builtins import functions as vy_fn from vyper.compiler.settings import OptimizationLevel from vyper.exceptions import ArgumentException, TypeMismatch @@ -432,3 +434,64 @@ def test_slice_bytes32_calldata_extended(get_contract, code, result): c.bar(3, "0x0001020304050607080910111213141516171819202122232425262728293031", 5).hex() == result ) + + +code_comptime = [ + ( + """ +@external +@view +def baz() -> Bytes[16]: + return slice(0x1234567891234567891234567891234567891234567891234567891234567891, 0, 16) + """, + "12345678912345678912345678912345", + ), + ( + """ +@external +@view +def baz() -> String[5]: + return slice("why hello! how are you?", 4, 5) + """, + "hello", + ), + ( + """ +@external +@view +def baz() -> Bytes[6]: + return slice(b'gm sir, how are you ?', 0, 6) + """, + "gm sir".encode("utf-8").hex(), + ), +] + + +@pytest.mark.parametrize("code,result", code_comptime) +def test_comptime(get_contract, code, result): + c = get_contract(code) + ret = c.baz() + if isinstance(ret, bytes): + assert ret.hex() == result + else: + assert ret == result + + +error_slice = [ + "slice(0x00, 0, 1)", + "slice(b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10', 10, 1)", + "slice(b'', 0, 1)", + 'slice("why hello! how are you?", 32, 1)', + 'slice("why hello! how are you?", -1, 1)', + 'slice("why hello! how are you?", 4, 0)', + 'slice("why hello! how are you?", 0, 33)', + 'slice("why hello! how are you?", 16, 10)', +] + + +@pytest.mark.parametrize("code", error_slice) +def test_slice_error(code): + vyper_ast = vy_ast.parse_to_ast(code) + old_node = vyper_ast.body[0].value + with pytest.raises(ArgumentException): + vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node) diff --git a/tests/functional/builtins/folding/test_slice.py b/tests/functional/builtins/folding/test_slice.py new file mode 100644 index 0000000000..da5d1166dc --- /dev/null +++ b/tests/functional/builtins/folding/test_slice.py @@ -0,0 +1,115 @@ +import string + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from vyper import ast as vy_ast +from vyper.builtins import functions as vy_fn +from vyper.exceptions import ArgumentException + + +@pytest.mark.fuzzing +@settings(max_examples=50) +@given( + bytes_in=st.binary(max_size=32), + start=st.integers(min_value=0, max_value=31), + length=st.integers(min_value=1, max_value=32), +) +def test_slice_bytes32(get_contract, bytes_in, start, length): + as_hex = "0x" + str(bytes_in.hex()).zfill(64) + length = min(32 - start, length) + + source = f""" +@external +def foo(bytes_in: bytes32) -> Bytes[{length}]: + return slice(bytes_in, {start}, {length}) + """ + contract = get_contract(source) + + vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})") + old_node = vyper_ast.body[0].value + new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node) + + start *= 2 + length *= 2 + assert ( + contract.foo(as_hex) + == new_node.value + == bytes.fromhex(as_hex[2:][start : (start + length)]) + ) + + +@pytest.mark.fuzzing +@settings(max_examples=50) +@given( + bytes_in=st.binary(max_size=31), + start=st.integers(min_value=0, max_value=31), + length=st.integers(min_value=1, max_value=32), +) +def test_slice_bytesnot32(bytes_in, start, length): + if not len(bytes_in): + as_hex = "0x00" + else: + as_hex = "0x" + bytes_in.hex() + length = min(32, 32 - start, length) + + vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})") + old_node = vyper_ast.body[0].value + with pytest.raises(ArgumentException): + vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node) + + +@pytest.mark.fuzzing +@settings(max_examples=50) +@given( + bytes_in=st.binary(min_size=1, max_size=100), + start=st.integers(min_value=0, max_value=99), + length=st.integers(min_value=1, max_value=100), +) +def test_slice_dynbytes(get_contract, bytes_in, start, length): + start = start % len(bytes_in) + length = min(len(bytes_in), len(bytes_in) - start, length) + + source = f""" +@external +def foo(bytes_in: Bytes[100]) -> Bytes[{length}]: + return slice(bytes_in, {start}, {length}) + """ + contract = get_contract(source) + + vyper_ast = vy_ast.parse_to_ast(f"slice({bytes_in}, {start}, {length})") + old_node = vyper_ast.body[0].value + new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node) + + assert contract.foo(bytes_in) == new_node.value == bytes_in[start : (start + length)] + + +valid_char = [ + char for char in string.printable if char not in (string.whitespace.replace(" ", "") + '"\\') +] + + +@pytest.mark.fuzzing +@settings(max_examples=50) +@given( + string_in=st.text(alphabet=valid_char, min_size=1, max_size=100), + start=st.integers(min_value=0, max_value=99), + length=st.integers(min_value=1, max_value=100), +) +def test_slice_string(get_contract, string_in, start, length): + start = start % len(string_in) + length = min(len(string_in), len(string_in) - start, length) + + source = f""" +@external +def foo(string_in: String[100]) -> String[{length}]: + return slice(string_in, {start}, {length}) + """ + contract = get_contract(source) + + vyper_ast = vy_ast.parse_to_ast(f'slice("{string_in}", {start}, {length})') + old_node = vyper_ast.body[0].value + new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node) + + assert contract.foo(string_in) == new_node.value == string_in[start : (start + length)] diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 001939638b..3167cd816a 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -294,6 +294,48 @@ class Slice(BuiltinFunction): ] _return_type = None + def evaluate(self, node): + validate_call_args(node, 3) + bytestring, start_node, length_node = node.args + if not ( + isinstance(bytestring, (vy_ast.Bytes, vy_ast.Str, vy_ast.Hex)) + and isinstance(start_node, vy_ast.Int) + and isinstance(length_node, vy_ast.Int) + ): + raise UnfoldableNode + + (start, length) = (start_node.value, length_node.value) + + if start < 0: + raise ArgumentException("Start cannot be negative", start_node) + elif length <= 0: + raise ArgumentException("Length must be positive", length_node) + + if isinstance(bytestring, vy_ast.Hex): + bytes_value = bytes.fromhex(bytestring.value.removeprefix("0x")) + if start >= 32: + raise ArgumentException("Start cannot take that value", start_node) + if length > 32: + raise ArgumentException("Length cannot take that value", length_node) + if len(bytes_value) != 32: + raise ArgumentException("Length can only be of 32", bytestring) + elif isinstance(bytestring, vy_ast.Str): + bytes_value = bytestring.value + else: + bytes_value = bytes.fromhex(bytestring.value.hex().removeprefix("0x")) + + if start + length > len(bytes_value): + raise ArgumentException("Slice is out of bounds", start_node) + + end = start + length + res = bytes_value[start:end] + if isinstance(bytestring, vy_ast.Bytes): + return vy_ast.Bytes.from_node(node, value=res) + if isinstance(bytestring, vy_ast.Str): + return vy_ast.Str.from_node(node, value=res) + if isinstance(bytestring, vy_ast.Hex): + return vy_ast.Bytes.from_node(node, value=res) + def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node)