diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py
index 791737fd..30c5be53 100644
--- a/nobrainer/dataset.py
+++ b/nobrainer/dataset.py
@@ -121,6 +121,9 @@ def from_tfrecords(
         )
         block_length = len([0 for _ in first_shard])
 
+        if not n_volumes:
+            n_volumes = block_length * len(files)
+
         dataset = dataset.interleave(
             map_func=lambda x: tf.data.TFRecordDataset(
                 x, compression_type=compression_type