From 12eb8221d2174bdff3cb19a435e6175b9738f515 Mon Sep 17 00:00:00 2001 From: Jordi Aranda Date: Wed, 23 Mar 2022 16:48:42 +0100 Subject: [PATCH 1/3] enable batch fetching in advance --- petastorm/pytorch.py | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/petastorm/pytorch.py b/petastorm/pytorch.py index 2550b408..8a15f826 100644 --- a/petastorm/pytorch.py +++ b/petastorm/pytorch.py @@ -16,7 +16,9 @@ import decimal # Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst import re +import threading import logging +from queue import Queue import numpy as np from six import PY2 from torch.utils.data.dataloader import default_collate @@ -100,11 +102,36 @@ def decimal_friendly_collate(batch): loader." +class BackgroundIterator(threading.Thread): + """Prefetch iterator results.""" + def __init__(self, iterator, prefetch=1000): + threading.Thread.__init__(self) + self.queue = Queue(prefetch) + self.iterator = iterator + self.daemon = True + self.start() + + def run(self): + for item in self.iterator: + self.queue.put(item) + self.queue.put(None) + + def __iter__(self): + return self + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + class LoaderBase(object): def __init__(self): self._in_iter = None self._error = None + self._max_prefetch = 1 def __iter__(self): if self._error is not None: @@ -118,7 +145,10 @@ def __iter__(self): self._in_iter = True try: - for batch in self._iter_impl(): + iterator = self._iter_impl() + if self._max_prefetch > 1: + iterator = BackgroundIterator(iterator, prefetch=self.max_prefetch) + for batch in iterator: yield batch except Exception as e: self._error = e @@ -264,7 +294,8 @@ class BatchedDataLoader(LoaderBase): def __init__(self, reader, batch_size=1, transform_fn=None, - shuffling_queue_capacity=0): + shuffling_queue_capacity=0, + batch_max_prefetch=None): """ Initializes a data loader object. @@ -287,6 +318,9 @@ def __init__(self, reader, batch_size=1, :param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors :param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue` instance. If set to 0, no shuffling will be done. + :param batch_max_prefetch: an optional int indicating maximum number of batches to fetch in + advance. This is specially useful when training models in order to improve model data + throughput. """ super(BatchedDataLoader, self).__init__() self.reader = reader @@ -298,6 +332,11 @@ def __init__(self, reader, batch_size=1, self.shuffling_queue_capacity = shuffling_queue_capacity self._in_iter = None + # fetch batches in advance? + if batch_max_prefetch is not None: + assert batch_max_prefetch > 0, "if set, batch_max_prefetch must be greater or equal to 1" + self._max_prefetch = batch_max_prefetch + def _iter_impl(self): """ The Data Loader iterator stops the for-loop when reader runs out of samples. From aae2993f2d9024528385bd7b717e57796b0dd26d Mon Sep 17 00:00:00 2001 From: Jordi Aranda Date: Wed, 23 Mar 2022 18:11:33 +0100 Subject: [PATCH 2/3] improve docstrings --- petastorm/pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/petastorm/pytorch.py b/petastorm/pytorch.py index 8a15f826..68588a5b 100644 --- a/petastorm/pytorch.py +++ b/petastorm/pytorch.py @@ -103,7 +103,9 @@ def decimal_friendly_collate(batch): class BackgroundIterator(threading.Thread): - """Prefetch iterator results.""" + """Prefetch iterator results. A thread iterates the original iterator and + populates a queue. Iterating over this background iterator just consumes the underlying + queue until no other result is available.""" def __init__(self, iterator, prefetch=1000): threading.Thread.__init__(self) self.queue = Queue(prefetch) From 7234d81ba20ec0c170ed1dc3a55d2e42a01cb8ba Mon Sep 17 00:00:00 2001 From: Jordi Aranda Date: Mon, 28 Mar 2022 12:30:33 +0200 Subject: [PATCH 3/3] add tests --- petastorm/pytorch.py | 41 ++++++++++++---------- petastorm/tests/test_pytorch_dataloader.py | 41 +++++++++++++++++++++- 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/petastorm/pytorch.py b/petastorm/pytorch.py index 68588a5b..3fb0ce2f 100644 --- a/petastorm/pytorch.py +++ b/petastorm/pytorch.py @@ -106,17 +106,21 @@ class BackgroundIterator(threading.Thread): """Prefetch iterator results. A thread iterates the original iterator and populates a queue. Iterating over this background iterator just consumes the underlying queue until no other result is available.""" - def __init__(self, iterator, prefetch=1000): + def __init__(self, iterator, queue_size=1000): threading.Thread.__init__(self) - self.queue = Queue(prefetch) + self.name = "background_iterator" + self.queue = Queue(queue_size) self.iterator = iterator - self.daemon = True + self.stop = threading.Event() self.start() def run(self): - for item in self.iterator: - self.queue.put(item) - self.queue.put(None) + while not self.stop.isSet(): + for item in self.iterator: + self.queue.put(item) + self.queue.put(None) + self.stop.set() + return def __iter__(self): return self @@ -133,7 +137,7 @@ class LoaderBase(object): def __init__(self): self._in_iter = None self._error = None - self._max_prefetch = 1 + self._queue_size = 1 def __iter__(self): if self._error is not None: @@ -146,10 +150,10 @@ def __iter__(self): logger.warning('Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.') self._in_iter = True + iterator = self._iter_impl() try: - iterator = self._iter_impl() - if self._max_prefetch > 1: - iterator = BackgroundIterator(iterator, prefetch=self.max_prefetch) + if self._queue_size > 1: + iterator = BackgroundIterator(iterator, queue_size=self._queue_size) for batch in iterator: yield batch except Exception as e: @@ -158,6 +162,8 @@ def __iter__(self): raise finally: self._in_iter = False + if isinstance(iterator, BackgroundIterator): + iterator.stop.set() class DataLoader(LoaderBase): @@ -297,7 +303,7 @@ class BatchedDataLoader(LoaderBase): def __init__(self, reader, batch_size=1, transform_fn=None, shuffling_queue_capacity=0, - batch_max_prefetch=None): + batch_queue_size=None): """ Initializes a data loader object. @@ -320,9 +326,8 @@ def __init__(self, reader, batch_size=1, :param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors :param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue` instance. If set to 0, no shuffling will be done. - :param batch_max_prefetch: an optional int indicating maximum number of batches to fetch in - advance. This is specially useful when training models in order to improve model data - throughput. + :param batch_queue_size: an optional int indicating maximum number of batches to fetch in + parallel. This might be useful when training models in order to improve model data throughput. """ super(BatchedDataLoader, self).__init__() self.reader = reader @@ -334,10 +339,10 @@ def __init__(self, reader, batch_size=1, self.shuffling_queue_capacity = shuffling_queue_capacity self._in_iter = None - # fetch batches in advance? - if batch_max_prefetch is not None: - assert batch_max_prefetch > 0, "if set, batch_max_prefetch must be greater or equal to 1" - self._max_prefetch = batch_max_prefetch + # fetch batches in parallel? + if batch_queue_size is not None: + assert batch_queue_size > 0, "if set, batch_queue_size must be greater or equal to 1" + self._queue_size = batch_queue_size def _iter_impl(self): """ diff --git a/petastorm/tests/test_pytorch_dataloader.py b/petastorm/tests/test_pytorch_dataloader.py index 384a695e..3f27a8ce 100644 --- a/petastorm/tests/test_pytorch_dataloader.py +++ b/petastorm/tests/test_pytorch_dataloader.py @@ -1,5 +1,7 @@ from decimal import Decimal from packaging import version +import time +import threading import numpy as np import pyarrow as pa @@ -10,7 +12,8 @@ from petastorm import make_reader, TransformSpec, make_batch_reader from petastorm.pytorch import (_sanitize_pytorch_types, DataLoader, BatchedDataLoader, decimal_friendly_collate, - InMemBatchedDataLoader, _load_rows_into_mem) + InMemBatchedDataLoader, _load_rows_into_mem, + BackgroundIterator) from petastorm.tests.test_common import TestSchema BASIC_DATA_LOADERS = [DataLoader, BatchedDataLoader] @@ -331,3 +334,39 @@ def test_inmem_batched_dataloader_shuffle_per_epoch(synthetic_dataset, reader_fa with pytest.raises(StopIteration): next(it) + + +def test_background_iterator(): + # number of iterator elements + N = int(1e6) + + # wait some time for the queue to be filled + bit = BackgroundIterator(range(N), queue_size=1000) + time.sleep(1) + + assert not bit.queue.empty() + # ensure the thread exists and is populating the queue + assert "background_iterator" in [t.name for t in threading.enumerate()] + # ensure we process the same number of elements present in original iterator + n = 0 + for _ in bit: + n += 1 + assert n == N + # ensure the thread stopped when original iterator was completed + time.sleep(1) + assert "background_iterator" not in [t.name for t in threading.enumerate()] + + +@pytest.mark.parametrize('reader_factory', ALL_READER_FLAVOR_FACTORIES) +def test_batched_dataloader_background_iterator_handle_exception(synthetic_dataset, reader_factory): + with BatchedDataLoader(reader_factory(synthetic_dataset.url, schema_fields=TORCH_BATCHABLE_FIELDS, + transform_spec=TransformSpec(_sensor_name_to_int)), + batch_queue_size=100) as loader: + try: + for _ in loader: + assert "background_iterator" in [t.name for t in threading.enumerate()] + raise RuntimeError() + except RuntimeError: + # ensure we wait enough for the thread to be stopped + time.sleep(1) + assert "background_iterator" not in [t.name for t in threading.enumerate()]