From e9e245894ca5d509fece3c2b6a4090f8644da8af Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 3 Jul 2022 16:03:55 +0800 Subject: [PATCH 01/16] feat: add kwargs to member function --- vyper/semantics/types/function.py | 16 ++++++++++++++-- vyper/semantics/types/indexable/sequence.py | 13 ++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 06f783b370..72a1888bae 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -14,6 +14,7 @@ NamespaceCollision, StateAccessViolation, StructureException, + TypeMismatch, ) from vyper.semantics.namespace import get_namespace from vyper.semantics.types.bases import BaseTypeDefinition, DataLocation, StorageSlot @@ -565,6 +566,7 @@ def __init__( arg_types: List[BaseTypeDefinition], return_type: Optional[BaseTypeDefinition], is_modifying: bool, + kwargs: Dict[str, KwargSettings] = {}, ) -> None: super().__init__(DataLocation.UNSET) self.underlying_type = underlying_type @@ -572,18 +574,28 @@ def __init__( self.arg_types = arg_types self.return_type = return_type self.is_modifying = is_modifying + self.kwargs = kwargs def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" - def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: - validate_call_args(node, len(self.arg_types)) + def _validate_arg_types(self, node: vy_ast.Call): + num_args = len(self.arg_types) + validate_call_args(node, num_args, list(self.kwargs)) assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): # CMC 2022-04-01 this should probably be in the validation module validate_expected_type(arg, expected_type) + for kwarg in node.keywords: + kwarg_settings = self.kwargs[kwarg.arg] + if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): + raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) + validate_expected_type(kwarg.value, kwarg_settings.typ) + + def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: + self._validate_arg_types(node) return self.return_type diff --git a/vyper/semantics/types/indexable/sequence.py b/vyper/semantics/types/indexable/sequence.py index 9e8d855bf9..39ab464b5d 100644 --- a/vyper/semantics/types/indexable/sequence.py +++ b/vyper/semantics/types/indexable/sequence.py @@ -160,11 +160,22 @@ def __init__( # Adding members here as otherwise MemberFunctionDefinition is not yet defined # if added as _type_members from vyper.semantics.types.function import MemberFunctionDefinition + from vyper.semantics.types.utils import KwargSettings self.add_member( "append", MemberFunctionDefinition(self, "append", [self.value_type], None, True) ) - self.add_member("pop", MemberFunctionDefinition(self, "pop", [], self.value_type, True)) + self.add_member( + "pop", + MemberFunctionDefinition( + self, + "pop", + [], + self.value_type, + True, + kwargs={"ix": KwargSettings(Uint256Definition(), -1, require_literal=True)}, + ), + ) def __repr__(self): return f"DynArray[{self.value_type}, {self.length}]" From ea9ca097693cb601c01a0bf102aaf80027bae5d4 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 14:31:01 +0800 Subject: [PATCH 02/16] move pop to builtin --- vyper/builtin_functions/functions.py | 13 +++++++++++++ vyper/codegen/expr.py | 6 +----- vyper/codegen/stmt.py | 4 +++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index 7cb80f5f57..b3e4f9b837 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -31,6 +31,7 @@ get_element_ptr, ir_tuple_from_args, needs_external_call_wrap, + pop_dyn_array, promote_signed_int, shl, unwrap_location, @@ -42,6 +43,7 @@ BaseType, ByteArrayLike, ByteArrayType, + DArrayType, SArrayType, StringType, TupleType, @@ -2464,6 +2466,15 @@ def build_IR(self, expr, args, kwargs, context): ) +class Pop(BuiltinFunction): + _id = "pop" + + def build_IR(self, expr, context, darray, return_popped_item): + darray_ir = Expr(darray, context).ir_node + assert isinstance(darray_ir.typ, DArrayType) + return pop_dyn_array(darray_ir, return_popped_item=return_popped_item) + + DISPATCH_TABLE = { "_abi_encode": ABIEncode(), "_abi_decode": ABIDecode(), @@ -2505,6 +2516,7 @@ def build_IR(self, expr, args, kwargs, context): "max": Max(), "empty": Empty(), "abs": Abs(), + "pop": Pop(), } STMT_DISPATCH_TABLE = { @@ -2517,6 +2529,7 @@ def build_IR(self, expr, args, kwargs, context): "create_forwarder_to": CreateForwarderTo(), "create_copy_of": CreateCopyOf(), "create_from_factory": CreateFromFactory(), + "pop": Pop(), } BUILTIN_FUNCTIONS = {**STMT_DISPATCH_TABLE, **DISPATCH_TABLE}.keys() diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index e93dae32e1..966bb83b8f 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -625,11 +625,7 @@ def parse_Call(self): return arg_ir elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": - # TODO consider moving this to builtins - darray = Expr(self.expr.func.value, self.context).ir_node - assert len(self.expr.args) == 0 - assert isinstance(darray.typ, DArrayType) - return pop_dyn_array(darray, return_popped_item=True) + return DISPATCH_TABLE["pop"].build_IR(self.expr, self.context, self.expr.func.value, True) elif ( # TODO use expr.func.type.is_internal once diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index aa9a2de73f..361a2516e5 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -151,7 +151,9 @@ def parse_Call(self): return append_dyn_array(darray, arg) else: assert len(args) == 0 - return pop_dyn_array(darray, return_popped_item=False) + funcname = self.stmt.func.attr + return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context, self.stmt.func.value, False) + #return pop_dyn_array(darray, return_popped_item=False) elif is_self_function: return self_call.ir_for_self_call(self.stmt, self.context) From 6b62e171dc329f0a981a9611771feb7a179c47b3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 14:40:40 +0800 Subject: [PATCH 03/16] move append to builtin --- vyper/builtin_functions/functions.py | 29 ++++++++++++++++++++++++---- vyper/codegen/expr.py | 2 +- vyper/codegen/stmt.py | 17 +++------------- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index b3e4f9b837..021843428c 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -16,14 +16,17 @@ IRnode, _freshname, add_ofst, + append_dyn_array, bytes_data_ptr, calculate_type_for_external_return, + check_assign, check_external_call, clamp, clamp2, clamp_basetype, clamp_nonzero, copy_bytes, + dummy_node_for_type, ensure_in_memory, eval_once_check, eval_seq, @@ -2466,13 +2469,30 @@ def build_IR(self, expr, args, kwargs, context): ) +class Append(BuiltinFunction): + _id = "append" + + def build_IR(self, expr, context): + darray = Expr(expr.func.value, context).ir_node + args = [Expr(x, context).ir_node for x in expr.args] + + # sanity checks + assert len(args) == 1 + arg = args[0] + assert isinstance(darray.typ, DArrayType) + + check_assign(dummy_node_for_type(darray.typ.subtype), dummy_node_for_type(arg.typ)) + + return append_dyn_array(darray, arg) + class Pop(BuiltinFunction): _id = "pop" - def build_IR(self, expr, context, darray, return_popped_item): - darray_ir = Expr(darray, context).ir_node - assert isinstance(darray_ir.typ, DArrayType) - return pop_dyn_array(darray_ir, return_popped_item=return_popped_item) + def build_IR(self, expr, context, return_popped_item): + darray = Expr(expr.func.value, context).ir_node + assert isinstance(darray.typ, DArrayType) + assert len(expr.args) == 0 + return pop_dyn_array(darray, return_popped_item=return_popped_item) DISPATCH_TABLE = { @@ -2529,6 +2549,7 @@ def build_IR(self, expr, context, darray, return_popped_item): "create_forwarder_to": CreateForwarderTo(), "create_copy_of": CreateCopyOf(), "create_from_factory": CreateFromFactory(), + "append": Append(), "pop": Pop(), } diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 966bb83b8f..cc37c2230f 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -625,7 +625,7 @@ def parse_Call(self): return arg_ir elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": - return DISPATCH_TABLE["pop"].build_IR(self.expr, self.context, self.expr.func.value, True) + return DISPATCH_TABLE["pop"].build_IR(self.expr, self.context, True) elif ( # TODO use expr.func.type.is_internal once diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 361a2516e5..2926425b74 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -138,22 +138,11 @@ def parse_Call(self): "append", "pop", ): - # TODO: consider moving this to builtins - darray = Expr(self.stmt.func.value, self.context).ir_node - args = [Expr(x, self.context).ir_node for x in self.stmt.args] + funcname = self.stmt.func.attr if self.stmt.func.attr == "append": - # sanity checks - assert len(args) == 1 - arg = args[0] - assert isinstance(darray.typ, DArrayType) - check_assign(dummy_node_for_type(darray.typ.subtype), dummy_node_for_type(arg.typ)) - - return append_dyn_array(darray, arg) + return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context) else: - assert len(args) == 0 - funcname = self.stmt.func.attr - return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context, self.stmt.func.value, False) - #return pop_dyn_array(darray, return_popped_item=False) + return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context, False) elif is_self_function: return self_call.ir_for_self_call(self.stmt, self.context) From 5e3fbfece4cb8663a934bb3f0edc220cd37f7ed5 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 15:29:04 +0800 Subject: [PATCH 04/16] update codegen and semantics --- vyper/builtin_functions/functions.py | 20 ++++++++++++++++++-- vyper/codegen/core.py | 13 ++++++++++--- vyper/semantics/types/function.py | 2 ++ vyper/semantics/types/indexable/sequence.py | 2 +- vyper/semantics/validation/annotation.py | 3 +++ 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index 021843428c..ff58bcb5ae 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -117,7 +117,7 @@ vyper_warn, ) -from .signatures import BuiltinFunction, process_inputs +from .signatures import BuiltinFunction, process_inputs, process_kwarg SHA256_ADDRESS = 2 SHA256_BASE_GAS = 60 @@ -2488,11 +2488,27 @@ def build_IR(self, expr, context): class Pop(BuiltinFunction): _id = "pop" + def _get_kwarg_settings(self, expr): + call_type = get_exact_type_from_node(expr.func) + expected_kwargs = call_type.kwargs + return expected_kwargs + def build_IR(self, expr, context, return_popped_item): darray = Expr(expr.func.value, context).ir_node assert isinstance(darray.typ, DArrayType) assert len(expr.args) == 0 - return pop_dyn_array(darray, return_popped_item=return_popped_item) + + kwargs = self._get_kwarg_settings(expr) + + if expr.keywords: + assert len(expr.keywords) == 1 and expr.keywords[0].arg == "ix" + kwarg_settings = kwargs[expr.keywords[0].arg] + expected_kwarg_type = kwarg_settings.typ + idx = process_kwarg(expr.keywords[0].value, kwarg_settings, expected_kwarg_type, context) + return pop_dyn_array(darray, return_popped_item=return_popped_item, pop_idx=idx) + + else: + return pop_dyn_array(darray, return_popped_item=return_popped_item) DISPATCH_TABLE = { diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 82882f5a9f..36601fe4f1 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -290,7 +290,7 @@ def append_dyn_array(darray_node, elem_node): return IRnode.from_list(b1.resolve(b2.resolve(ret))) -def pop_dyn_array(darray_node, return_popped_item): +def pop_dyn_array(darray_node, return_popped_item, pop_idx=None): assert isinstance(darray_node.typ, DArrayType) assert darray_node.encoding == Encoding.VYPER ret = ["seq"] @@ -298,12 +298,19 @@ def pop_dyn_array(darray_node, return_popped_item): old_len = clamp("gt", get_dyn_array_count(darray_node), 0) new_len = IRnode.from_list(["sub", old_len, 1], typ="uint256") - with new_len.cache_when_complex("new_len") as (b2, new_len): + if pop_idx is not None: + # Pop from given index + idx = clamp("gt", get_dyn_array_count(darray_node), pop_idx) + else: + # Else, pop from last index + idx = new_len + + with idx.cache_when_complex("idx") as (b2, idx): ret.append(STORE(darray_node, new_len)) # NOTE skip array bounds check bc we already asserted len two lines up if return_popped_item: - popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) + popped_item = get_element_ptr(darray_node, idx, array_bounds_check=False) ret.append(popped_item) typ = popped_item.typ location = popped_item.location diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 72a1888bae..0c24e4e10a 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -598,6 +598,8 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: self._validate_arg_types(node) return self.return_type + def infer_kwarg_types(self, node: vy_ast.Call) -> Optional[Dict[str, BaseTypeDefinition]]: + return {i.arg: self.kwargs[i.arg].typ for i in node.keywords} def _generate_method_id(name: str, canonical_abi_types: List[str]) -> Dict[str, int]: function_sig = f"{name}({','.join(canonical_abi_types)})" diff --git a/vyper/semantics/types/indexable/sequence.py b/vyper/semantics/types/indexable/sequence.py index 39ab464b5d..d7aee33b25 100644 --- a/vyper/semantics/types/indexable/sequence.py +++ b/vyper/semantics/types/indexable/sequence.py @@ -173,7 +173,7 @@ def __init__( [], self.value_type, True, - kwargs={"ix": KwargSettings(Uint256Definition(), -1, require_literal=True)}, + kwargs={"ix": KwargSettings(Uint256Definition(), -1)}, ), ) diff --git a/vyper/semantics/validation/annotation.py b/vyper/semantics/validation/annotation.py index 011a255f61..71bc18ea9d 100644 --- a/vyper/semantics/validation/annotation.py +++ b/vyper/semantics/validation/annotation.py @@ -159,6 +159,9 @@ def visit_Call(self, node, type_): assert len(node.args) == len(call_type.arg_types) for arg, arg_type in zip(node.args, call_type.arg_types): self.visit(arg, arg_type) + kwarg_types = call_type.infer_kwarg_types(node) + for kwarg in node.keywords: + self.visit(kwarg.value, kwarg_types[kwarg.arg]) else: # builtin functions arg_types = call_type.infer_arg_types(node) From 41fb641b4dcbc683f0d18d16a251634a3b12e08f Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 15:41:52 +0800 Subject: [PATCH 05/16] fix lint --- vyper/builtin_functions/functions.py | 11 +++++++---- vyper/codegen/expr.py | 2 -- vyper/codegen/stmt.py | 4 ---- vyper/semantics/types/function.py | 13 ++++++++----- vyper/semantics/types/indexable/sequence.py | 4 ++-- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index ff58bcb5ae..98ff64a849 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -2485,12 +2485,13 @@ def build_IR(self, expr, context): return append_dyn_array(darray, arg) + class Pop(BuiltinFunction): _id = "pop" def _get_kwarg_settings(self, expr): call_type = get_exact_type_from_node(expr.func) - expected_kwargs = call_type.kwargs + expected_kwargs = call_type._kwargs return expected_kwargs def build_IR(self, expr, context, return_popped_item): @@ -2501,10 +2502,12 @@ def build_IR(self, expr, context, return_popped_item): kwargs = self._get_kwarg_settings(expr) if expr.keywords: - assert len(expr.keywords) == 1 and expr.keywords[0].arg == "ix" - kwarg_settings = kwargs[expr.keywords[0].arg] + kwarg_name = expr.keywords[0].arg + kwarg_val = expr.keywords[0].value + assert len(expr.keywords) == 1 and kwarg_name == "ix" + kwarg_settings = kwargs[kwarg_name] expected_kwarg_type = kwarg_settings.typ - idx = process_kwarg(expr.keywords[0].value, kwarg_settings, expected_kwarg_type, context) + idx = process_kwarg(kwarg_val, kwarg_settings, expected_kwarg_type, context) return pop_dyn_array(darray, return_popped_item=return_popped_item, pop_idx=idx) else: diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index cc37c2230f..335dd653c5 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -12,7 +12,6 @@ get_element_ptr, getpos, make_setter, - pop_dyn_array, unwrap_location, ) from vyper.codegen.ir_node import IRnode @@ -22,7 +21,6 @@ BaseType, ByteArrayLike, ByteArrayType, - DArrayType, EnumType, InterfaceType, MappingType, diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 2926425b74..8a4a54fb34 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -9,16 +9,12 @@ LOAD, STORE, IRnode, - append_dyn_array, - check_assign, - dummy_node_for_type, get_dyn_array_count, get_element_ptr, getpos, is_return_from_function, make_byte_array_copier, make_setter, - pop_dyn_array, zero_pad, ) from vyper.codegen.expr import Expr diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 0c24e4e10a..b4ec34ecec 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -558,6 +558,7 @@ class MemberFunctionDefinition(BaseTypeDefinition): """ _is_callable = True + _kwargs: Dict[str, KwargSettings] = {} def __init__( self, @@ -566,7 +567,7 @@ def __init__( arg_types: List[BaseTypeDefinition], return_type: Optional[BaseTypeDefinition], is_modifying: bool, - kwargs: Dict[str, KwargSettings] = {}, + kwargs: Optional[Dict[str, KwargSettings]], ) -> None: super().__init__(DataLocation.UNSET) self.underlying_type = underlying_type @@ -574,14 +575,15 @@ def __init__( self.arg_types = arg_types self.return_type = return_type self.is_modifying = is_modifying - self.kwargs = kwargs + if kwargs is not None: + self._kwargs = kwargs def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" def _validate_arg_types(self, node: vy_ast.Call): num_args = len(self.arg_types) - validate_call_args(node, num_args, list(self.kwargs)) + validate_call_args(node, num_args, list(self._kwargs)) assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): @@ -589,7 +591,7 @@ def _validate_arg_types(self, node: vy_ast.Call): validate_expected_type(arg, expected_type) for kwarg in node.keywords: - kwarg_settings = self.kwargs[kwarg.arg] + kwarg_settings = self._kwargs[kwarg.arg] if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) validate_expected_type(kwarg.value, kwarg_settings.typ) @@ -599,7 +601,8 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: return self.return_type def infer_kwarg_types(self, node: vy_ast.Call) -> Optional[Dict[str, BaseTypeDefinition]]: - return {i.arg: self.kwargs[i.arg].typ for i in node.keywords} + return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} + def _generate_method_id(name: str, canonical_abi_types: List[str]) -> Dict[str, int]: function_sig = f"{name}({','.join(canonical_abi_types)})" diff --git a/vyper/semantics/types/indexable/sequence.py b/vyper/semantics/types/indexable/sequence.py index d7aee33b25..8bec6af9f1 100644 --- a/vyper/semantics/types/indexable/sequence.py +++ b/vyper/semantics/types/indexable/sequence.py @@ -163,7 +163,7 @@ def __init__( from vyper.semantics.types.utils import KwargSettings self.add_member( - "append", MemberFunctionDefinition(self, "append", [self.value_type], None, True) + "append", MemberFunctionDefinition(self, "append", [self.value_type], None, True, None) ) self.add_member( "pop", @@ -173,7 +173,7 @@ def __init__( [], self.value_type, True, - kwargs={"ix": KwargSettings(Uint256Definition(), -1)}, + {"ix": KwargSettings(Uint256Definition(), -1)}, ), ) From 2224a73a9c3647edcbc4609d38935cf4ef6e4004 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 15:43:59 +0800 Subject: [PATCH 06/16] fix mypy lint --- vyper/semantics/types/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index b4ec34ecec..010aacea52 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -581,7 +581,7 @@ def __init__( def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" - def _validate_arg_types(self, node: vy_ast.Call): + def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self.arg_types) validate_call_args(node, num_args, list(self._kwargs)) From 65b371d07ce0c94248faa3ea88d911b6cdf85600 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 16:07:59 +0800 Subject: [PATCH 07/16] fix index in pop_dyn_array --- vyper/codegen/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 36601fe4f1..b1630a8b49 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -300,17 +300,17 @@ def pop_dyn_array(darray_node, return_popped_item, pop_idx=None): if pop_idx is not None: # Pop from given index - idx = clamp("gt", get_dyn_array_count(darray_node), pop_idx) + ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx)) else: # Else, pop from last index - idx = new_len + pop_idx = new_len - with idx.cache_when_complex("idx") as (b2, idx): + with pop_idx.cache_when_complex("idx") as (b2, pop_idx): ret.append(STORE(darray_node, new_len)) # NOTE skip array bounds check bc we already asserted len two lines up if return_popped_item: - popped_item = get_element_ptr(darray_node, idx, array_bounds_check=False) + popped_item = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) ret.append(popped_item) typ = popped_item.typ location = popped_item.location From 15b188c89dcd3fd95cb1de840d65db4c0cb01ab9 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 16:08:07 +0800 Subject: [PATCH 08/16] add pop index tests --- tests/parser/types/test_dynamic_array.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index d274c6919a..2124937b75 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -1150,6 +1150,21 @@ def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_dat assert c.foo(test_data) == expected_result +@pytest.mark.parametrize("arr", [[1], [1, 2], [1, 2, 3, 4, 5]]) +@pytest.mark.parametrize("idx", [0, 1, 2, 3, 4, 5]) +def test_pop_index(get_contract, assert_tx_failed, arr, idx): + code = """ +@external +def foo(a: DynArray[uint256, 5], b: uint256) -> uint256: + return a.pop(ix=b) + """ + c = get_contract(code) + if idx >= len(arr): + assert_tx_failed(lambda: c.foo(arr, idx)) + else: + assert c.foo(arr, idx) == arr[idx] + + append_pop_complex_tests = [ ( """ From a84a1cb38f1744f9f5f9125048cf997bb3ea4118 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 4 Jul 2022 18:03:24 +0800 Subject: [PATCH 09/16] add tests --- tests/parser/types/test_dynamic_array.py | 72 +++++++++++++++++++++--- vyper/codegen/core.py | 3 +- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index 2124937b75..dc25157930 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -1150,19 +1150,77 @@ def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_dat assert c.foo(test_data) == expected_result -@pytest.mark.parametrize("arr", [[1], [1, 2], [1, 2, 3, 4, 5]]) -@pytest.mark.parametrize("idx", [0, 1, 2, 3, 4, 5]) -def test_pop_index(get_contract, assert_tx_failed, arr, idx): +@pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) +def test_pop_index_return(get_contract, assert_tx_failed, test_data): code = """ @external def foo(a: DynArray[uint256, 5], b: uint256) -> uint256: return a.pop(ix=b) """ c = get_contract(code) - if idx >= len(arr): - assert_tx_failed(lambda: c.foo(arr, idx)) - else: - assert c.foo(arr, idx) == arr[idx] + arr_length = len(test_data) + for idx in range(arr_length + 1): + if idx >= arr_length: + assert_tx_failed(lambda: c.foo(test_data, idx)) + else: + assert c.foo(test_data, idx) == test_data[idx] + + +pop_index_tests = [ + ( + """ +my_array: DynArray[uint256, 5] +@external +def foo(xs: DynArray[uint256, 5], i: uint256) -> DynArray[uint256, 5]: + for x in xs: + self.my_array.append(x) + for x in xs: + self.my_array.pop(ix=0) + return self.my_array + """, + lambda xs, idx: [], + ), + # check order of evaluation. + ( + """ +my_array: DynArray[uint256, 5] +@external +def foo(xs: DynArray[uint256, 5], i: uint256) -> (DynArray[uint256, 5], uint256): + for x in xs: + self.my_array.append(x) + return self.my_array, self.my_array.pop(ix=i) + """, + lambda xs, idx: None if len(xs) == 0 else [xs[:idx] + xs[idx+1:], xs[idx]], + ), + # check order of evaluation. + ( + """ +my_array: DynArray[uint256, 5] +@external +def foo(xs: DynArray[uint256, 5], i: uint256) -> (uint256, DynArray[uint256, 5]): + for x in xs: + self.my_array.append(x) + return self.my_array.pop(ix=i), self.my_array + """, + lambda xs, idx: None if len(xs) == 0 else [xs[idx], xs[:idx] + xs[idx+1:]], + ), +] + + +@pytest.mark.parametrize("code,check_result", pop_index_tests) +# TODO change this to fuzz random data +@pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) +def test_pop_index(get_contract, assert_tx_failed, code, check_result, test_data): + c = get_contract(code) + + arr_length = len(test_data) + for idx in range(arr_length): + expected_result = check_result(test_data, idx) + if expected_result is None: + # None is sentinel to indicate txn should revert + assert_tx_failed(lambda: c.foo(test_data, idx)) + else: + assert c.foo(test_data, idx) == expected_result append_pop_complex_tests = [ diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index b1630a8b49..c90f7e6d97 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -299,13 +299,14 @@ def pop_dyn_array(darray_node, return_popped_item, pop_idx=None): new_len = IRnode.from_list(["sub", old_len, 1], typ="uint256") if pop_idx is not None: - # Pop from given index + # If pop from given index, assert that array length is greater than index ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx)) else: # Else, pop from last index pop_idx = new_len with pop_idx.cache_when_complex("idx") as (b2, pop_idx): + # TODO Update darray ret.append(STORE(darray_node, new_len)) # NOTE skip array bounds check bc we already asserted len two lines up From e90664618fa798dea9f586fc045721be287397a9 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 00:28:30 +0800 Subject: [PATCH 10/16] add loop to codegen to bubble swap values --- vyper/builtin_functions/functions.py | 4 ++-- vyper/codegen/core.py | 32 ++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index 98ff64a849..3fbc259773 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -2508,10 +2508,10 @@ def build_IR(self, expr, context, return_popped_item): kwarg_settings = kwargs[kwarg_name] expected_kwarg_type = kwarg_settings.typ idx = process_kwarg(kwarg_val, kwarg_settings, expected_kwarg_type, context) - return pop_dyn_array(darray, return_popped_item=return_popped_item, pop_idx=idx) + return pop_dyn_array(context, darray, return_popped_item=return_popped_item, pop_idx=idx) else: - return pop_dyn_array(darray, return_popped_item=return_popped_item) + return pop_dyn_array(context, darray, return_popped_item=return_popped_item) DISPATCH_TABLE = { diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c90f7e6d97..d018365a1e 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -290,7 +290,7 @@ def append_dyn_array(darray_node, elem_node): return IRnode.from_list(b1.resolve(b2.resolve(ret))) -def pop_dyn_array(darray_node, return_popped_item, pop_idx=None): +def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): assert isinstance(darray_node.typ, DArrayType) assert darray_node.encoding == Encoding.VYPER ret = ["seq"] @@ -309,9 +309,37 @@ def pop_dyn_array(darray_node, return_popped_item, pop_idx=None): # TODO Update darray ret.append(STORE(darray_node, new_len)) + if pop_idx is not None: + # Swap index to pop with the old last index + dst_i = get_element_ptr(darray_node, old_len, array_bounds_check=False) + src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) + ret.append(make_setter(dst_i, src_i)) + + # Iterate from popped index to the new last index and swap + # Set up the loop variable + loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256") + next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") + + # Swap value at index loop_var with index loop_var + 1 + loop_body = [ + "seq", + make_setter( + get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i + get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i + ), + ] + + # Set loop termination as new_index - 1 + iter = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") + + # Set dynarray length as repeat bound + repeat_bound = darray_node.typ.count + + ret.append(["repeat", loop_var, pop_idx, iter, repeat_bound, loop_body]) + # NOTE skip array bounds check bc we already asserted len two lines up if return_popped_item: - popped_item = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) + popped_item = get_element_ptr(darray_node, old_len, array_bounds_check=False) ret.append(popped_item) typ = popped_item.typ location = popped_item.location From 70b582ec5639cde7854392de26e5d337532ed57e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 00:43:07 +0800 Subject: [PATCH 11/16] fix index --- vyper/codegen/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index d018365a1e..a77b904461 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -301,11 +301,12 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): if pop_idx is not None: # If pop from given index, assert that array length is greater than index ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx)) + idx = old_len else: # Else, pop from last index - pop_idx = new_len + idx = new_len - with pop_idx.cache_when_complex("idx") as (b2, pop_idx): + with pop_idx.cache_when_complex("pop_idx") as (b2, pop_idx): # TODO Update darray ret.append(STORE(darray_node, new_len)) @@ -339,7 +340,7 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): # NOTE skip array bounds check bc we already asserted len two lines up if return_popped_item: - popped_item = get_element_ptr(darray_node, old_len, array_bounds_check=False) + popped_item = get_element_ptr(darray_node, idx, array_bounds_check=False) ret.append(popped_item) typ = popped_item.typ location = popped_item.location From 67a59a8a5beebf0c7c234dbd9ca086b3d120af14 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 10:33:58 +0800 Subject: [PATCH 12/16] wip codegen --- vyper/codegen/core.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index a77b904461..a4d13c302b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -301,20 +301,19 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): if pop_idx is not None: # If pop from given index, assert that array length is greater than index ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx)) - idx = old_len - else: - # Else, pop from last index - idx = new_len with pop_idx.cache_when_complex("pop_idx") as (b2, pop_idx): - # TODO Update darray ret.append(STORE(darray_node, new_len)) + # Modify dynamic array if pop_idx is not None: + body = ["seq"] + # Swap index to pop with the old last index - dst_i = get_element_ptr(darray_node, old_len, array_bounds_check=False) + dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False) src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) - ret.append(make_setter(dst_i, src_i)) + swap = make_setter(dst_i, src_i) + body.append(swap) # Iterate from popped index to the new last index and swap # Set up the loop variable @@ -322,25 +321,38 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") # Swap value at index loop_var with index loop_var + 1 - loop_body = [ + loop_body = IRnode.from_list([ "seq", make_setter( get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i ), - ] + ]) - # Set loop termination as new_index - 1 - iter = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") + # Set loop termination as new_len - 2 + iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") # Set dynarray length as repeat bound repeat_bound = darray_node.typ.count + loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body]) + + # Perform loop only if new_len is at least 2 + length_cmp = IRnode.from_list(["ge", new_len, 2]) + length_check = IRnode.from_list(["if", length_cmp, loop]) + body.append(length_check) + + # Perform the initial swap only if popped index is not the last index + swap_test = IRnode.from_list(["lt", pop_idx, new_len]) + swap_check = IRnode.from_list(["if", swap_test, body]) - ret.append(["repeat", loop_var, pop_idx, iter, repeat_bound, loop_body]) + ret.append(swap_check) # NOTE skip array bounds check bc we already asserted len two lines up if return_popped_item: - popped_item = get_element_ptr(darray_node, idx, array_bounds_check=False) + # Set index of popped element to last index of old array + # For pop with index, the popped element is swapped to the last index of the + # old array. + popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) ret.append(popped_item) typ = popped_item.typ location = popped_item.location From 666b5ad8d9cf577846af71aa602f2d53345f6c24 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:42:12 +0800 Subject: [PATCH 13/16] cache new darray len --- vyper/codegen/core.py | 110 +++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index a4d13c302b..99876d26ff 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -302,63 +302,63 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): # If pop from given index, assert that array length is greater than index ret.append(clamp("gt", get_dyn_array_count(darray_node), pop_idx)) - with pop_idx.cache_when_complex("pop_idx") as (b2, pop_idx): + with new_len.cache_when_complex("new_len") as (b2, new_len): ret.append(STORE(darray_node, new_len)) - # Modify dynamic array - if pop_idx is not None: - body = ["seq"] - - # Swap index to pop with the old last index - dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False) - src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) - swap = make_setter(dst_i, src_i) - body.append(swap) - - # Iterate from popped index to the new last index and swap - # Set up the loop variable - loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256") - next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") - - # Swap value at index loop_var with index loop_var + 1 - loop_body = IRnode.from_list([ - "seq", - make_setter( - get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i - get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i - ), - ]) - - # Set loop termination as new_len - 2 - iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") - - # Set dynarray length as repeat bound - repeat_bound = darray_node.typ.count - loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body]) - - # Perform loop only if new_len is at least 2 - length_cmp = IRnode.from_list(["ge", new_len, 2]) - length_check = IRnode.from_list(["if", length_cmp, loop]) - body.append(length_check) - - # Perform the initial swap only if popped index is not the last index - swap_test = IRnode.from_list(["lt", pop_idx, new_len]) - swap_check = IRnode.from_list(["if", swap_test, body]) - - ret.append(swap_check) - - # NOTE skip array bounds check bc we already asserted len two lines up - if return_popped_item: - # Set index of popped element to last index of old array - # For pop with index, the popped element is swapped to the last index of the - # old array. - popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) - ret.append(popped_item) - typ = popped_item.typ - location = popped_item.location - else: - typ, location = None, None - return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=typ, location=location) + with pop_idx.cache_when_complex("pop_idx") as (b3, pop_idx): + # Modify dynamic array + if pop_idx is not None: + body = ["seq"] + + # Swap index to pop with the old last index + dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False) + src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) + swap = make_setter(dst_i, src_i) + body.append(swap) + + # Iterate from popped index to the new last index and swap + # Set up the loop variable + loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256") + next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") + + # Swap value at index loop_var with index loop_var + 1 + loop_body = IRnode.from_list([ + "seq", + make_setter( + get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i + get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i + ), + ]) + + # Set loop termination as new_len - 2 + iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") + + # Set dynarray length as repeat bound + repeat_bound = darray_node.typ.count + loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body]) + + # Enter loop only if new_len is at least 2 + length_check = IRnode.from_list(["if", ["ge", new_len, 2], loop]) + body.append(length_check) + print("body: ", body) + # Perform the initial swap only if popped index is not the last index + swap_test = IRnode.from_list(["lt", pop_idx, new_len]) + swap_check = IRnode.from_list(["if", swap_test, body]) + + ret.append(swap_check) + + # NOTE skip array bounds check bc we already asserted len two lines up + if return_popped_item: + # Set index of popped element to last index of old array + # For pop with index, the popped element is swapped to the last index of the + # old array. + popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) + ret.append(popped_item) + typ = popped_item.typ + location = popped_item.location + else: + typ, location = None, None + return IRnode.from_list(b1.resolve(b2.resolve(b3.resolve(ret))), typ=typ, location=location) def getpos(node): From 74dcac0f5c87eb0702fff296b6c2ad155a547bee Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:42:18 +0800 Subject: [PATCH 14/16] fix test --- tests/parser/types/test_dynamic_array.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index dc25157930..8bfd1ca561 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -1151,7 +1151,8 @@ def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_dat @pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) -def test_pop_index_return(get_contract, assert_tx_failed, test_data): +@pytest.mark.parametrize("ix", [i for i in range(6)]) +def test_pop_index_return(get_contract, assert_tx_failed, test_data, ix): code = """ @external def foo(a: DynArray[uint256, 5], b: uint256) -> uint256: @@ -1159,11 +1160,10 @@ def foo(a: DynArray[uint256, 5], b: uint256) -> uint256: """ c = get_contract(code) arr_length = len(test_data) - for idx in range(arr_length + 1): - if idx >= arr_length: - assert_tx_failed(lambda: c.foo(test_data, idx)) - else: - assert c.foo(test_data, idx) == test_data[idx] + if ix >= arr_length: + assert_tx_failed(lambda: c.foo(test_data, ix)) + else: + assert c.foo(test_data, ix) == test_data[ix] pop_index_tests = [ From 2f27069dbd481decade06a18977af9baeec2f381 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:44:48 +0800 Subject: [PATCH 15/16] remove errant print; rename tests --- tests/parser/types/test_dynamic_array.py | 4 ++-- vyper/codegen/core.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index 8bfd1ca561..f40abf69d7 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -1152,7 +1152,7 @@ def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_dat @pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) @pytest.mark.parametrize("ix", [i for i in range(6)]) -def test_pop_index_return(get_contract, assert_tx_failed, test_data, ix): +def test_pop_index_return_pass(get_contract, assert_tx_failed, test_data, ix): code = """ @external def foo(a: DynArray[uint256, 5], b: uint256) -> uint256: @@ -1210,7 +1210,7 @@ def foo(xs: DynArray[uint256, 5], i: uint256) -> (uint256, DynArray[uint256, 5]) @pytest.mark.parametrize("code,check_result", pop_index_tests) # TODO change this to fuzz random data @pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) -def test_pop_index(get_contract, assert_tx_failed, code, check_result, test_data): +def test_pop_index_pass(get_contract, assert_tx_failed, code, check_result, test_data): c = get_contract(code) arr_length = len(test_data) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 99876d26ff..b489226959 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -340,7 +340,7 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): # Enter loop only if new_len is at least 2 length_check = IRnode.from_list(["if", ["ge", new_len, 2], loop]) body.append(length_check) - print("body: ", body) + # Perform the initial swap only if popped index is not the last index swap_test = IRnode.from_list(["lt", pop_idx, new_len]) swap_check = IRnode.from_list(["if", swap_test, body]) From f88d978e919c171388bfc3d47ffcda9d287783f8 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 5 Jul 2022 18:21:11 +0800 Subject: [PATCH 16/16] add temp buffer; add tests --- tests/parser/types/test_dynamic_array.py | 12 +++ vyper/codegen/core.py | 120 +++++++++++++---------- 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index f40abf69d7..ad0c0fa1a2 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -1180,6 +1180,18 @@ def foo(xs: DynArray[uint256, 5], i: uint256) -> DynArray[uint256, 5]: """, lambda xs, idx: [], ), + ( + """ +my_array: DynArray[uint256, 5] +@external +def foo(xs: DynArray[uint256, 5], i: uint256) -> DynArray[uint256, 5]: + for x in xs: + self.my_array.append(x) + self.my_array.pop(ix=i) + return self.my_array + """, + lambda xs, idx: None if len(xs) == 0 else xs[:idx] + xs[idx+1:], + ), # check order of evaluation. ( """ diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index b489226959..19c194a280 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -305,60 +305,72 @@ def pop_dyn_array(context, darray_node, return_popped_item, pop_idx=None): with new_len.cache_when_complex("new_len") as (b2, new_len): ret.append(STORE(darray_node, new_len)) - with pop_idx.cache_when_complex("pop_idx") as (b3, pop_idx): - # Modify dynamic array - if pop_idx is not None: - body = ["seq"] - - # Swap index to pop with the old last index - dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False) - src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) - swap = make_setter(dst_i, src_i) - body.append(swap) - - # Iterate from popped index to the new last index and swap - # Set up the loop variable - loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256") - next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") - - # Swap value at index loop_var with index loop_var + 1 - loop_body = IRnode.from_list([ - "seq", - make_setter( - get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i - get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i - ), - ]) - - # Set loop termination as new_len - 2 - iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") - - # Set dynarray length as repeat bound - repeat_bound = darray_node.typ.count - loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body]) - - # Enter loop only if new_len is at least 2 - length_check = IRnode.from_list(["if", ["ge", new_len, 2], loop]) - body.append(length_check) - - # Perform the initial swap only if popped index is not the last index - swap_test = IRnode.from_list(["lt", pop_idx, new_len]) - swap_check = IRnode.from_list(["if", swap_test, body]) - - ret.append(swap_check) - - # NOTE skip array bounds check bc we already asserted len two lines up - if return_popped_item: - # Set index of popped element to last index of old array - # For pop with index, the popped element is swapped to the last index of the - # old array. - popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) - ret.append(popped_item) - typ = popped_item.typ - location = popped_item.location - else: - typ, location = None, None - return IRnode.from_list(b1.resolve(b2.resolve(b3.resolve(ret))), typ=typ, location=location) + # Modify dynamic array + if pop_idx is not None: + body = ["seq"] + + # Swap index to pop with the old last index using a temporary buffer + dst_i = get_element_ptr(darray_node, new_len, array_bounds_check=False) + buf = context.new_internal_variable(darray_node.typ.subtype) + buf = IRnode.from_list(buf, typ=darray_node.typ.subtype, location=MEMORY) + src_i = get_element_ptr(darray_node, pop_idx, array_bounds_check=False) + + save_dst = make_setter(buf, dst_i) + mv_src = make_setter(dst_i, src_i) + mv_dst = make_setter(src_i, buf) + + initial_swap = IRnode.from_list(["seq", save_dst, mv_src, mv_dst]) + body.append(initial_swap) + + # Iterate from popped index to the new last index and swap + # Set up the loop variable + loop_var = IRnode.from_list(context.fresh_varname("dynarray_pop_ix"), typ="uint256") + next_ix = IRnode.from_list(["add", loop_var, 1], typ="uint256") + + # Swap value at index loop_var with index loop_var + 1 + loop_save_dst = make_setter( + buf, + get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i + ) + loop_mv_src = make_setter( + get_element_ptr(darray_node, loop_var, array_bounds_check=False), # dst_i + get_element_ptr(darray_node, next_ix, array_bounds_check=False) # src_i + ) + loop_mv_dst = make_setter( + get_element_ptr(darray_node, next_ix, array_bounds_check=False), # src_i + buf, + ) + loop_body = IRnode.from_list(["seq", loop_save_dst, loop_mv_src, loop_mv_dst]) + + # Set loop termination as new_len - 1 + iter_count = IRnode.from_list(["sub", IRnode.from_list(["sub", new_len, pop_idx], typ="uint256"), 1], typ="uint256") + + # Set dynarray length as repeat bound + repeat_bound = darray_node.typ.count + loop = IRnode.from_list(["repeat", loop_var, pop_idx, iter_count, repeat_bound, loop_body]) + + # Enter loop only if new_len is at least 2 + length_check = IRnode.from_list(["if", ["ge", new_len, 2], loop]) + body.append(length_check) + + # Perform the initial swap only if popped index is not the last index + swap_test = IRnode.from_list(["lt", pop_idx, new_len]) + swap_check = IRnode.from_list(["if", swap_test, body]) + + ret.append(swap_check) + + # NOTE skip array bounds check bc we already asserted len two lines up + if return_popped_item: + # Set index of popped element to last index of old array + # For pop with index, the popped element is swapped to the last index of the + # old array. + popped_item = get_element_ptr(darray_node, new_len, array_bounds_check=False) + ret.append(popped_item) + typ = popped_item.typ + location = popped_item.location + else: + typ, location = None, None + return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=typ, location=location) def getpos(node):