Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less constrained ordered task prep #1335

Merged
merged 3 commits into from
Oct 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions tests/trinity/utils/test_ordered_task_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ async def wait(coro, timeout=DEFAULT_TIMEOUT):
return await asyncio.wait_for(coro, timeout=timeout)


class NoPrerequisites(Enum):
pass


class OnePrereq(Enum):
one = auto()

Expand Down Expand Up @@ -187,13 +191,60 @@ def test_reregister_duplicates():
ti.register_tasks((2, ))


def test_empty_enum():
@pytest.mark.asyncio
async def test_no_prereq_tasks():
ti = OrderedTaskPreparation(NoPrerequisites, identity, lambda x: x - 1)
ti.set_finished_dependency(1)
ti.register_tasks((2, 3))

# with no prerequisites, tasks are *immediately* finished, as long as they are in order
finished = await wait(ti.ready_tasks())
assert finished == (2, 3)


@pytest.mark.asyncio
async def test_register_out_of_order():
ti = OrderedTaskPreparation(OnePrereq, identity, lambda x: x - 1, accept_dangling_tasks=True)
ti.set_finished_dependency(1)
ti.register_tasks((4, 5))
ti.finish_prereq(OnePrereq.one, (4, 5))

try:
finished = await wait(ti.ready_tasks())
except asyncio.TimeoutError:
pass
else:
assert False, f"No steps should be ready, but got {finished!r}"

ti.register_tasks((2, 3))
ti.finish_prereq(OnePrereq.one, (2, 3))
finished = await wait(ti.ready_tasks())
assert finished == (2, 3, 4, 5)


@pytest.mark.asyncio
async def test_no_prereq_tasks_out_of_order():
ti = OrderedTaskPreparation(
NoPrerequisites,
identity,
lambda x: x - 1,
accept_dangling_tasks=True,
)
ti.set_finished_dependency(1)
ti.register_tasks((4, 5))

class NoPrerequisites(Enum):
try:
finished = await wait(ti.ready_tasks())
except asyncio.TimeoutError:
pass
else:
assert False, f"No steps should be ready, but got {finished!r}"

with pytest.raises(ValidationError):
OrderedTaskPreparation(NoPrerequisites, identity, lambda x: x - 1)
ti.register_tasks((2, 3))

# with no prerequisites, tasks are *immediately* finished, as long as they are in order
finished = await wait(ti.ready_tasks())
assert finished == (2, 3, 4, 5)


@pytest.mark.asyncio
Expand Down
153 changes: 113 additions & 40 deletions trinity/utils/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
)
from collections import defaultdict
from enum import Enum
from functools import total_ordering
from itertools import count
from functools import (
total_ordering,
)
from itertools import (
count,
repeat,
)
from operator import attrgetter
from typing import (
Any,
Callable,
Expand All @@ -27,8 +33,15 @@
to_tuple,
)
from eth_utils.toolz import (
compose,
concat,
curry,
do,
identity,
iterate,
mapcat,
nth,
pipe,
)

