Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add pop with index for dynamic arrays #2948

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
85 changes: 85 additions & 0 deletions tests/parser/types/test_dynamic_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,91 @@ def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_dat
assert c.foo(test_data) == expected_result


@pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)])
@pytest.mark.parametrize("ix", [i for i in range(6)])
def test_pop_index_return_pass(get_contract, assert_tx_failed, test_data, ix):
code = """
@external
def foo(a: DynArray[uint256, 5], b: uint256) -> uint256:
return a.pop(ix=b)
"""
c = get_contract(code)
arr_length = len(test_data)
if ix >= arr_length:
assert_tx_failed(lambda: c.foo(test_data, ix))
else:
assert c.foo(test_data, ix) == test_data[ix]


pop_index_tests = [
(
"""
my_array: DynArray[uint256, 5]
@external
def foo(xs: DynArray[uint256, 5], i: uint256) -> DynArray[uint256, 5]:
for x in xs:
self.my_array.append(x)
for x in xs:
self.my_array.pop(ix=0)
return self.my_array
""",
lambda xs, idx: [],
),
(
"""
my_array: DynArray[uint256, 5]
@external
def foo(xs: DynArray[uint256, 5], i: uint256) -> DynArray[uint256, 5]:
for x in xs:
self.my_array.append(x)
self.my_array.pop(ix=i)
return self.my_array
""",
lambda xs, idx: None if len(xs) == 0 else xs[:idx] + xs[idx+1:],
),
# check order of evaluation.
(
"""
my_array: DynArray[uint256, 5]
@external
def foo(xs: DynArray[uint256, 5], i: uint256) -> (DynArray[uint256, 5], uint256):
for x in xs:
self.my_array.append(x)
return self.my_array, self.my_array.pop(ix=i)
""",
lambda xs, idx: None if len(xs) == 0 else [xs[:idx] + xs[idx+1:], xs[idx]],
),
# check order of evaluation.
(
"""
my_array: DynArray[uint256, 5]
@external
def foo(xs: DynArray[uint256, 5], i: uint256) -> (uint256, DynArray[uint256, 5]):
for x in xs:
self.my_array.append(x)
return self.my_array.pop(ix=i), self.my_array
""",
lambda xs, idx: None if len(xs) == 0 else [xs[idx], xs[:idx] + xs[idx+1:]],
),
]


@pytest.mark.parametrize("code,check_result", pop_index_tests)
# TODO change this to fuzz random data
@pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)])
def test_pop_index_pass(get_contract, assert_tx_failed, code, check_result, test_data):
c = get_contract(code)

arr_length = len(test_data)
for idx in range(arr_length):
expected_result = check_result(test_data, idx)
if expected_result is None:
# None is sentinel to indicate txn should revert
assert_tx_failed(lambda: c.foo(test_data, idx))
else:
assert c.foo(test_data, idx) == expected_result


