diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index d03fb7bb..14e696ab 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -8,7 +8,7 @@ from .progressivegan import progressivegan from .unet import unet from .unetr import unetr - +from .bayesian_meshnet import variational_meshnet def get(name): """Return callable that creates a particular `tf.keras.Model`. @@ -35,6 +35,7 @@ def get(name): "attention_unet": attention_unet, "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, + "variational_meshnet": variational_meshnet, } try: diff --git a/nobrainer/models/tests/models_test.py b/nobrainer/models/tests/models_test.py index 3daeda40..afd58c62 100644 --- a/nobrainer/models/tests/models_test.py +++ b/nobrainer/models/tests/models_test.py @@ -19,6 +19,7 @@ from ..unetr import unetr from ..vnet import vnet from ..vox2vox import Vox_ensembler, vox_gan +from ..bayesian_meshnet import variational_meshnet def model_test(model_cls, n_classes, input_shape, kwds={}): @@ -258,3 +259,7 @@ def test_attention_unet_with_inception(): def test_unetr(): model_test(unetr, n_classes=1, input_shape=(1, 96, 96, 96, 1)) + + +def test_variational_meshnet(): + model_test(variational_meshnet, n_classes=1, input_shape=(1, 128, 128, 128, 1))