Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[venom]: stack2mem pass implementation #4245

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
2 changes: 0 additions & 2 deletions tests/functional/codegen/features/test_clampers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from eth_utils import keccak

from tests.utils import ZERO_ADDRESS, decimal_to_int
from vyper.exceptions import StackTooDeep
from vyper.utils import int_bounds


Expand Down Expand Up @@ -502,7 +501,6 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]:


@pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)])
@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_multidimension_dynarray_clamper_passing(get_contract, value):
code = """
@external
Expand Down
2 changes: 0 additions & 2 deletions tests/functional/codegen/types/test_dynamic_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
CompilerPanic,
ImmutableViolation,
OverflowException,
StackTooDeep,
StateAccessViolation,
TypeMismatch,
)
Expand Down Expand Up @@ -737,7 +736,6 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]:
]


@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_mult_list(get_contract):
code = """
nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2]
Expand Down
116 changes: 116 additions & 0 deletions tests/unit/compiler/venom/test_mem_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest

from vyper.venom.mem_allocator import MemoryAllocator

MEM_BLOCK_ADDRESS = 0x1000


@pytest.fixture
def allocator():
return MemoryAllocator(1024, MEM_BLOCK_ADDRESS)


def test_initial_state(allocator):
assert allocator.get_free_memory() == 1024
assert allocator.get_allocated_memory() == 0


def test_single_allocation(allocator):
addr = allocator.allocate(256)
assert addr == MEM_BLOCK_ADDRESS
assert allocator.get_free_memory() == 768
assert allocator.get_allocated_memory() == 256


def test_multiple_allocations(allocator):
addr1 = allocator.allocate(256)
addr2 = allocator.allocate(128)
addr3 = allocator.allocate(64)

assert addr1 == MEM_BLOCK_ADDRESS
assert addr2 == MEM_BLOCK_ADDRESS + 256
assert addr3 == MEM_BLOCK_ADDRESS + 384
assert allocator.get_free_memory() == 576
assert allocator.get_allocated_memory() == 448


def test_deallocation(allocator):
addr1 = allocator.allocate(256)
addr2 = allocator.allocate(128)

assert allocator.deallocate(addr1) is True
assert allocator.get_free_memory() == 896
assert allocator.get_allocated_memory() == 128

assert allocator.deallocate(addr2) is True
assert allocator.get_free_memory() == 1024
assert allocator.get_allocated_memory() == 0


def test_allocation_after_deallocation(allocator):
addr1 = allocator.allocate(256)
allocator.deallocate(addr1)
addr2 = allocator.allocate(128)

assert addr2 == MEM_BLOCK_ADDRESS
assert allocator.get_free_memory() == 896
assert allocator.get_allocated_memory() == 128


def test_out_of_memory(allocator):
allocator.allocate(1000)
with pytest.raises(MemoryError):
allocator.allocate(100)


def test_invalid_deallocation(allocator):
assert allocator.deallocate(0x2000) is False


def test_fragmentation_and_merging(allocator):
addr1 = allocator.allocate(256)
addr2 = allocator.allocate(256)
addr3 = allocator.allocate(256)

assert allocator.get_free_memory() == 256
assert allocator.get_allocated_memory() == 768

allocator.deallocate(addr1)
assert allocator.get_free_memory() == 512
assert allocator.get_allocated_memory() == 512

allocator.deallocate(addr3)
assert allocator.get_free_memory() == 768
assert allocator.get_allocated_memory() == 256

addr4 = allocator.allocate(512)
assert addr4 == MEM_BLOCK_ADDRESS + 512
assert allocator.get_free_memory() == 256
assert allocator.get_allocated_memory() == 768

allocator.deallocate(addr2)
assert allocator.get_free_memory() == 512
assert allocator.get_allocated_memory() == 512

allocator.deallocate(addr4)
assert allocator.get_free_memory() == 1024 # All blocks merged
assert allocator.get_allocated_memory() == 0