from trinity.utils.queues import (
Expand Down Expand Up @@ -291,10 +304,6 @@ class BaseTaskPrerequisites(Generic[TTask, TPrerequisite]):

@classmethod
def from_enum(cls, prereqs: Type[TPrerequisite]) -> 'Type[BaseTaskPrerequisites[Any, Any]]':

if len(prereqs) < 1:
raise ValidationError("There must be at least one prerequisite to track completions")

return type('CompletionFor' + prereqs.__name__, (cls, ), dict(_prereqs=prereqs))

def __init__(self, task: TTask) -> None:
Expand Down Expand Up @@ -425,7 +434,7 @@ class BlockDownloads(Enum):
- prerequisites: all these must be completed for a task to be ready
(a necessary but not sufficient condition)
- ready: a task is ready after all its prereqs are completed, and the task it depends on is
also ready
also ready. The initial ready task is set with :meth:`set_finished_dependency`
"""
# methods to extract the id and dependency IDs out of a task
_id_of: StaticMethod[Callable[[TTask], TTaskID]]
Expand All @@ -441,18 +450,22 @@ def __init__(
prerequisites: Type[TPrerequisite],
id_extractor: Callable[[TTask], TTaskID],
dependency_extractor: Callable[[TTask], TTaskID],
accept_dangling_tasks: bool = False,
max_depth: int = None) -> None:

self._prereq_tracker = BaseTaskPrerequisites.from_enum(prerequisites)
self._id_of = id_extractor
self._dependency_of = dependency_extractor
self._oldest_depth = 0
self._accept_dangling_tasks = accept_dangling_tasks

# how long to wait before pruning
if max_depth is None:
self._max_depth = self._default_max_depth
elif max_depth < 0:
raise ValidationError(f"The maximum depth must be at least 0, not {max_depth}")
else:
self._max_depth = min([self._default_max_depth, max_depth])
self._max_depth = max_depth

# all of the tasks that have been completed, and not pruned
self._tasks: Dict[TTaskID, BaseTaskPrerequisites[TTask, TPrerequisite]] = {}
Expand All @@ -471,14 +484,20 @@ def __init__(
# They wait in this Queue until being returned by ready_tasks().
self._ready_tasks: 'Queue[TTask]' = Queue()

# Track the depth from the original task at 0 to n dependencies away
# This is used exclusively for pruning
self._depths: Dict[TTaskID, int] = {}
# Declared finished with set_finished_dependency()
self._declared_finished: Set[TTaskID] = set()

def set_finished_dependency(self, finished_task: TTask) -> None:
"""
Mark this task as already finished. Any task being registered in
:meth:`register_tasks` must have dependencies that are finished.
Mark this task as already finished. This is a bootstrapping method. In general,
tasks are marked as finished by :meth:`finish_prereq`. But how do we know which task is
first, and that its dependency is complete? We call `set_finished_dependency`.

Since a task can only become ready when its dependent
task is ready, the first result from ready_tasks will be dependent on
finished_task set in this method. (More precisely, it will be dependent on *one of*
the ``finished_task`` objects set with this method, since the method may be called
multiple times)
"""
completed = self._prereq_tracker(finished_task)
completed.set_complete()
Expand All @@ -490,16 +509,16 @@ def set_finished_dependency(self, finished_task: TTask) -> None:
(finished_task, ),
)
self._tasks[task_id] = completed
if len(self._depths):
self._depths[task_id] = max(self._depths.values())
else:
self._depths[task_id] = 0
self._declared_finished.add(task_id)
# note that this task is intentionally *not* added to self._unready

def register_tasks(self, tasks: Tuple[TTask, ...]) -> None:
"""
Initiate a task into tracking. Each task must be registered *after* its dependency has
been registered.
Initiate a task into tracking. By default, each task must be registered
*after* its dependency has been registered.

If you want to be able to register non-contiguous tasks, you can
initialize this intance with: ``accept_dangling_tasks=True``.

:param tasks: the tasks to register, in iteration order
"""
Expand All @@ -520,7 +539,7 @@ def register_tasks(self, tasks: Tuple[TTask, ...]) -> None:
)

for prereq_tracker, task_id, dependency_id in task_meta_info:
if dependency_id not in self._tasks:
if not self._accept_dangling_tasks and dependency_id not in self._tasks:
raise MissingDependency(
f"Cannot prepare task {prereq_tracker!r} with id {task_id} and "
f"dependency {dependency_id} before preparing its dependency"
Expand All @@ -529,8 +548,10 @@ def register_tasks(self, tasks: Tuple[TTask, ...]) -> None:
self._tasks[task_id] = prereq_tracker
self._unready.add(task_id)
self._dependencies[dependency_id].add(task_id)
depth = self._depths[dependency_id] + 1
self._depths[task_id] = depth

if prereq_tracker.is_complete and self._is_ready(prereq_tracker.task):
# this is possible for tasks with 0 prerequisites (useful for pure ordering)
self._mark_complete(task_id)

def finish_prereq(self, prereq: TPrerequisite, tasks: Tuple[TTask, ...]) -> None:
"""For every task in tasks, mark the given prerequisite as completed"""
Expand All @@ -548,7 +569,7 @@ def finish_prereq(self, prereq: TPrerequisite, tasks: Tuple[TTask, ...]) -> None

task_completion = self._tasks[task_id]
task_completion.finish(prereq)
if task_completion.is_complete and self._dependency_of(task) not in self._unready:
if task_completion.is_complete and self._is_ready(task):
self._mark_complete(task_id)

async def ready_tasks(self) -> Tuple[TTask, ...]:
Expand All @@ -558,6 +579,17 @@ async def ready_tasks(self) -> Tuple[TTask, ...]:
"""
return await queue_get_batch(self._ready_tasks)

def _is_ready(self, task: TTask) -> bool:
dependency = self._dependency_of(task)
if dependency in self._declared_finished:
# Ready by declaration
return True
elif dependency in self._tasks and dependency not in self._unready:
# Ready by insertion and tracked completion
return True
else:
return False

def _mark_complete(self, task_id: TTaskID) -> None:
qualified_tasks = tuple([task_id])
while qualified_tasks:
Expand All @@ -582,34 +614,75 @@ def _mark_one_task_complete(self, task_id: TTaskID) -> Generator[TTaskID, None,
self._unready.remove(task_id)

# prune any completed tasks that are too old
self._prune(task_id)
self._prune_finished(task_id)

# resolve tasks that depend on this task
for depending_task_id in self._dependencies[task_id]:
# we already know that this task is ready, so we only need to check completion
if self._tasks[depending_task_id].is_complete:
yield depending_task_id

def _prune(self, task_id: TTaskID) -> None:
def _prune_finished(self, task_id: TTaskID) -> None:
"""
This prunes any data starting at the task completed at task_completion, and older.
This prunes any data starting more than _max_depth in history.
It is called when the task becomes ready.
"""
# determine how far back to prune
finished_depth = self._depths[task_id]
try:
oldest_id = self._find_oldest_unpruned_task_id(task_id)
except ValidationError:
# No tasks are old enough to prune, can end immediately
return

root_id, depth = self._find_root(oldest_id)
unpruned = self._prune_forward(root_id, depth)
if oldest_id not in unpruned:
raise ValidationError(
f"Expected {oldest_id} to be in {unpruned!r}, something went wrong during pruning."
)

prune_depth = finished_depth - self._max_depth
if prune_depth > self._oldest_depth:
def _validate_has_task(self, task_id: TTaskID) -> None:
if task_id not in self._tasks:
raise ValidationError(f"No task {task_id} is present")

for depth in range(self._oldest_depth, prune_depth):
prune_tasks = tuple(
task_id for task_id in self._tasks.keys()
if self._depths[task_id] == depth
)
def _find_oldest_unpruned_task_id(self, finished_task_id: TTaskID) -> TTaskID:
get_dependency_of_id = compose(
curry(do)(self._validate_has_task),
self._dependency_of,
attrgetter('task'),
self._tasks.get,
)
ancestors = iterate(get_dependency_of_id, finished_task_id)
return nth(self._max_depth, ancestors)

for prune_task_id in prune_tasks:
del self._tasks[prune_task_id]
del self._depths[prune_task_id]
self._dependencies.pop(prune_task_id, None)
def _find_root(self, task_id: TTaskID) -> Tuple[TTaskID, int]:
"""
return the oldest root, and the depth to it from the seed task
"""
root_candidate = task_id
get_dependency_of_id = compose(self._dependency_of, attrgetter('task'), self._tasks.get)
# We'll use the maximum saved history (_max_depth) to cap how long the stale cache
# of history might get, when pruning. Increasing the cap should not be a problem, if needed.
for depth in range(0, self._max_depth):
dependency = get_dependency_of_id(root_candidate)
if dependency not in self._tasks:
return root_candidate, depth
else:
root_candidate = dependency
raise ValidationError(
f"Stale task history too long ({depth}) before pruning. {dependency} is still in cache."
)

self._oldest_depth = prune_depth
def _prune_forward(self, root_id: TTaskID, depth: int) -> Tuple[TTaskID]:
"""
Prune all forks forward from the root
"""
def prune_parent(prune_task_id: TTaskID) -> Set[TTaskID]:
children = self._dependencies.pop(prune_task_id, set())
del self._tasks[prune_task_id]
if prune_task_id in self._declared_finished:
self._declared_finished.remove(prune_task_id)
return children

prune_parent_list = compose(tuple, curry(mapcat)(prune_parent))
prune_trunk = repeat(prune_parent_list, depth)
return pipe((root_id, ), *prune_trunk)