diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 48729749..faf1a7b7 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -39,7 +39,7 @@ def fit( opt_args=None, loss=losses.dice, metrics=metrics.dice, - callbacks=None + callbacks=None, ): """Train a segmentation model""" # TODO: check validity of datasets @@ -84,11 +84,11 @@ def _compile(): self.model_.summary() if callbacks is not None and not isinstance(callbacks, list): - raise AttributeError('Callbacks must be either of type list or None') - + raise AttributeError("Callbacks must be either of type list or None") + if callbacks is None: callbacks = [] - + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit(