diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 5e0092b5..8dd75b44 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -1,3 +1,5 @@ +from pprint import pprint + from .attention_unet import attention_unet from .attention_unet_with_inception import attention_unet_with_inception from .autoencoder import autoencoder @@ -10,6 +12,22 @@ from .unet import unet from .unetr import unetr +__all__ = ["get", "list_available_models"] + +_models = { + "highresnet": highresnet, + "meshnet": meshnet, + "unet": unet, + "autoencoder": autoencoder, + "progressivegan": progressivegan, + "progressiveae": progressiveae, + "dcgan": dcgan, + "attention_unet": attention_unet, + "attention_unet_with_inception": attention_unet_with_inception, + "unetr": unetr, + "variational_meshnet": variational_meshnet, +} + def get(name): """Return callable that creates a particular `tf.keras.Model`. @@ -25,24 +43,18 @@ def get(name): if not isinstance(name, str): raise ValueError("Model name must be a string.") - models = { - "highresnet": highresnet, - "meshnet": meshnet, - "unet": unet, - "autoencoder": autoencoder, - "progressivegan": progressivegan, - "progressiveae": progressiveae, - "dcgan": dcgan, - "attention_unet": attention_unet, - "attention_unet_with_inception": attention_unet_with_inception, - "unetr": unetr, - "variational_meshnet": variational_meshnet, - } - try: - return models[name.lower()] + return _models[name.lower()] except KeyError: - avail = ", ".join(models.keys()) + avail = ", ".join(_models.keys()) raise ValueError( "Unknown model: '{}'. Available models are {}.".format(name, avail) ) + + +def available_models(): + return list(_models) + +def list_available_models(): + pprint(available_models()) + diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index c04ee906..f5a2f951 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -5,6 +5,7 @@ from .base import BaseEstimator from .. import losses, metrics +from ..models import available_models, list_available_models logging.getLogger().setLevel(logging.INFO) @@ -15,20 +16,38 @@ class Segmentation(BaseEstimator): state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"] def __init__( - self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=True + self, base_model=None, model_args=None, checkpoint_filepath=None, multi_gpu=True ): super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu) - if not isinstance(base_model, str): + if base_model is None: + print( + "No model specified. Please specify one using the 'add_model' method." + ) + self.base_model = None + elif not isinstance(base_model, str): self.base_model = base_model.__name__ else: self.base_model = base_model + + if self.base_model and self.base_model not in available_models(): + raise ValueError( + "Unknown model: '{}'. Available models are {}.".format( + self.base_model, available_models() + ) + ) + self.model_ = None self.model_args = model_args or {} self.block_shape_ = None self.volume_shape_ = None self.scalar_labels_ = None + def add_model(self, base_model, model_args=None): + """Add a segmentation model""" + self.base_model = base_model + self.model_args = model_args or {} + def fit( self, dataset_train, @@ -72,7 +91,7 @@ def _compile(): metrics=metrics, ) - if self.model is None: + if self.model_ is None: mod = importlib.import_module("..models", "nobrainer.processing") base_model = getattr(mod, self.base_model) if batch_size % self.strategy.num_replicas_in_sync: @@ -97,9 +116,9 @@ def _compile(): epochs=epochs, steps_per_epoch=dataset_train.get_steps_per_epoch(), validation_data=dataset_validate.dataset if dataset_validate else None, - validation_steps=dataset_validate.get_steps_per_epoch() - if dataset_validate - else None, + validation_steps=( + dataset_validate.get_steps_per_epoch() if dataset_validate else None + ), callbacks=callbacks, verbose=verbose, ) @@ -119,3 +138,6 @@ def predict(self, x, batch_size=1, normalizer=None): batch_size=batch_size, normalizer=normalizer, ) + + def list_available_models(self): + list_available_models()