# Test if we can now allocate the entire memory
addr5 = allocator.allocate(1024)
assert addr5 == MEM_BLOCK_ADDRESS
assert allocator.get_free_memory() == 0
assert allocator.get_allocated_memory() == 1024


def test_exact_fit_allocation(allocator):
addr1 = allocator.allocate(1024)
assert addr1 == MEM_BLOCK_ADDRESS
assert allocator.get_free_memory() == 0
assert allocator.get_allocated_memory() == 1024

allocator.deallocate(addr1)
addr2 = allocator.allocate(1024)
assert addr2 == MEM_BLOCK_ADDRESS
assert allocator.get_free_memory() == 0
assert allocator.get_allocated_memory() == 1024
13 changes: 9 additions & 4 deletions vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ class CodegenPanic(VyperInternalException):
"""Invalid code generated during codegen phase"""


class StackTooDeep(CodegenPanic):
"""Stack too deep""" # (should not happen)


class UnexpectedNodeType(VyperInternalException):
"""Unexpected AST node type."""

Expand All @@ -424,6 +420,15 @@ class InvalidABIType(VyperInternalException):
"""An internal routine constructed an invalid ABI type"""


class UnreachableStackException(VyperException):

"""An unreachable stack operation was encountered."""

def __init__(self, message, op):
self.op = op
super().__init__(message)


@contextlib.contextmanager
def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None):
try:
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def _import_to_path(level: int, module_str: str) -> PurePath:
base_path = "../" * (level - 1)
elif level == 1:
base_path = "./"
return PurePath(f"{base_path}{module_str.replace('.','/')}/")
return PurePath(f"{base_path}{module_str.replace('.', '/')}/")


# can add more, e.g. "vyper.builtins.interfaces", etc.
Expand Down
3 changes: 3 additions & 0 deletions vyper/venom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vyper.codegen.ir_node import IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRVariable
from vyper.venom.context import IRContext
from vyper.venom.function import IRFunction
from vyper.venom.ir_node_to_venom import ir_node_to_venom
Expand All @@ -18,6 +19,7 @@
from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass
from vyper.venom.passes.sccp import SCCP
from vyper.venom.passes.simplify_cfg import SimplifyCFGPass
from vyper.venom.passes.stack2mem import Stack2Mem
from vyper.venom.passes.store_elimination import StoreElimination
from vyper.venom.venom_to_assembly import VenomCompiler

Expand Down Expand Up @@ -57,6 +59,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None:
ExtractLiteralsPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()
DFTPass(ac, fn).run_pass()
Stack2Mem(ac, fn).run_pass()


def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> IRContext:
Expand Down
5 changes: 5 additions & 0 deletions vyper/venom/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from vyper.venom.basicblock import IRInstruction, IRLabel, IROperand
from vyper.venom.function import IRFunction
from vyper.venom.mem_allocator import MemoryAllocator


class IRContext:
Expand All @@ -10,13 +11,17 @@ class IRContext:
immutables_len: Optional[int]
data_segment: list[IRInstruction]
last_label: int
mem_allocator: MemoryAllocator

def __init__(self) -> None:
self.functions = {}
self.ctor_mem_size = None
self.immutables_len = None
self.data_segment = []
self.last_label = 0
self.mem_allocator = MemoryAllocator(
4096, 0x100000
) # TODO: Should get this from the original IR

def add_function(self, fn: IRFunction) -> None:
fn.ctx = self
Expand Down
58 changes: 58 additions & 0 deletions vyper/venom/mem_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List


class MemoryBlock:
size: int
address: int
is_free: bool

def __init__(self, size: int, address: int):
self.size = size
self.address = address
self.is_free = True


class MemoryAllocator:
total_size: int
start_address: int
blocks: List[MemoryBlock]

def __init__(self, total_size: int, start_address: int):
self.total_size = total_size
self.start_address = start_address
self.blocks = [MemoryBlock(total_size, 0)]

