Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Workflow to Allow IterableDataset Inputs #8263

Merged
merged 12 commits into from
Dec 19, 2024
Merged
20 changes: 10 additions & 10 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,24 @@ def __init__(
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
else:
super().__init__(self._iteration)
super().__init__(self._iteration if iteration_update is None else iteration_update)

if isinstance(data_loader, DataLoader):
sampler = data_loader.__dict__["sampler"]
sampler = getattr(data_loader, "sampler", None)

# set the epoch value for DistributedSampler objects when an epoch starts
if isinstance(sampler, DistributedSampler):

@self.on(Events.EPOCH_STARTED)
def set_sampler_epoch(engine: Engine) -> None:
sampler.set_epoch(engine.state.epoch)

# if the epoch_length isn't given, attempt to get it from the length of the data loader
if epoch_length is None:
epoch_length = len(data_loader)
else:
if epoch_length is None:
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
try:
epoch_length = len(data_loader)
except TypeError: # raised when data_loader is given an iterable dataset which has no length
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
pass # deliberately leave epoch_length as None

# set all sharable data for the workflow based on Ignite engine.state
self.state: Any = State(
Expand All @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
iteration=0,
epoch=0,
max_epochs=max_epochs,
epoch_length=epoch_length,
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
output=None,
batch=None,
metrics={},
Expand Down
13 changes: 13 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import nibabel as nib
import numpy as np
import torch.nn as nn

from monai.data import DataLoader, Dataset, IterableDataset
from monai.engines import SupervisedEvaluator
from monai.transforms import Compose, LoadImaged, SimulateDelayd


Expand Down Expand Up @@ -59,6 +61,17 @@ def test_shape(self):
for d in dataloader:
self.assertTupleEqual(d["image"].shape[1:], expected_shape)

def test_supervisedevaluator(self):
"""
Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.
"""
data = list(range(10))
dl = DataLoader(IterableDataset(data))
evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity())
evaluator.run() # fails if the epoch length or other internal setup is not done correctly

self.assertEqual(evaluator.state.iteration, len(data))


if __name__ == "__main__":
unittest.main()
Loading