Skip to content

Commit

Permalink
fix: make a task engine job stoppable
Browse files Browse the repository at this point in the history
  • Loading branch information
chisholm committed Apr 6, 2024
1 parent a886786 commit e2005bf
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/dioptra/rq/tasks/run_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
DIOPTRA_JOB_ID,
DIOPTRA_QUEUE,
)
from dioptra.rq.tasks.run_task_engine_stoppable import run_experiment_stoppable
from dioptra.sdk.utilities.paths import set_cwd
from dioptra.task_engine.task_engine import run_experiment
from dioptra.task_engine.validation import is_valid
from dioptra.worker.s3_download import s3_download

Expand Down Expand Up @@ -65,6 +65,7 @@ def run_task_engine_task(
normally used, but useful in unit tests when you need a specially
configured object with stubbed responses.
"""

rq_job = get_current_job()
rq_job_id = rq_job.get_id() if rq_job else None

Expand Down Expand Up @@ -147,11 +148,17 @@ def _run_experiment(

db_client.update_job_status(rq_job_id, "started")

run_experiment(experiment_desc, global_parameters)
was_stopped = run_experiment_stoppable(experiment_desc, global_parameters)

log.info("=== Run succeeded ===")
mlflow.end_run()
db_client.update_job_status(rq_job_id, "finished")
if was_stopped:
log.info("=== Run stopped ===")
mlflow.end_run("KILLED")
# We don't have a job status value for "stopped" or "killed"...
db_client.update_job_status(rq_job_id, "failed")
else:
log.info("=== Run succeeded ===")
mlflow.end_run()
db_client.update_job_status(rq_job_id, "finished")

except Exception:
mlflow.end_run("FAILED")
Expand Down
151 changes: 151 additions & 0 deletions src/dioptra/rq/tasks/run_task_engine_stoppable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
import multiprocessing as mp
import signal
from typing import Any, Mapping, MutableMapping

import structlog

from dioptra.task_engine.task_engine import request_stop, run_experiment

# Which endpoint to poll
_POLL_URL = "http://dioptra-deployment-restapi:5000/something"


# Web server endpoint poll interval, in seconds
_POLL_INTERVAL = 3


def _get_logger() -> Any:
"""
Get a logger for this module.
Returns:
A logger object
"""
return structlog.get_logger(__name__)


def run_experiment_stoppable(
experiment_desc: Mapping[str, Any], global_parameters: MutableMapping[str, Any]
) -> bool:
"""
Run an experiment via the task engine. This implementation runs it in a
sub-process. The parent process will poll an endpoint for a shutdown
instruction which will cause us to stop the experiment early.
Args:
experiment_desc: A declarative experiment description, as a mapping
global_parameters: Global parameters for this run, as a mapping from
parameter name to value
Returns:
True if the process was stopped prematurely; False if not
"""
child_process = mp.Process(
target=_run_experiment_child_process, args=(experiment_desc, global_parameters)
)

child_process.start()

was_stopped = _monitor_process(child_process)

child_process.close()

return was_stopped


def _run_experiment_child_process(
experiment_desc: Mapping[str, Any], global_parameters: MutableMapping[str, Any]
) -> None:
"""
Simple wrapper around run_experiment() which arranges for SIGTERM to
request graceful termination of the experiment.
Args:
experiment_desc: A declarative experiment description, as a mapping
global_parameters: Global parameters for this run, as a mapping from
parameter name to value
"""

signal.signal(signal.SIGTERM, lambda *args: request_stop())

run_experiment(experiment_desc, global_parameters)


def _monitor_process(child_process: mp.Process) -> bool:
"""
Watch the given child process while polling for a shutdown request.
If shutdown is requested, shut down the child process early.
This function blocks until the child process terminates.
Args:
child_process: The child process to watch
Returns:
True if the process was stopped prematurely; False if not
"""
log = _get_logger()
log.debug("Monitoring task engine process: %d", child_process.pid)

should_stop = False
while child_process.is_alive():
should_stop = _should_stop()

if should_stop:
log.warning("Attempting to stop pid: %d", child_process.pid)
# Send a SIGTERM to attempt a graceful shutdown
child_process.terminate()

# Wait one poll interval to see if it stops. If not, forcibly
# kill it.
child_process.join(_POLL_INTERVAL)

# Docs describe checking .exitcode, not .is_alive().
if child_process.exitcode is None:
log.warning(
"Graceful shutdown failed; killing pid: %d", child_process.pid
)
child_process.kill()
child_process.join()

else:
# Wait until next poll
child_process.join(_POLL_INTERVAL)

return should_stop


def _should_stop() -> bool:
"""
Determine whether the current experiment should be stopped.
Returns:
True if it should be stopped; False if not
"""
# resp = requests.get(_POLL_URL)
# if resp.ok:
# # Depends on what the endpoint returns
# value = cast(bool, resp.json())
#
# else:
# log.warning("Polling endpoint returned http status: %d", resp.status_code)
# value = False

value = False
return value
21 changes: 21 additions & 0 deletions src/dioptra/task_engine/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@
)
from dioptra.task_engine import util

_stop_requested = False


def request_stop() -> None:
"""
Request graceful early termination of an experiment. This will occur
in between experiment steps; it can't interrupt a running step.
"""
global _stop_requested
# I don't think we need any concurrency protection for this simple thing,
# do we?
_stop_requested = True


def _get_logger() -> logging.Logger:
"""
Expand Down Expand Up @@ -450,6 +463,10 @@ def run_experiment(
global_parameters: External parameter values to use in the
experiment, as a dict
"""
# An external entity can set this flag to request a graceful shutdown.
# See request_stop().
global _stop_requested
_stop_requested = False

log = _get_logger()

Expand All @@ -475,6 +492,10 @@ def run_experiment(
log.debug("Step order:\n %s", "\n ".join(step_order))

for step_name in step_order:
if _stop_requested:
log.warning("Experiment aborted!")
break

try:
log.info("Running step: %s", step_name)

Expand Down

0 comments on commit e2005bf

Please sign in to comment.