diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 66e1768d..48729749 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -39,6 +39,7 @@ def fit( opt_args=None, loss=losses.dice, metrics=metrics.dice, + callbacks=None ): """Train a segmentation model""" # TODO: check validity of datasets @@ -82,7 +83,12 @@ def _compile(): _compile() self.model_.summary() - callbacks = [] + if callbacks is not None and not isinstance(callbacks, list): + raise AttributeError('Callbacks must be either of type list or None') + + if callbacks is None: + callbacks = [] + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit(