From 8aa264e7744b22179d64752c4aea655be490cba9 Mon Sep 17 00:00:00 2001 From: H Gazula Date: Fri, 8 Mar 2024 07:58:45 -0500 Subject: [PATCH] resolved https://github.com/neuronets/nobrainer/issues/285 --- nobrainer/processing/segmentation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 66e1768d..48729749 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -39,6 +39,7 @@ def fit( opt_args=None, loss=losses.dice, metrics=metrics.dice, + callbacks=None ): """Train a segmentation model""" # TODO: check validity of datasets @@ -82,7 +83,12 @@ def _compile(): _compile() self.model_.summary() - callbacks = [] + if callbacks is not None and not isinstance(callbacks, list): + raise AttributeError('Callbacks must be either of type list or None') + + if callbacks is None: + callbacks = [] + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit(