diff --git a/example_use_train.py b/example_use_train.py index d6d5c4a..d54063d 100644 --- a/example_use_train.py +++ b/example_use_train.py @@ -12,7 +12,7 @@ # we set i-bce-topo to indicate the use of an iterative (i is associated with iUNET) loss # comprising balanced cross entropy (bce) and topological (topo) loss terms manager.train(train_tfrecord=train_tfrecord, validation_tfrecord=valid_tfrecord, - loss_type='bce-topo', model_dir=model_ckpt_dir) + loss_type='i-bce-topo', model_dir=model_ckpt_dir) ######################################################################################################################## # for defining and training a SHN model we would do the following: diff --git a/manager.py b/manager.py index 7fc16c7..c147305 100644 --- a/manager.py +++ b/manager.py @@ -3,7 +3,7 @@ import os import math import cv2 - +import datetime import matplotlib.pyplot as plt from readers import Reader @@ -42,6 +42,9 @@ def __init__(self, if name is 'SHN': assert(n_modules is not None) + dt = datetime.datetime.now() + self.date = '{}-{}-{}--{}-{}-{}'.format(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second) + # to be set only if a model is trained self.tag = None self.model_save_dir = None @@ -83,9 +86,11 @@ def _set_tag_and_create_model_dir(self, vgg_fmaps, vgg_weights, model_dir): self.tag = self.tag + '_' + str(w) if self.n_iterations is not None: - self.tag = self.tag + '_' + str(self.n_iterations) + self.tag = self.tag + '_modules_{}'.format(self.n_modules) elif self.n_modules is not None: - self.tag = self.tag + '_' + str(self.n_modules) + self.tag = self.tag + '_iters_{}'.format(self.n_iterations) + + self.tag = self.tag + '_' + self.date self.model_save_dir = os.path.join(model_dir, self.name, self.tag) if not os.path.exists(self.model_save_dir): @@ -187,7 +192,8 @@ def _loss_def(self, loss_type, vgg_fmaps=None, vgg_weights=None): use_vgg_loss=True, vgg_fmaps=vgg_fmaps, vgg_weights=vgg_weights) - def train(self, train_tfrecord, loss_type, vgg_fmaps=None, vgg_weights=None, validation_tfrecord=None, + def train(self, train_tfrecord, loss_type, + vgg_fmaps=None, vgg_weights=None, validation_tfrecord=None, training_steps=6000, batch_size=2, initial_lr=10**(-4), decay_steps=2000, decay_rate=0.5, do_online_augmentation=True, log_dir='', model_dir='models'):