Skip to content

Commit

Permalink
#10: State default inference partial completion.
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesArruda committed Feb 11, 2025
1 parent 51cdbca commit 8ab6f9a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/upstage_des/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,16 @@ def __set_name__(self, owner: "Actor", name: str) -> None:
def _infer_state(self, instance: "Actor") -> tuple[Any, ...]:
"""Infer types for the state.
This should allow isinstance(value, self._infer_state(instance))
Args:
instance (Actor): The actor the state is attached to.
Returns:
le[Any,...]: The state type
tuple[Any,...]: The state type
"""
state_class = instance._state_defs[self.name]
args = get_args(state_class.__orig_class__)
args = get_args(state_class.__orig_class__) # type: ignore [attr-defined]
return args

def _set_default(self, instance: "Actor") -> None:
Expand Down
16 changes: 16 additions & 0 deletions src/upstage_des/test/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,19 @@ class Worker(UP.Actor):
assert state_name is not None
value = getattr(worker, state_name)
assert value is worker.walkie, "Wrong state retrieved"


def test_type_inference() -> None:
class A(UP.Actor):
st = UP.State[int | float]()

with UP.EnvironmentContext():
a = A(name="hi", st=1)
print(a.st)

v = a._state_defs["st"]._infer_state(a)
print(v)


if __name__ == "__main__":
test_type_inference()

0 comments on commit 8ab6f9a

Please sign in to comment.