Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Optimize compiler times by prevent huge loop unrolling when filling inline arrays. #4046

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
20 changes: 18 additions & 2 deletions mojo/stdlib/src/collections/inline_array.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from collections import InlineArray
from collections._index_normalization import normalize_index
from sys.intrinsics import _type_is_eq

import math
from memory import UnsafePointer
from memory.maybe_uninitialized import UnsafeMaybeUninitialized

Expand Down Expand Up @@ -131,17 +132,32 @@ struct InlineArray[

@always_inline
@implicit
fn __init__(out self, fill: Self.ElementType):
fn __init__[batch_size: Int = 64](out self, fill: Self.ElementType):
"""Constructs an empty array where each element is the supplied `fill`.

Parameters:
batch_size: The number of elements to unroll for filling the array.

Args:
fill: The element to fill each index.
"""
_inline_array_construction_checks[size]()
__mlir_op.`lit.ownership.mark_initialized`(__get_mvalue_as_litref(self))

alias unroll_end = math.align_down(size, batch_size)

for i in range(0, unroll_end, batch_size):

@parameter
for j in range(batch_size):
var ptr = UnsafePointer.address_of(
self.unsafe_get(i * batch_size + j)
)
ptr.init_pointee_copy(fill)

# Fill the remainder
@parameter
for i in range(size):
for i in range(unroll_end, size):
var ptr = UnsafePointer.address_of(self.unsafe_get(i))
ptr.init_pointee_copy(fill)

Expand Down