Skip to content

Commit

Permalink
Enable deferred invocation (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomperez98 authored Jan 21, 2025
1 parent 1766ec1 commit 321dffd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
50 changes: 34 additions & 16 deletions src/resonate/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,27 +376,28 @@ def _handle_notify(self, notify: Notify) -> list[Command]:
def _handle_continue(
self, id: str, next_value: Result[Any, Exception] | None
) -> list[Command]:
loopback: list[Command]
record = self._records[id]
coro = record.get_coro()
yielded_value = coro.advance(next_value)

if isinstance(yielded_value, LFI):
return self._process_lfi(record, yielded_value)
if isinstance(yielded_value, LFC):
return self._process_lfc(record, yielded_value)
if isinstance(yielded_value, RFI):
return self._process_rfi(record, yielded_value)
if isinstance(yielded_value, RFC):
return self._process_rfc(record, yielded_value)
if isinstance(yielded_value, Promise):
return self._process_promise(record, yielded_value)
if isinstance(yielded_value, FinalValue):
return self._process_final_value(record, yielded_value.v)
if isinstance(yielded_value, DI):
# start execution from the top. Add current record to runnable
raise NotImplementedError

assert_never(yielded_value)
loopback = self._process_lfi(record, yielded_value)
elif isinstance(yielded_value, LFC):
loopback = self._process_lfc(record, yielded_value)
elif isinstance(yielded_value, RFI):
loopback = self._process_rfi(record, yielded_value)
elif isinstance(yielded_value, RFC):
loopback = self._process_rfc(record, yielded_value)
elif isinstance(yielded_value, Promise):
loopback = self._process_promise(record, yielded_value)
elif isinstance(yielded_value, FinalValue):
loopback = self._process_final_value(record, yielded_value.v)
elif isinstance(yielded_value, DI):
loopback = self._process_deferred(record, yielded_value)
else:
assert_never(yielded_value)
return loopback

def _process_invoke_msg(
self, invoke_msg: InvokeMsg, task: TaskRecord
Expand Down Expand Up @@ -791,6 +792,23 @@ def _process_final_value(
) -> list[Command]:
return [Complete(record.id, final_value)]

def _process_deferred(self, record: Record[Any], deferred: DI) -> list[Command]:
loopback = self._handle_fork_or_join(
ForkOrJoin(
deferred.id,
Handle[Any](deferred.id),
Invocation(
deferred.unit.fn,
*deferred.unit.args,
**deferred.unit.kwargs,
),
)
)
loopback.extend(
self._handle_continue(record.id, next_value=Ok(Promise[Any](deferred.id)))
)
return loopback

def _get_info_from_rfi(self, rfi: RFI) -> tuple[Data, Headers, Tags, int | None]:
data: Data
tags: Tags
Expand Down
28 changes: 28 additions & 0 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,31 @@ def foo_sleep(ctx: Context, n: int) -> Generator[Yieldable, Any, int]:
p = foo_sleep.run(f"{group}-{n}", n)
assert p.result() == n
s.stop()


@pytest.mark.skipif(
os.getenv("RESONATE_STORE_URL") is None, reason="env variable is not set"
)
def test_golden_device_deferred() -> None:
group = "test-golden-device-deferred"

def foo_golden_device_deferred(
ctx: Context, n: str
) -> Generator[Yieldable, Any, str]:
p: Promise[str] = yield ctx.deferred("bar", bar_golden_device_deferred, n)
v: str = yield p
return v

def bar_golden_device_deferred(ctx: Context, n: str) -> str: # noqa: ARG001
return n

resonate = Resonate(
store=RemoteStore(url=os.environ["RESONATE_STORE_URL"]),
task_source=Poller("http://localhost:8002", group=group),
)
resonate.register(foo_golden_device_deferred)
resonate.register(bar_golden_device_deferred)
p: Handle[str] = resonate.run(f"{group}-foo", foo_golden_device_deferred, "hi")
assert isinstance(p, Handle)
assert p.result() == "hi"
resonate.stop()

0 comments on commit 321dffd

Please sign in to comment.