Skip to content

Commit

Permalink
add tests for u-net models and resolved #291
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Mar 11, 2024
1 parent 9d55484 commit cff8368
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
File renamed without changes.
4 changes: 4 additions & 0 deletions nobrainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .attention_unet import attention_unet
from .attention_unet_with_inception import attention_unet_with_inception
from .autoencoder import autoencoder
from .dcgan import dcgan
from .highresnet import highresnet
Expand Down Expand Up @@ -29,6 +31,8 @@ def get(name):
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
}

try:
Expand Down
4 changes: 2 additions & 2 deletions nobrainer/models/attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""

import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras import layers
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model


Expand Down Expand Up @@ -71,7 +71,7 @@ def attention_unet(n_classes, input_shape):
outputs = layers.Activation(final_activation)(outputs)

""" Model """
return Model(inputs=inputs, outputs=outputs, name="Attention U-Net")
return Model(inputs=inputs, outputs=outputs, name="Attention_U-Net")


if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions nobrainer/models/attention_unet_with_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
Adapted from https://github.com/robinvvinod/unet
"""

import tensorflow.keras.backend as K
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model

from hyperparameters import alpha

K.set_image_data_format("channels_last")


Expand Down Expand Up @@ -44,7 +42,7 @@ def conv3d_block(
)(input_tensor)
if batchnorm:
conv = layers.BatchNormalization()(conv)
output = layers.LeakyReLU(alpha=alpha)(conv)
output = layers.LeakyReLU(alpha=0.1)(conv)

for _ in range(recurrent - 1):
conv = layers.Conv3D(
Expand All @@ -57,7 +55,7 @@ def conv3d_block(
)(output)
if batchnorm:
conv = layers.BatchNormalization()(conv)
res = layers.LeakyReLU(alpha=alpha)(conv)
res = layers.LeakyReLU(alpha=0.1)(conv)
output = layers.Add()([output, res])

return output
Expand Down Expand Up @@ -141,7 +139,7 @@ def transpose_block(
),
kernel_initializer="he_normal",
)(input_tensor)
conv = layers.LeakyReLU(alpha=alpha)(conv)
conv = layers.LeakyReLU(alpha=0.1)(conv)

act = conv3d_block(
conv,
Expand Down
12 changes: 12 additions & 0 deletions nobrainer/models/tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from nobrainer.bayesian_utils import default_mean_field_normal_fn

from ..attention_unet import attention_unet
from ..attention_unet_with_inception import attention_unet_with_inception
from ..autoencoder import autoencoder
from ..bayesian_vnet import bayesian_vnet
from ..bayesian_vnet_semi import bayesian_vnet_semi
Expand Down Expand Up @@ -241,3 +243,13 @@ def test_vox2vox():
pred_shape = (1, 2, 2, 2, 1)
out = vox_discriminator(inputs=[y, x])
assert out.shape == pred_shape


def test_attention_unet():
model_test(attention_unet, n_classes=1, input_shape=(1, 64, 64, 64, 1))


def test_attention_unet_with_inception():
model_test(
attention_unet_with_inception, n_classes=1, input_shape=(1, 64, 64, 64, 1)
)

0 comments on commit cff8368

Please sign in to comment.