Skip to content

Commit

Permalink
fix: threaded task is_current() will return False after cancelling.
Browse files Browse the repository at this point in the history
This was the case for async tasks, not yet for threaded tasks.
  • Loading branch information
maartenbreddels committed Dec 20, 2024
1 parent a5103a7 commit f96f8b6
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion solara/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def __init__(self, function: Callable[P, R], key: str):
self.__qualname__ = function.__qualname__
self.function = function
self.lock = threading.Lock()
self._local = threading.local()

def cancel(self) -> None:
if self._cancel:
Expand Down Expand Up @@ -343,12 +344,16 @@ def cancel():
current_thread.start()

def is_current(self):
cancel_event = getattr(self._local, "cancel_event", None)
if cancel_event is not None and cancel_event.is_set():
return False
return self._current_thread == threading.current_thread()

def _run(self, _last_finished_event, previous_thread: Optional[threading.Thread], cancel_event, args, kwargs) -> None:
# use_thread has this as default, which can make code run 10x slower
intrusive_cancel = False
wait_on_previous = False
self._local.cancel_event = cancel_event

def runner():
if wait_on_previous:
Expand Down Expand Up @@ -405,7 +410,7 @@ def runner():
# this means this thread is cancelled not be request, but because
# a new thread is running, we can ignore this
finally:
if self.is_current():
if self._current_thread == threading.current_thread():
self.running_thread = None
logger.info("thread done!")
if cancel_event.is_set():
Expand Down

0 comments on commit f96f8b6

Please sign in to comment.