append_pop_complex_tests = [
(
"""
Expand Down
55 changes: 54 additions & 1 deletion vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@
IRnode,
_freshname,
add_ofst,
append_dyn_array,
bytes_data_ptr,
calculate_type_for_external_return,
check_assign,
check_external_call,
clamp,
clamp2,
clamp_basetype,
clamp_nonzero,
copy_bytes,
dummy_node_for_type,
ensure_in_memory,
eval_once_check,
eval_seq,
get_bytearray_length,
get_element_ptr,
ir_tuple_from_args,
needs_external_call_wrap,
pop_dyn_array,
promote_signed_int,
shl,
unwrap_location,
Expand All @@ -42,6 +46,7 @@
BaseType,
ByteArrayLike,
ByteArrayType,
DArrayType,
SArrayType,
StringType,
TupleType,
Expand Down Expand Up @@ -112,7 +117,7 @@
vyper_warn,
)

from .signatures import BuiltinFunction, process_inputs
from .signatures import BuiltinFunction, process_inputs, process_kwarg

SHA256_ADDRESS = 2
SHA256_BASE_GAS = 60
Expand Down Expand Up @@ -2464,6 +2469,51 @@ def build_IR(self, expr, args, kwargs, context):
)


class Append(BuiltinFunction):
_id = "append"

def build_IR(self, expr, context):
darray = Expr(expr.func.value, context).ir_node
args = [Expr(x, context).ir_node for x in expr.args]

# sanity checks
assert len(args) == 1
arg = args[0]
assert isinstance(darray.typ, DArrayType)

check_assign(dummy_node_for_type(darray.typ.subtype), dummy_node_for_type(arg.typ))

return append_dyn_array(darray, arg)


class Pop(BuiltinFunction):
_id = "pop"

def _get_kwarg_settings(self, expr):
call_type = get_exact_type_from_node(expr.func)
expected_kwargs = call_type._kwargs
return expected_kwargs

def build_IR(self, expr, context, return_popped_item):
darray = Expr(expr.func.value, context).ir_node
assert isinstance(darray.typ, DArrayType)
assert len(expr.args) == 0

kwargs = self._get_kwarg_settings(expr)

if expr.keywords:
kwarg_name = expr.keywords[0].arg
kwarg_val = expr.keywords[0].value
assert len(expr.keywords) == 1 and kwarg_name == "ix"
kwarg_settings = kwargs[kwarg_name]
expected_kwarg_type = kwarg_settings.typ
idx = process_kwarg(kwarg_val, kwarg_settings, expected_kwarg_type, context)
return pop_dyn_array(context, darray, return_popped_item=return_popped_item, pop_idx=idx)

else:
return pop_dyn_array(context, darray, return_popped_item=return_popped_item)


DISPATCH_TABLE = {
"_abi_encode": ABIEncode(),
"_abi_decode": ABIDecode(),
Expand Down Expand Up @@ -2505,6 +2555,7 @@ def build_IR(self, expr, args, kwargs, context):
"max": Max(),
"empty": Empty(),
"abs": Abs(),
"pop": Pop(),
}

STMT_DISPATCH_TABLE = {
Expand All @@ -2517,6 +2568,8 @@ def build_IR(self, expr, args, kwargs, context):
"create_forwarder_to": CreateForwarderTo(),
"create_copy_of": CreateCopyOf(),
"create_from_factory": CreateFromFactory(),
"append": Append(),
"pop": Pop(),
}

BUILTIN_FUNCTIONS = {**STMT_DISPATCH_TABLE, **DISPATCH_TABLE}.keys()
Expand Down
63 changes: 62 additions & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,19 +290,80 @@ def append_dyn_array(darray_node, elem_node):
return IRnode.from_list(b1.resolve(b2.resolve(ret)))


def pop_dyn_array(darray_node, return_popped_item):
def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None):
assert isinstance(darray_node.typ, DArrayType)
assert darray_node.encoding == Encoding.VYPER
ret = ["seq"]
with darray_node.cache_when_complex("darray") as (b1, darray_node):
old_len = clamp("gt", get_dyn_array_count(darray_node), 0)
new_len = IRnode.from_list(["sub", old_len, 1], typ="uint256")

if pop_idx is not None:
# If pop from given index, assert that array length is greater than index
ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx))

with new_len.cache_when_complex("new_len") as (b2, new_len):
ret.append(STORE(darray_node, new_len))

# Modify dynamic array
if pop_idx is not None:
body = ["seq"]

# Swap index to pop with the old last index using a temporary buffer
dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False)
buf = context.new_internal_variable(darray_node.typ.subtype)
buf = IRnode.from_list(buf, typ=darray_node.typ.subtype, location=MEMORY)
src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False)

save_dst = make_setter(buf, dst_i)
mv_src = make_setter(dst_i, src_i)
mv_dst = make_setter(src_i, buf)

initial_swap = IRnode.from_list(["seq", save_dst, mv_src, mv_dst])
body.append(initial_swap)

# Iterate from popped index to the new last index and swap
# Set up the loop variable
loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256")
next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256")

# Swap value at index loop_var with index loop_var + 1
loop_save_dst = make_setter(
buf,
get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i
)
loop_mv_src = make_setter(
get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i
get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i
)
loop_mv_dst = make_setter(
get_element_ptr(darray_node, next_ix, array_bounds_check=False), # src_i
buf,
)
loop_body = IRnode.from_list(["seq", loop_save_dst, loop_mv_src, loop_mv_dst])

# Set loop termination as new_len - 1
iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256")

# Set dynarray length as repeat bound
repeat_bound = darray_node.typ.count
loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body])

# Enter loop only if new_len is at least 2
length_check = IRnode.from_list(["if", ["ge", new_len, 2], loop])
body.append(length_check)

# Perform the initial swap only if popped index is not the last index
swap_test = IRnode.from_list(["lt", pop_idx, new_len])
swap_check = IRnode.from_list(["if", swap_test, body])

ret.append(swap_check)

# NOTE skip array bounds check bc we already asserted len two lines up
if return_popped_item:
# Set index of popped element to last index of old array
# For pop with index, the popped element is swapped to the last index of the
# old array.
popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False)
ret.append(popped_item)
typ = popped_item.typ
Expand Down
8 changes: 1 addition & 7 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_element_ptr,
getpos,
make_setter,
pop_dyn_array,
unwrap_location,
)
from vyper.codegen.ir_node import IRnode
Expand All @@ -22,7 +21,6 @@
BaseType,
ByteArrayLike,
ByteArrayType,
DArrayType,
EnumType,
InterfaceType,
MappingType,
Expand Down Expand Up @@ -625,11 +623,7 @@ def parse_Call(self):
return arg_ir

elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop":
# TODO consider moving this to builtins
darray = Expr(self.expr.func.value, self.context).ir_node
assert len(self.expr.args) == 0
assert isinstance(darray.typ, DArrayType)
return pop_dyn_array(darray, return_popped_item=True)
return DISPATCH_TABLE["pop"].build_IR(self.expr, self.context, True)

elif (
# TODO use expr.func.type.is_internal once
Expand Down
19 changes: 3 additions & 16 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@
LOAD,
STORE,
IRnode,
append_dyn_array,
check_assign,
dummy_node_for_type,
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
zero_pad,
)
from vyper.codegen.expr import Expr
Expand Down Expand Up @@ -138,20 +134,11 @@ def parse_Call(self):
"append",
"pop",
):
# TODO: consider moving this to builtins
darray = Expr(self.stmt.func.value, self.context).ir_node
args = [Expr(x, self.context).ir_node for x in self.stmt.args]
funcname = self.stmt.func.attr
if self.stmt.func.attr == "append":
# sanity checks
assert len(args) == 1
arg = args[0]
assert isinstance(darray.typ, DArrayType)
check_assign(dummy_node_for_type(darray.typ.subtype), dummy_node_for_type(arg.typ))

return append_dyn_array(darray, arg)
return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context)
else:
assert len(args) == 0
return pop_dyn_array(darray, return_popped_item=False)
return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context, False)

elif is_self_function:
return self_call.ir_for_self_call(self.stmt, self.context)
Expand Down
Loading