Skip to content

Commit

Permalink
added date in log tags
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoPis committed Jul 22, 2020
1 parent ffc5e4f commit 1209c3c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion example_use_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# we set i-bce-topo to indicate the use of an iterative (i is associated with iUNET) loss
# comprising balanced cross entropy (bce) and topological (topo) loss terms
manager.train(train_tfrecord=train_tfrecord, validation_tfrecord=valid_tfrecord,
loss_type='bce-topo', model_dir=model_ckpt_dir)
loss_type='i-bce-topo', model_dir=model_ckpt_dir)

########################################################################################################################
# for defining and training a SHN model we would do the following:
Expand Down
14 changes: 10 additions & 4 deletions manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import math
import cv2

import datetime
import matplotlib.pyplot as plt

from readers import Reader
Expand Down Expand Up @@ -42,6 +42,9 @@ def __init__(self,
if name is 'SHN':
assert(n_modules is not None)

dt = datetime.datetime.now()
self.date = '{}-{}-{}--{}-{}-{}'.format(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)

# to be set only if a model is trained
self.tag = None
self.model_save_dir = None
Expand Down Expand Up @@ -83,9 +86,11 @@ def _set_tag_and_create_model_dir(self, vgg_fmaps, vgg_weights, model_dir):
self.tag = self.tag + '_' + str(w)

if self.n_iterations is not None:
self.tag = self.tag + '_' + str(self.n_iterations)
self.tag = self.tag + '_modules_{}'.format(self.n_modules)
elif self.n_modules is not None:
self.tag = self.tag + '_' + str(self.n_modules)
self.tag = self.tag + '_iters_{}'.format(self.n_iterations)

self.tag = self.tag + '_' + self.date

self.model_save_dir = os.path.join(model_dir, self.name, self.tag)
if not os.path.exists(self.model_save_dir):
Expand Down Expand Up @@ -187,7 +192,8 @@ def _loss_def(self, loss_type, vgg_fmaps=None, vgg_weights=None):
use_vgg_loss=True, vgg_fmaps=vgg_fmaps,
vgg_weights=vgg_weights)

def train(self, train_tfrecord, loss_type, vgg_fmaps=None, vgg_weights=None, validation_tfrecord=None,
def train(self, train_tfrecord, loss_type,
vgg_fmaps=None, vgg_weights=None, validation_tfrecord=None,
training_steps=6000, batch_size=2,
initial_lr=10**(-4), decay_steps=2000, decay_rate=0.5, do_online_augmentation=True,
log_dir='', model_dir='models'):
Expand Down

0 comments on commit 1209c3c

Please sign in to comment.