Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compile-time evaluation of slice() #3667

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
63 changes: 63 additions & 0 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest
from hypothesis import given, settings

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.compiler.settings import OptimizationLevel
from vyper.exceptions import ArgumentException, TypeMismatch

Expand Down Expand Up @@ -432,3 +434,64 @@ def test_slice_bytes32_calldata_extended(get_contract, code, result):
c.bar(3, "0x0001020304050607080910111213141516171819202122232425262728293031", 5).hex()
== result
)


code_comptime = [
(
"""
@external
@view
def baz() -> Bytes[16]:
return slice(0x1234567891234567891234567891234567891234567891234567891234567891, 0, 16)
""",
"12345678912345678912345678912345",
),
(
"""
@external
@view
def baz() -> String[5]:
return slice("why hello! how are you?", 4, 5)
""",
"hello",
),
(
"""
@external
@view
def baz() -> Bytes[6]:
return slice(b'gm sir, how are you ?', 0, 6)
""",
"gm sir".encode("utf-8").hex(),
),
]


@pytest.mark.parametrize("code,result", code_comptime)
def test_comptime(get_contract, code, result):
c = get_contract(code)
ret = c.baz()
if isinstance(ret, bytes):
assert ret.hex() == result
else:
assert ret == result


error_slice = [
"slice(0x00, 0, 1)",
"slice(b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10', 10, 1)",
"slice(b'', 0, 1)",
'slice("why hello! how are you?", 32, 1)',
'slice("why hello! how are you?", -1, 1)',
'slice("why hello! how are you?", 4, 0)',
'slice("why hello! how are you?", 0, 33)',
'slice("why hello! how are you?", 16, 10)',
]


@pytest.mark.parametrize("code", error_slice)
def test_slice_error(code):
vyper_ast = vy_ast.parse_to_ast(code)
old_node = vyper_ast.body[0].value
with pytest.raises(ArgumentException):
vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)
115 changes: 115 additions & 0 deletions tests/functional/builtins/folding/test_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import string

import pytest
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.exceptions import ArgumentException


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(max_size=32),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytes32(get_contract, bytes_in, start, length):
as_hex = "0x" + str(bytes_in.hex()).zfill(64)
length = min(32 - start, length)

source = f"""
@external
def foo(bytes_in: bytes32) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

start *= 2
length *= 2
assert (
contract.foo(as_hex)
== new_node.value
== bytes.fromhex(as_hex[2:][start : (start + length)])
)


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(max_size=31),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytesnot32(bytes_in, start, length):
if not len(bytes_in):
as_hex = "0x00"
else:
as_hex = "0x" + bytes_in.hex()
length = min(32, 32 - start, length)

vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
with pytest.raises(ArgumentException):
vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_dynbytes(get_contract, bytes_in, start, length):
start = start % len(bytes_in)
length = min(len(bytes_in), len(bytes_in) - start, length)

source = f"""
@external
def foo(bytes_in: Bytes[100]) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({bytes_in}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(bytes_in) == new_node.value == bytes_in[start : (start + length)]


valid_char = [
char for char in string.printable if char not in (string.whitespace.replace(" ", "") + '"\\')
]


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
string_in=st.text(alphabet=valid_char, min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_string(get_contract, string_in, start, length):
start = start % len(string_in)
length = min(len(string_in), len(string_in) - start, length)

source = f"""
@external
def foo(string_in: String[100]) -> String[{length}]:
return slice(string_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f'slice("{string_in}", {start}, {length})')
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(string_in) == new_node.value == string_in[start : (start + length)]
42 changes: 42 additions & 0 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,48 @@ class Slice(BuiltinFunction):
]
_return_type = None

def evaluate(self, node):
validate_call_args(node, 3)
bytestring, start_node, length_node = node.args
if not (
isinstance(bytestring, (vy_ast.Bytes, vy_ast.Str, vy_ast.Hex))
and isinstance(start_node, vy_ast.Int)
and isinstance(length_node, vy_ast.Int)
):
raise UnfoldableNode

(start, length) = (start_node.value, length_node.value)

if start < 0:
raise ArgumentException("Start cannot be negative", start_node)
elif length <= 0:
raise ArgumentException("Length must be positive", length_node)

if isinstance(bytestring, vy_ast.Hex):
bytes_value = bytes.fromhex(bytestring.value.removeprefix("0x"))
if start >= 32:
raise ArgumentException("Start cannot take that value", start_node)
if length > 32:
raise ArgumentException("Length cannot take that value", length_node)
if len(bytes_value) != 32:
raise ArgumentException("Length can only be of 32", bytestring)
elif isinstance(bytestring, vy_ast.Str):
bytes_value = bytestring.value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@charles-cooper how do you feel about this line ? Here the bytes_value type is str while otherwise it's bytes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks confusing

else:
bytes_value = bytes.fromhex(bytestring.value.hex().removeprefix("0x"))

if start + length > len(bytes_value):
raise ArgumentException("Slice is out of bounds", start_node)

end = start + length
res = bytes_value[start : end]
if isinstance(bytestring, vy_ast.Bytes):
return vy_ast.Bytes.from_node(node, value=res)
if isinstance(bytestring, vy_ast.Str):
return vy_ast.Str.from_node(node, value=res)
if isinstance(bytestring, vy_ast.Hex):
return vy_ast.Bytes.from_node(node, value=res)

def fetch_call_return(self, node):
arg_type, _, _ = self.infer_arg_types(node)

Expand Down
Loading