diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 791737fd..30c5be53 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -121,6 +121,9 @@ def from_tfrecords( ) block_length = len([0 for _ in first_shard]) + if not n_volumes: + n_volumes = block_length * len(files) + dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( x, compression_type=compression_type