diff --git a/batchgenerators/dataloading/multi_threaded_augmenter.py b/batchgenerators/dataloading/multi_threaded_augmenter.py index 6006fbe..5461adb 100755 --- a/batchgenerators/dataloading/multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/multi_threaded_augmenter.py @@ -21,6 +21,7 @@ import numpy as np import sys import logging +import signal from multiprocessing import Event from time import sleep, time from threadpoolctl import threadpool_limits @@ -32,6 +33,9 @@ def producer(queue, data_loader, transform, thread_id, seed, abort_event, wait_time: float = 0.02): + # Restore default SIGTERM handler to terminate the process + signal.signal(signal.SIGTERM, signal.SIG_DFL) + np.random.seed(seed) data_loader.set_thread_id(thread_id) item = None diff --git a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py index 530eba1..d19bedf 100755 --- a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py @@ -23,6 +23,7 @@ from queue import Queue as thrQueue import numpy as np import logging +import signal from multiprocessing import Event from time import sleep, time @@ -37,6 +38,9 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed, abort_event: Event, wait_time: float = 0.02): + # Restore default SIGTERM handler to terminate the process + signal.signal(signal.SIGTERM, signal.SIG_DFL) + # the producer will set the abort event if something happens with threadpool_limits(1, None): np.random.seed(seed)