Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compile-time evaluation of slice() #3667

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
91 changes: 43 additions & 48 deletions tests/functional/builtins/folding/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,45 @@
@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
a=st.integers(min_value=0, max_value=2**256 - 1),
s=st.integers(min_value=0, max_value=31),
le=st.integers(min_value=1, max_value=32),
bytes_in=st.binary(max_size=32),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytes32(get_contract, a, s, le):
a = hex(a)
while len(a) < 66:
a = f"0x0{a[2:]}"
le = min(32, 32 - s, le)
def test_slice_bytes32(get_contract, bytes_in, start, length):
as_hex = "0x" + str.join("", ["00" for _ in range(32 - len(bytes_in))]) + bytes_in.hex()
length = min(32 - start, length)

source = f"""
@external
def foo(a: bytes32) -> Bytes[{le}]:
return slice(a, {s}, {le})
def foo(bytes_in: bytes32) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({a}, {s}, {le})")
vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

s *= 2
le *= 2
assert contract.foo(a) == new_node.value == bytes.fromhex(a[2:][s : (s + le)])
start *= 2
length *= 2
assert contract.foo(as_hex) == new_node.value == bytes.fromhex(as_hex[2:][start : (start + length)])


@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
a=st.integers(min_value=0, max_value=2**256 - 1),
s=st.integers(min_value=0, max_value=31),
le=st.integers(min_value=1, max_value=32),
bytes_in=st.binary(max_size=31),
start=st.integers(min_value=0, max_value=31),
length=st.integers(min_value=1, max_value=32),
)
def test_slice_bytesnot32(a, s, le):
a = hex(a)
if len(a) == 3:
a = f"0x0{a[2:]}"
elif len(a) == 66:
a = a[:-2]
elif len(a) % 2 == 1:
a = a[:-1]
le = min(32, 32 - s, le)

vyper_ast = vy_ast.parse_to_ast(f"slice({a}, {s}, {le})")
def test_slice_bytesnot32(bytes_in, start, length):
if not len(bytes_in):
as_hex = "0x00"
else:
as_hex = "0x" + bytes_in.hex()
length = min(32, 32 - start, length)

vyper_ast = vy_ast.parse_to_ast(f"slice({as_hex}, {start}, {length})")
old_node = vyper_ast.body[0].value
with pytest.raises(ArgumentException):
vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)
Expand All @@ -64,26 +59,26 @@ def test_slice_bytesnot32(a, s, le):
@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
a=st.binary(min_size=1, max_size=100),
s=st.integers(min_value=0, max_value=99),
le=st.integers(min_value=1, max_value=100),
bytes_in=st.binary(min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_dynbytes(get_contract, a, s, le):
s = s % len(a)
le = min(len(a), len(a) - s, le)
def test_slice_dynbytes(get_contract, bytes_in, start, length):
start = start % len(bytes_in)
length = min(len(bytes_in), len(bytes_in) - start, length)

source = f"""
@external
def foo(a: Bytes[100]) -> Bytes[{le}]:
return slice(a, {s}, {le})
def foo(bytes_in: Bytes[100]) -> Bytes[{length}]:
return slice(bytes_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"slice({a}, {s}, {le})")
vyper_ast = vy_ast.parse_to_ast(f"slice({bytes_in}, {start}, {length})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(a) == new_node.value == a[s : (s + le)]
assert contract.foo(bytes_in) == new_node.value == bytes_in[start : (start + length)]


valid_char = [
Expand All @@ -94,23 +89,23 @@ def foo(a: Bytes[100]) -> Bytes[{le}]:
@pytest.mark.fuzzing
@settings(max_examples=50)
@given(
a=st.text(alphabet=valid_char, min_size=1, max_size=100),
s=st.integers(min_value=0, max_value=99),
le=st.integers(min_value=1, max_value=100),
string_in=st.text(alphabet=valid_char, min_size=1, max_size=100),
start=st.integers(min_value=0, max_value=99),
length=st.integers(min_value=1, max_value=100),
)
def test_slice_string(get_contract, a, s, le):
s = s % len(a)
le = min(len(a), len(a) - s, le)
def test_slice_string(get_contract, string_in, start, length):
start = start % len(string_in)
length = min(len(string_in), len(string_in) - start, length)

source = f"""
@external
def foo(a: String[100]) -> String[{le}]:
return slice(a, {s}, {le})
def foo(string_in: String[100]) -> String[{length}]:
return slice(string_in, {start}, {length})
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f'slice("{a}", {s}, {le})')
vyper_ast = vy_ast.parse_to_ast(f'slice("{string_in}", {start}, {length})')
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["slice"].evaluate(old_node)

assert contract.foo(a) == new_node.value == a[s : (s + le)]
assert contract.foo(string_in) == new_node.value == string_in[start : (start + length)]
22 changes: 11 additions & 11 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,35 +302,35 @@ def evaluate(self, node):
and isinstance(start, vy_ast.Int)
and isinstance(length, vy_ast.Int)
):
(st_val, le_val) = (start.value, length.value)
(start_val, length_val) = (start.value, length.value)

if st_val < 0:
if start_val < 0:
raise ArgumentException("Start cannot be negative", start)
elif le_val <= 0:
elif length_val <= 0:
raise ArgumentException("Length cannot be negative", length)

if isinstance(literal_value, vy_ast.Hex):
if st_val >= 32:
if start_val >= 32:
raise ArgumentException("Start cannot take that value", start)
if le_val > 32:
if length_val > 32:
raise ArgumentException("Length cannot take that value", length)
length = len(literal_value.value) // 2 - 1
if length != 32:
raise ArgumentException("Length can only be of 32", literal_value)
st_val *= 2
le_val *= 2
start_val *= 2
length_val *= 2

if st_val + le_val > len(literal_value.value):
if start_val + length_val > len(literal_value.value):
raise ArgumentException("Slice is out of bounds", start)

if isinstance(literal_value, vy_ast.Bytes):
sublit = literal_value.value[st_val : (st_val + le_val)]
sublit = literal_value.value[start_val : (start_val + length_val)]
return vy_ast.Bytes.from_node(node, value=sublit)
elif isinstance(literal_value, vy_ast.Str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for elif because the previous branch already returned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so put "if" everywhere right ? it was lookin safer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited!

sublit = literal_value.value[st_val : (st_val + le_val)]
sublit = literal_value.value[start_val : (start_val + length_val)]
return vy_ast.Str.from_node(node, value=sublit)
elif isinstance(literal_value, vy_ast.Hex):
sublit = literal_value.value[2:][st_val : (st_val + le_val)]
sublit = literal_value.value[2:][start_val : (start_val + length_val)]
return vy_ast.Bytes.from_node(node, value=f"0x{sublit}")
raise UnfoldableNode

Expand Down