def allocate(self, size: int) -> int:
for block in self.blocks:
if block.is_free and block.size >= size:
if block.size > size:
new_block = MemoryBlock(block.size - size, block.address + size)
self.blocks.insert(self.blocks.index(block) + 1, new_block)
block.size = size
block.is_free = False
return self.start_address + block.address
raise MemoryError("Memory allocation failed")

def deallocate(self, address: int) -> bool:
relative_address = address - self.start_address
for block in self.blocks:
if block.address == relative_address:
block.is_free = True
self._merge_adjacent_free_blocks()
return True
return False # invalid address

def _merge_adjacent_free_blocks(self) -> None:
i = 0
while i < len(self.blocks) - 1:
if self.blocks[i].is_free and self.blocks[i + 1].is_free:
self.blocks[i].size += self.blocks[i + 1].size
self.blocks.pop(i + 1)
else:
i += 1

def get_free_memory(self) -> int:
return sum(block.size for block in self.blocks if block.is_free)

def get_allocated_memory(self) -> int:
return sum(block.size for block in self.blocks if not block.is_free)
70 changes: 70 additions & 0 deletions vyper/venom/passes/stack2mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from vyper.exceptions import UnreachableStackException
from vyper.venom.analysis.cfg import CFGAnalysis
from vyper.venom.analysis.dfg import DFGAnalysis
from vyper.venom.analysis.liveness import LivenessAnalysis
from vyper.venom.basicblock import IRInstruction, IRLiteral, IRVariable
from vyper.venom.mem_allocator import MemoryAllocator
from vyper.venom.passes.base_pass import IRPass
from vyper.venom.venom_to_assembly import VenomCompiler


class Stack2Mem(IRPass):
mem_allocator: MemoryAllocator

def run_pass(self):
fn = self.function
self.mem_allocator = self.function.ctx.mem_allocator
self.analyses_cache.request_analysis(CFGAnalysis)
dfg = self.analyses_cache.request_analysis(DFGAnalysis)
self.analyses_cache.request_analysis(LivenessAnalysis)

while True:
compiler = VenomCompiler([fn.ctx])
try:
compiler.generate_evm()
break
except Exception as e:
if isinstance(e, UnreachableStackException):
self._demote_variable(dfg, e.op)
self.analyses_cache.force_analysis(LivenessAnalysis)
else:
break

self.analyses_cache.invalidate_analysis(DFGAnalysis)

def _demote_variable(self, dfg: DFGAnalysis, var: IRVariable):
"""
Demote a stack variable to memory operations.
"""
uses = dfg.get_uses(var)
def_inst = dfg.get_producing_instruction(var)

# Allocate memory for this variable
mem_addr = self.mem_allocator.allocate(32)

if def_inst is not None:
self._insert_mstore_after(def_inst, mem_addr)

for inst in uses:
self._insert_mload_before(inst, mem_addr, var)

def _insert_mstore_after(self, inst: IRInstruction, mem_addr: int):
bb = inst.parent
idx = bb.instructions.index(inst)
assert inst.output is not None
# mem_var = IRVariable(f"mem_{mem_addr}")
# bb.insert_instruction(
# IRInstruction("alloca", [IRLiteral(mem_addr), 32], mem_var), idx + 1
# )
new_var = self.function.get_next_variable()
bb.insert_instruction(IRInstruction("mstore", [new_var, IRLiteral(mem_addr)]), idx + 1)
inst.output = new_var

def _insert_mload_before(self, inst: IRInstruction, mem_addr: int, var: IRVariable):
bb = inst.parent
idx = bb.instructions.index(inst)
new_var = self.function.get_next_variable()
load_inst = IRInstruction("mload", [IRLiteral(mem_addr)])
load_inst.output = new_var
bb.insert_instruction(load_inst, idx)
inst.replace_operands({var: new_var})
Loading
Loading