Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compile-time evaluation of slice() #3667

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
63 changes: 63 additions & 0 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest
from hypothesis import given, settings

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.compiler.settings import OptimizationLevel
from vyper.exceptions import ArgumentException, TypeMismatch

Expand Down Expand Up @@ -432,3 +434,64 @@ def test_slice_bytes32_calldata_extended(get_contract, code, result):
c.bar(3, "0x0001020304050607080910111213141516171819202122232425262728293031", 5).hex()
== result
)


code_comptime = [
(
"""
@external
@view
def baz() -> Bytes[16]:
return slice(0x1234567891234567891234567891234567891234567891234567891234567891, 0, 16)
""",
"12345678912345678912345678912345",
),
(
"""
@external
@view
def baz() -> String[5]:
return slice("why hello! how are you?", 4, 5)
""",
"hello",
),
(
"""
@external
@view
def baz() -> Bytes[6]:
return slice(b'gm sir, how are you ?', 0, 6)
""",
"gm sir".encode("utf-8").hex(),
),
]


@pytest.mark.parametrize("code,result", code_comptime)
def test_comptime(get_contract, code, result):
c = get_contract(code)
ret = c.baz()
if hasattr(ret, "hex"):
assert ret.hex() == result
else:
assert ret == result


error_slice = [
"slice(0x00, 0, 1)",
"slice(b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10', 10, 1)",
"slice(b'', 0, 1)",
'slice("why hello! how are you?", 32, 1)',
'slice("why hello! how are you?", -1, 1)',
'slice("why hello! how are you?", 4, 0)',
'slice("why hello! how are you?", 0, 33)',
'slice("why hello! how are you?", 16, 10)',
]


@pytest.mark.parametrize("code", error_slice)
def test_slice_error(code):
vyper_ast = vy_ast.parse_to_ast(code)
old_node = vyper_ast.body[0].value
with pytest.raises(ArgumentException):
vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)
111 changes: 111 additions & 0 deletions tests/functional/builtins/folding/test_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import string

import pytest
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.exceptions import ArgumentException


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(max_size=32),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytes32(get_contract, bytes_in, start, length):
as_hex = "0x" + str.join("", ["00" for _ in range(32 - len(bytes_in))]) + bytes_in.hex()
length = min(32 - start, length)

source = f"""
@external
def foo(bytes_in: bytes32) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

start *= 2
length *= 2
assert contract.foo(as_hex) == new_node.value == bytes.fromhex(as_hex[2:][start : (start + length)])


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(max_size=31),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytesnot32(bytes_in, start, length):
if not len(bytes_in):
as_hex = "0x00"
else:
as_hex = "0x" + bytes_in.hex()
length = min(32, 32 - start, length)

vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
with pytest.raises(ArgumentException):
vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
bytes_in=st.binary(min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_dynbytes(get_contract, bytes_in, start, length):
start = start % len(bytes_in)
length = min(len(bytes_in), len(bytes_in) - start, length)

source = f"""
@external
def foo(bytes_in: Bytes[100]) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({bytes_in}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(bytes_in) == new_node.value == bytes_in[start : (start + length)]


valid_char = [
char for char in string.printable if char not in (string.whitespace.replace(" ", "") + '"\\')
]


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
string_in=st.text(alphabet=valid_char, min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_string(get_contract, string_in, start, length):
start = start % len(string_in)
length = min(len(string_in), len(string_in) - start, length)

source = f"""
@external
def foo(string_in: String[100]) -> String[{length}]:
return slice(string_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f'slice("{string_in}", {start}, {length})')
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(string_in) == new_node.value == string_in[start : (start + length)]
40 changes: 40 additions & 0 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,46 @@ class Slice(BuiltinFunction):
]
_return_type = None

def evaluate(self, node):
validate_call_args(node, 3)
literal_value, start, length = node.args
if (
isinstance(literal_value, (vy_ast.Bytes, vy_ast.Str, vy_ast.Hex))
and isinstance(start, vy_ast.Int)
and isinstance(length, vy_ast.Int)
):
(start_val, length_val) = (start.value, length.value)

if start_val < 0:
raise ArgumentException("Start cannot be negative", start)
elif length_val <= 0:
raise ArgumentException("Length cannot be negative", length)

if isinstance(literal_value, vy_ast.Hex):
if start_val >= 32:
raise ArgumentException("Start cannot take that value", start)
if length_val > 32:
raise ArgumentException("Length cannot take that value", length)
length = len(literal_value.value) // 2 - 1
if length != 32:
raise ArgumentException("Length can only be of 32", literal_value)
start_val *= 2
length_val *= 2

if start_val + length_val > len(literal_value.value):
raise ArgumentException("Slice is out of bounds", start)

if isinstance(literal_value, vy_ast.Bytes):
sublit = literal_value.value[start_val : (start_val + length_val)]
return vy_ast.Bytes.from_node(node, value=sublit)
elif isinstance(literal_value, vy_ast.Str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for elif because the previous branch already returned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so put "if" everywhere right ? it was lookin safer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited!

sublit = literal_value.value[start_val : (start_val + length_val)]
return vy_ast.Str.from_node(node, value=sublit)
elif isinstance(literal_value, vy_ast.Hex):
sublit = literal_value.value[2:][start_val : (start_val + length_val)]
return vy_ast.Bytes.from_node(node, value=f"0x{sublit}")
raise UnfoldableNode

def fetch_call_return(self, node):
arg_type, _, _ = self.infer_arg_types(node)

Expand Down