diff --git a/nobrainer/ext/SynthSeg/__init__.py b/nobrainer/ext/SynthSeg/__init__.py index 1d6c1d6f..9c31f73f 100644 --- a/nobrainer/ext/SynthSeg/__init__.py +++ b/nobrainer/ext/SynthSeg/__init__.py @@ -1,9 +1,11 @@ -from . import brain_generator -from . import estimate_priors -from . import evaluate -from . import labels_to_image_model -from . import metrics_model -from . import model_inputs -from . import predict -from . import training_supervised -from . import training +from . import ( + brain_generator, + estimate_priors, + evaluate, + labels_to_image_model, + metrics_model, + model_inputs, + predict, + training, + training_supervised, +) diff --git a/nobrainer/ext/SynthSeg/estimate_priors.py b/nobrainer/ext/SynthSeg/estimate_priors.py index 424c8c8e..fc8d9694 100644 --- a/nobrainer/ext/SynthSeg/estimate_priors.py +++ b/nobrainer/ext/SynthSeg/estimate_priors.py @@ -13,10 +13,11 @@ License. """ - # python imports import os + import numpy as np + try: from scipy.stats import median_absolute_deviation except ImportError: @@ -24,11 +25,12 @@ # third-party imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, utils -def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True): +def sample_intensity_stats_from_image( + image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True +): """This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity for all specified label values. Labels can share the same statistics by being regrouped into K classes. :param image: image from which to evaluate mean intensity and std deviation. @@ -48,19 +50,27 @@ def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_ """ # reformat labels and classes - labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int')) + labels_list = np.array( + utils.reformat_to_list(labels_list, load_as_numpy=True, dtype="int") + ) if classes_list is not None: - classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int')) + classes_list = np.array( + utils.reformat_to_list(classes_list, load_as_numpy=True, dtype="int") + ) else: classes_list = np.arange(labels_list.shape[0]) - assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length' + assert len(classes_list) == len( + labels_list + ), "labels and classes lists should have the same length" # get unique classes unique_classes, unique_indices = np.unique(classes_list, return_index=True) n_classes = len(unique_classes) if not np.array_equal(unique_classes, np.arange(n_classes)): - raise ValueError('classes_list should only contain values between 0 and K-1, ' - 'where K is the total number of classes. Here K = %d' % n_classes) + raise ValueError( + "classes_list should only contain values between 0 and K-1, " + "where K is the total number of classes. Here K = %d" % n_classes + ) # compute mean/std of specified classes means = np.zeros(n_classes) @@ -80,13 +90,14 @@ def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_ # compute stats for class and put them to the location of corresponding label values if len(intensities) != 0: means[idx] = np.nanmedian(intensities) - stds[idx] = median_absolute_deviation(intensities, nan_policy='omit') + stds[idx] = median_absolute_deviation(intensities, nan_policy="omit") return np.stack([means, stds]) -def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3, - rescale=True): +def sample_intensity_stats_from_single_dataset( + image_dir, labels_dir, labels_list, classes_list=None, max_channel=3, rescale=True +): """This function aims at estimating the intensity distributions of K different structure types from a set of images. The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation. Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the @@ -116,30 +127,42 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis # list files path_images = utils.list_images_in_folder(image_dir) path_labels = utils.list_images_in_folder(labels_dir) - assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files' + assert len(path_images) == len( + path_labels + ), "image and labels folders do not have the same number of files" # reformat list labels and classes - labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int')) + labels_list = np.array( + utils.reformat_to_list(labels_list, load_as_numpy=True, dtype="int") + ) if classes_list is not None: - classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int')) + classes_list = np.array( + utils.reformat_to_list(classes_list, load_as_numpy=True, dtype="int") + ) else: classes_list = np.arange(labels_list.shape[0]) - assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length' + assert len(classes_list) == len( + labels_list + ), "labels and classes lists should have the same length" # get unique classes unique_classes, unique_indices = np.unique(classes_list, return_index=True) n_classes = len(unique_classes) if not np.array_equal(unique_classes, np.arange(n_classes)): - raise ValueError('classes_list should only contain values between 0 and K-1, ' - 'where K is the total number of classes. Here K = %d' % n_classes) + raise ValueError( + "classes_list should only contain values between 0 and K-1, " + "where K is the total number of classes. Here K = %d" % n_classes + ) # initialise result arrays - n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel) + n_dims, n_channels = utils.get_dims( + utils.load_volume(path_images[0]).shape, max_channels=max_channel + ) means = np.zeros((len(path_images), n_classes, n_channels)) stds = np.zeros((len(path_images), n_classes, n_channels)) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True) + loop_info = utils.LoopInfo(len(path_images), 10, "estimating", print_time=True) for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) @@ -154,7 +177,9 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis im = image[..., channel] if rescale: im = edit_volumes.rescale_volume(im) - stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list) + stats = sample_intensity_stats_from_image( + im, la, labels_list, classes_list=classes_list + ) means[idx, :, channel] = stats[0, :] stds[idx, :, channel] = stats[1, :] @@ -176,13 +201,15 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis return prior_means, prior_stds -def build_intensity_stats(list_image_dir, - list_labels_dir, - result_dir, - estimation_labels, - estimation_classes=None, - max_channel=3, - rescale=True): +def build_intensity_stats( + list_image_dir, + list_labels_dir, + result_dir, + estimation_labels, + estimation_classes=None, + max_channel=3, + rescale=True, +): """This function aims at estimating the intensity distributions of K different structure types from a set of images. The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation. Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the @@ -219,22 +246,34 @@ def build_intensity_stats(list_image_dir, # reformat image/labels dir into lists list_image_dir = utils.reformat_to_list(list_image_dir) - list_labels_dir = utils.reformat_to_list(list_labels_dir, length=len(list_image_dir)) + list_labels_dir = utils.reformat_to_list( + list_labels_dir, length=len(list_image_dir) + ) # reformat list estimation labels and classes - estimation_labels = np.array(utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype='int')) + estimation_labels = np.array( + utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype="int") + ) if estimation_classes is not None: - estimation_classes = np.array(utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype='int')) + estimation_classes = np.array( + utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype="int") + ) else: estimation_classes = np.arange(estimation_labels.shape[0]) - assert len(estimation_classes) == len(estimation_labels), 'estimation labels and classes should be of same length' + assert len(estimation_classes) == len( + estimation_labels + ), "estimation labels and classes should be of same length" # get unique classes - unique_estimation_classes, unique_indices = np.unique(estimation_classes, return_index=True) + unique_estimation_classes, unique_indices = np.unique( + estimation_classes, return_index=True + ) n_classes = len(unique_estimation_classes) if not np.array_equal(unique_estimation_classes, np.arange(n_classes)): - raise ValueError('estimation_classes should only contain values between 0 and N-1, ' - 'where K is the total number of classes. Here N = %d' % n_classes) + raise ValueError( + "estimation_classes should only contain values between 0 and N-1, " + "where K is the total number of classes. Here N = %d" % n_classes + ) # loop over dataset list_datasets_prior_means = list() @@ -242,12 +281,14 @@ def build_intensity_stats(list_image_dir, for image_dir, labels_dir in zip(list_image_dir, list_labels_dir): # get prior stats for dataset - tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(image_dir, - labels_dir, - estimation_labels, - estimation_classes, - max_channel=max_channel, - rescale=rescale) + tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset( + image_dir, + labels_dir, + estimation_labels, + estimation_classes, + max_channel=max_channel, + rescale=rescale, + ) # add stats arrays to list of datasets-wise statistics list_datasets_prior_means.append(tmp_prior_means) @@ -258,7 +299,7 @@ def build_intensity_stats(list_image_dir, prior_stds = np.concatenate(list_datasets_prior_stds, axis=0) # save files - np.save(os.path.join(result_dir, 'prior_means.npy'), prior_means) - np.save(os.path.join(result_dir, 'prior_stds.npy'), prior_stds) + np.save(os.path.join(result_dir, "prior_means.npy"), prior_means) + np.save(os.path.join(result_dir, "prior_stds.npy"), prior_stds) return prior_means, prior_stds diff --git a/nobrainer/ext/SynthSeg/evaluate.py b/nobrainer/ext/SynthSeg/evaluate.py index 5928f5e5..74fbc1ae 100644 --- a/nobrainer/ext/SynthSeg/evaluate.py +++ b/nobrainer/ext/SynthSeg/evaluate.py @@ -13,16 +13,14 @@ License. """ - # python imports import os -import numpy as np -from scipy.stats import wilcoxon -from scipy.ndimage.morphology import distance_transform_edt # third-party imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, utils +import numpy as np +from scipy.ndimage.morphology import distance_transform_edt +from scipy.stats import wilcoxon def fast_dice(x, y, labels): @@ -33,20 +31,30 @@ def fast_dice(x, y, labels): :return: numpy array with Dice scores in the same order as labels. """ - assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(x.shape, y.shape) + assert ( + x.shape == y.shape + ), "both inputs should have same size, had {} and {}".format(x.shape, y.shape) if len(labels) > 1: # sort labels labels_sorted = np.sort(labels) # build bins for histograms - label_edges = np.sort(np.concatenate([labels_sorted - 0.1, labels_sorted + 0.1])) - label_edges = np.insert(label_edges, [0, len(label_edges)], [labels_sorted[0] - 0.1, labels_sorted[-1] + 0.1]) + label_edges = np.sort( + np.concatenate([labels_sorted - 0.1, labels_sorted + 0.1]) + ) + label_edges = np.insert( + label_edges, + [0, len(label_edges)], + [labels_sorted[0] - 0.1, labels_sorted[-1] + 0.1], + ) # compute Dice and re-arrange scores in initial order hst = np.histogram2d(x.flatten(), y.flatten(), bins=label_edges)[0] idx = np.arange(start=1, stop=2 * len(labels_sorted), step=2) - dice_score = 2 * np.diag(hst)[idx] / (np.sum(hst, 0)[idx] + np.sum(hst, 1)[idx] + 1e-5) + dice_score = ( + 2 * np.diag(hst)[idx] / (np.sum(hst, 0)[idx] + np.sum(hst, 1)[idx] + 1e-5) + ) dice_score = dice_score[np.searchsorted(labels_sorted, labels)] else: @@ -60,7 +68,9 @@ def dice(x, y): return 2 * np.sum(x * y) / (np.sum(x) + np.sum(y)) -def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_distance=False): +def surface_distances( + x, y, hausdorff_percentile=None, return_coordinate_max_distance=False +): """Computes the maximum boundary distance (Hausdorff distance), and the average boundary distance of two masks. :param x: numpy array (boolean or 0/1) :param y: numpy array (boolean or 0/1) @@ -74,7 +84,9 @@ def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_dis mean_dist: scalar with average surface distance coordinate_max_distance: only returned return_coordinate_max_distance is True.""" - assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(x.shape, y.shape) + assert ( + x.shape == y.shape + ), "both inputs should have same size, had {} and {}".format(x.shape, y.shape) n_dims = len(x.shape) hausdorff_percentile = 100 if hausdorff_percentile is None else hausdorff_percentile @@ -88,7 +100,9 @@ def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_dis if (crop_x is None) | (crop_y is None): return max(x.shape), max(x.shape) - crop = np.concatenate([np.minimum(crop_x, crop_y)[:n_dims], np.maximum(crop_x, crop_y)[n_dims:]]) + crop = np.concatenate( + [np.minimum(crop_x, crop_y)[:n_dims], np.maximum(crop_x, crop_y)[n_dims:]] + ) x = edit_volumes.crop_volume_with_idx(x, crop) y = edit_volumes.crop_volume_with_idx(y, crop) @@ -118,15 +132,23 @@ def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_dis indices_x_surface = np.where(x_edge == 1) idx_max_distance_x = np.where(x_dists_to_y == max_dist)[0] if idx_max_distance_x.size != 0: - coordinate_max_distance = np.stack(indices_x_surface).transpose()[idx_max_distance_x] + coordinate_max_distance = np.stack(indices_x_surface).transpose()[ + idx_max_distance_x + ] else: indices_y_surface = np.where(y_edge == 1) idx_max_distance_y = np.where(y_dists_to_x == max_dist)[0] - coordinate_max_distance = np.stack(indices_y_surface).transpose()[idx_max_distance_y] + coordinate_max_distance = np.stack(indices_y_surface).transpose()[ + idx_max_distance_y + ] # find percentile of max distance else: - max_dist.append(np.percentile(np.concatenate([x_dists_to_y, y_dists_to_x]), hd_percentile)) + max_dist.append( + np.percentile( + np.concatenate([x_dists_to_y, y_dists_to_x]), hd_percentile + ) + ) # find average distance between 2 surfaces if x_dists_to_y.shape[0] > 0: @@ -150,7 +172,9 @@ def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_dis return max_dist, mean_dist -def compute_non_parametric_paired_test(dice_ref, dice_compare, eval_indices=None, alternative='two-sided'): +def compute_non_parametric_paired_test( + dice_ref, dice_compare, eval_indices=None, alternative="two-sided" +): """Compute non-parametric paired t-tests between two sets of Dice scores. :param dice_ref: numpy array with Dice scores, rows represent structures, and columns represent subjects. Taken as reference for one-sided tests. @@ -204,28 +228,30 @@ def cohens_d(volumes_x, volumes_y): n_x = np.shape(volumes_x)[0] n_y = np.shape(volumes_y)[0] - std = np.sqrt(((n_x-1)*var_x + (n_y-1)*var_y) / (n_x + n_y - 2)) + std = np.sqrt(((n_x - 1) * var_x + (n_y - 1) * var_y) / (n_x + n_y - 2)) cohensd = (means_x - means_y) / std return cohensd -def evaluation(gt_dir, - seg_dir, - label_list, - mask_dir=None, - compute_score_whole_structure=False, - path_dice=None, - path_hausdorff=None, - path_hausdorff_99=None, - path_hausdorff_95=None, - path_mean_distance=None, - crop_margin_around_gt=10, - list_incorrect_labels=None, - list_correct_labels=None, - use_nearest_label=False, - recompute=True, - verbose=True): +def evaluation( + gt_dir, + seg_dir, + label_list, + mask_dir=None, + compute_score_whole_structure=False, + path_dice=None, + path_hausdorff=None, + path_hausdorff_99=None, + path_hausdorff_95=None, + path_mean_distance=None, + crop_margin_around_gt=10, + list_incorrect_labels=None, + list_correct_labels=None, + use_nearest_label=False, + recompute=True, + verbose=True, +): """This function computes Dice scores, as well as surface distances, between two sets of labels maps in gt_dir (ground truth) and seg_dir (typically predictions). Label maps in both folders are matched by sorting order. The resulting scores are saved at the specified locations. @@ -260,10 +286,24 @@ def evaluation(gt_dir, # check whether to recompute compute_dice = not os.path.isfile(path_dice) if (path_dice is not None) else True - compute_hausdorff = not os.path.isfile(path_hausdorff) if (path_hausdorff is not None) else False - compute_hausdorff_99 = not os.path.isfile(path_hausdorff_99) if (path_hausdorff_99 is not None) else False - compute_hausdorff_95 = not os.path.isfile(path_hausdorff_95) if (path_hausdorff_95 is not None) else False - compute_mean_dist = not os.path.isfile(path_mean_distance) if (path_mean_distance is not None) else False + compute_hausdorff = ( + not os.path.isfile(path_hausdorff) if (path_hausdorff is not None) else False + ) + compute_hausdorff_99 = ( + not os.path.isfile(path_hausdorff_99) + if (path_hausdorff_99 is not None) + else False + ) + compute_hausdorff_95 = ( + not os.path.isfile(path_hausdorff_95) + if (path_hausdorff_95 is not None) + else False + ) + compute_mean_dist = ( + not os.path.isfile(path_mean_distance) + if (path_mean_distance is not None) + else False + ) compute_hd = [compute_hausdorff, compute_hausdorff_99, compute_hausdorff_95] if compute_dice | any(compute_hd) | compute_mean_dist | recompute: @@ -273,11 +313,13 @@ def evaluation(gt_dir, path_segs = utils.list_images_in_folder(seg_dir) path_gt_labels = utils.reformat_to_list(path_gt_labels, length=len(path_segs)) if len(path_gt_labels) != len(path_segs): - print('gt and segmentation folders must have the same amount of label maps.') + print( + "gt and segmentation folders must have the same amount of label maps." + ) if mask_dir is not None: path_masks = utils.list_images_in_folder(mask_dir) if len(path_masks) != len(path_segs): - print('not the same amount of masks and segmentations.') + print("not the same amount of masks and segmentations.") else: path_masks = [None] * len(path_segs) @@ -297,26 +339,32 @@ def evaluation(gt_dir, dice_coefs = np.zeros((n_labels, len(path_segs))) # loop over segmentations - loop_info = utils.LoopInfo(len(path_segs), 10, 'evaluating', print_time=True) - for idx, (path_gt, path_seg, path_mask) in enumerate(zip(path_gt_labels, path_segs, path_masks)): + loop_info = utils.LoopInfo(len(path_segs), 10, "evaluating", print_time=True) + for idx, (path_gt, path_seg, path_mask) in enumerate( + zip(path_gt_labels, path_segs, path_masks) + ): if verbose: loop_info.update(idx) # load gt labels and segmentation - gt_labels = utils.load_volume(path_gt, dtype='int', aff_ref=np.eye(4)) - seg = utils.load_volume(path_seg, dtype='int', aff_ref=np.eye(4)) + gt_labels = utils.load_volume(path_gt, dtype="int", aff_ref=np.eye(4)) + seg = utils.load_volume(path_seg, dtype="int", aff_ref=np.eye(4)) if path_mask is not None: - mask = utils.load_volume(path_mask, dtype='bool', aff_ref=np.eye(4)) + mask = utils.load_volume(path_mask, dtype="bool", aff_ref=np.eye(4)) gt_labels[mask] = max_label seg[mask] = max_label # crop images if crop_margin_around_gt > 0: - gt_labels, cropping = edit_volumes.crop_volume_around_region(gt_labels, margin=crop_margin_around_gt) + gt_labels, cropping = edit_volumes.crop_volume_around_region( + gt_labels, margin=crop_margin_around_gt + ) seg = edit_volumes.crop_volume_with_idx(seg, cropping) if list_incorrect_labels is not None: - seg = edit_volumes.correct_label_map(seg, list_incorrect_labels, list_correct_labels, use_nearest_label) + seg = edit_volumes.correct_label_map( + seg, list_incorrect_labels, list_correct_labels, use_nearest_label + ) # compute Dice scores dice_coefs[:n_labels, idx] = fast_dice(gt_labels, seg, label_list) @@ -341,7 +389,9 @@ def evaluation(gt_dir, if (label in unique_gt_labels) & (label in unique_seg_labels): mask_gt = np.where(gt_labels == label, True, False) mask_seg = np.where(seg == label, True, False) - tmp_max_dists, mean_dists[index, idx] = surface_distances(mask_gt, mask_seg, [100, 99, 95]) + tmp_max_dists, mean_dists[index, idx] = surface_distances( + mask_gt, mask_seg, [100, 99, 95] + ) max_dists[index, idx, :] = np.array(tmp_max_dists) else: mean_dists[index, idx] = max(gt_labels.shape) @@ -349,7 +399,9 @@ def evaluation(gt_dir, # compute max/mean distances for whole structure if compute_score_whole_structure: - tmp_max_dists, mean_dists[-1, idx] = surface_distances(temp_gt, temp_seg, [100, 99, 95]) + tmp_max_dists, mean_dists[-1, idx] = surface_distances( + temp_gt, temp_seg, [100, 99, 95] + ) max_dists[-1, idx, :] = np.array(tmp_max_dists) # write results diff --git a/nobrainer/ext/SynthSeg/metrics_model.py b/nobrainer/ext/SynthSeg/metrics_model.py index 6c351341..9a7c3b63 100644 --- a/nobrainer/ext/SynthSeg/metrics_model.py +++ b/nobrainer/ext/SynthSeg/metrics_model.py @@ -13,18 +13,17 @@ License. """ +# third-party imports +from ext.lab2im import layers +import keras.layers as KL +from keras.models import Model # python imports import numpy as np import tensorflow as tf -import keras.layers as KL -from keras.models import Model -# third-party imports -from ext.lab2im import layers - -def metrics_model(input_model, label_list, metrics='dice'): +def metrics_model(input_model, label_list, metrics="dice"): # get prediction last_tensor = input_model.outputs[0] @@ -33,26 +32,32 @@ def metrics_model(input_model, label_list, metrics='dice'): # check shapes n_labels = input_shape[-1] label_list = np.unique(label_list) - assert n_labels == len(label_list), 'label_list should be as long as the posteriors channels' + assert n_labels == len( + label_list + ), "label_list should be as long as the posteriors channels" # get GT and convert it to probabilistic values - labels_gt = input_model.get_layer('labels_out').output + labels_gt = input_model.get_layer("labels_out").output labels_gt = layers.ConvertLabels(label_list)(labels_gt) - labels_gt = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels_gt) + labels_gt = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, dtype="int32"), depth=n_labels, axis=-1) + )(labels_gt) labels_gt = KL.Reshape(input_shape)(labels_gt) # make sure the tensors have the right keras shape last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list()) labels_gt._keras_shape = tuple(labels_gt.get_shape().as_list()) - if metrics == 'dice': + if metrics == "dice": last_tensor = layers.DiceLoss()([labels_gt, last_tensor]) - elif metrics == 'wl2': + elif metrics == "wl2": last_tensor = layers.WeightedL2Loss(target_value=5)([labels_gt, last_tensor]) else: - raise Exception('metrics should either be "dice or "wl2, got {}'.format(metrics)) + raise Exception( + 'metrics should either be "dice or "wl2, got {}'.format(metrics) + ) # create the model and return model = Model(inputs=input_model.inputs, outputs=last_tensor) @@ -61,13 +66,14 @@ def metrics_model(input_model, label_list, metrics='dice'): class IdentityLoss(object): """Very simple loss, as the computation of the loss as been directly implemented in the model.""" + def __init__(self, keepdims=True): self.keepdims = keepdims def loss(self, y_true, y_predicted): """Because the metrics is already calculated in the model, we simply return y_predicted. - We still need to put y_true in the inputs, as it's expected by keras.""" + We still need to put y_true in the inputs, as it's expected by keras.""" loss = y_predicted - tf.debugging.check_numerics(loss, 'Loss not finite') + tf.debugging.check_numerics(loss, "Loss not finite") return loss diff --git a/nobrainer/ext/SynthSeg/model_inputs.py b/nobrainer/ext/SynthSeg/model_inputs.py index 34f4b4ad..e6d9cda2 100644 --- a/nobrainer/ext/SynthSeg/model_inputs.py +++ b/nobrainer/ext/SynthSeg/model_inputs.py @@ -13,7 +13,6 @@ License. """ - # python imports import numpy as np import numpy.random as npr @@ -22,17 +21,19 @@ from nobrainer.ext.lab2im import utils -def build_model_inputs(path_label_maps, - n_labels, - batchsize=1, - n_channels=1, - subjects_prob=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False): +def build_model_inputs( + path_label_maps, + n_labels, + batchsize=1, + n_channels=1, + subjects_prob=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, +): """ This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). @@ -86,7 +87,9 @@ def build_model_inputs(path_label_maps, while True: # randomly pick as many images as batchsize - indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob) + indices = npr.choice( + np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob + ) # initialise input lists list_label_maps = [] @@ -96,8 +99,10 @@ def build_model_inputs(path_label_maps, for idx in indices: # load input label map - lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) - if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]): + lab = utils.load_volume( + path_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) + if (npr.uniform() > 0.7) & ("seg_cerebral" in path_label_maps[idx]): lab[lab == 24] = 0 # add label map to inputs @@ -112,42 +117,72 @@ def build_model_inputs(path_label_maps, if isinstance(prior_means, np.ndarray): if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: if prior_means.shape[0] / 2 != n_channels: - raise ValueError("the number of blocks in prior_means does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = prior_means[2 * channel : 2 * channel + 2, :] else: tmp_prior_means = prior_means else: tmp_prior_means = prior_means - if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + if ( + (prior_means is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): tmp_prior_means = None if isinstance(prior_stds, np.ndarray): if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: if prior_stds.shape[0] / 2 != n_channels: - raise ValueError("the number of blocks in prior_stds does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = prior_stds[2 * channel : 2 * channel + 2, :] else: tmp_prior_stds = prior_stds else: tmp_prior_stds = prior_stds - if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + if ( + (prior_stds is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): tmp_prior_stds = None # draw means and std devs from priors - tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions, - 125., 125., positive_only=True) - tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions, - 15., 15., positive_only=True) + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_classes, + prior_distributions, + 125.0, + 125.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_classes, + prior_distributions, + 15.0, + 15.0, + positive_only=True, + ) random_coef = npr.uniform() if random_coef > 0.95: # reset the background to 0 in 5% of cases tmp_classes_means[0] = 0 tmp_classes_stds[0] = 0 - elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases + elif ( + random_coef > 0.7 + ): # reset the background to low Gaussian in 25% of cases tmp_classes_means[0] = npr.uniform(0, 15) tmp_classes_stds[0] = npr.uniform(0, 5) - tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1]) - tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1]) + tmp_means = utils.add_axis( + tmp_classes_means[generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[generation_classes], axis=[0, -1] + ) means = np.concatenate([means, tmp_means], axis=-1) stds = np.concatenate([stds, tmp_stds], axis=-1) list_means.append(means) diff --git a/nobrainer/ext/SynthSeg/predict.py b/nobrainer/ext/SynthSeg/predict.py index 70b2e1fc..d05796bf 100644 --- a/nobrainer/ext/SynthSeg/predict.py +++ b/nobrainer/ext/SynthSeg/predict.py @@ -13,56 +13,56 @@ License. """ +import csv # python imports import os -import csv -import numpy as np -import tensorflow as tf -import keras.layers as KL -import keras.backend as K -from keras.models import Model # project imports from SynthSeg import evaluate # third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, layers, utils from ext.neuron import models as nrn_models +import keras.backend as K +import keras.layers as KL +from keras.models import Model +import numpy as np +import tensorflow as tf -def predict(path_images, - path_segmentations, - path_model, - labels_segmentation, - n_neutral_labels=None, - names_segmentation=None, - path_posteriors=None, - path_resampled=None, - path_volumes=None, - min_pad=None, - cropping=None, - target_res=1., - gradients=False, - flip=True, - topology_classes=None, - sigma_smoothing=0.5, - keep_biggest_component=True, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - gt_folder=None, - evaluation_labels=None, - list_incorrect_labels=None, - list_correct_labels=None, - compute_distances=False, - recompute=True, - verbose=True): +def predict( + path_images, + path_segmentations, + path_model, + labels_segmentation, + n_neutral_labels=None, + names_segmentation=None, + path_posteriors=None, + path_resampled=None, + path_volumes=None, + min_pad=None, + cropping=None, + target_res=1.0, + gradients=False, + flip=True, + topology_classes=None, + sigma_smoothing=0.5, + keep_biggest_component=True, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + gt_folder=None, + evaluation_labels=None, + list_incorrect_labels=None, + list_correct_labels=None, + compute_distances=False, + recompute=True, + verbose=True, +): """ This function uses trained models to segment images. It is crucial that the inputs match the architecture parameters of the trained model. @@ -130,22 +130,42 @@ def predict(path_images, """ # prepare input/output filepaths - path_images, path_segmentations, path_posteriors, path_resampled, path_volumes, compute, unique_vol_file = \ - prepare_output_files(path_images, path_segmentations, path_posteriors, path_resampled, path_volumes, recompute) + ( + path_images, + path_segmentations, + path_posteriors, + path_resampled, + path_volumes, + compute, + unique_vol_file, + ) = prepare_output_files( + path_images, + path_segmentations, + path_posteriors, + path_resampled, + path_volumes, + recompute, + ) # get label list labels_segmentation, _ = utils.get_list_labels(label_list=labels_segmentation) if (n_neutral_labels is not None) & flip: - labels_segmentation, flip_indices, unique_idx = get_flip_indices(labels_segmentation, n_neutral_labels) + labels_segmentation, flip_indices, unique_idx = get_flip_indices( + labels_segmentation, n_neutral_labels + ) else: - labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True) + labels_segmentation, unique_idx = np.unique( + labels_segmentation, return_index=True + ) flip_indices = None # prepare other labels list if names_segmentation is not None: names_segmentation = utils.load_array_if_path(names_segmentation)[unique_idx] if topology_classes is not None: - topology_classes = utils.load_array_if_path(topology_classes, load_as_numpy=True)[unique_idx] + topology_classes = utils.load_array_if_path( + topology_classes, load_as_numpy=True + )[unique_idx] # prepare volumes if necessary if unique_vol_file & (path_volumes[0] is not None): @@ -154,30 +174,32 @@ def predict(path_images, # build network _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_images[0]) model_input_shape = [None] * n_dims + [n_channels] - net = build_model(path_model=path_model, - input_shape=model_input_shape, - labels_segmentation=labels_segmentation, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - sigma_smoothing=sigma_smoothing, - flip_indices=flip_indices, - gradients=gradients) + net = build_model( + path_model=path_model, + input_shape=model_input_shape, + labels_segmentation=labels_segmentation, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + sigma_smoothing=sigma_smoothing, + flip_indices=flip_indices, + gradients=gradients, + ) # set cropping/padding if (cropping is not None) & (min_pad is not None): - cropping = utils.reformat_to_list(cropping, length=n_dims, dtype='int') - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') + cropping = utils.reformat_to_list(cropping, length=n_dims, dtype="int") + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") min_pad = np.minimum(cropping, min_pad) # perform segmentation if len(path_images) <= 10: - loop_info = utils.LoopInfo(len(path_images), 1, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 1, "predicting", True) else: - loop_info = utils.LoopInfo(len(path_images), 10, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "predicting", True) for i in range(len(path_images)): if verbose: loop_info.update(i) @@ -186,39 +208,53 @@ def predict(path_images, if compute[i]: # preprocessing - image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=path_images[i], - n_levels=n_levels, - target_res=target_res, - crop=cropping, - min_pad=min_pad, - path_resample=path_resampled[i]) + image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess( + path_image=path_images[i], + n_levels=n_levels, + target_res=target_res, + crop=cropping, + min_pad=min_pad, + path_resample=path_resampled[i], + ) # prediction post_patch = net.predict(image) # postprocessing - seg, posteriors, volumes = postprocess(post_patch=post_patch, - shape=shape, - pad_idx=pad_idx, - crop_idx=crop_idx, - n_dims=n_dims, - labels_segmentation=labels_segmentation, - keep_biggest_component=keep_biggest_component, - aff=aff, - im_res=im_res, - topology_classes=topology_classes) + seg, posteriors, volumes = postprocess( + post_patch=post_patch, + shape=shape, + pad_idx=pad_idx, + crop_idx=crop_idx, + n_dims=n_dims, + labels_segmentation=labels_segmentation, + keep_biggest_component=keep_biggest_component, + aff=aff, + im_res=im_res, + topology_classes=topology_classes, + ) # write results to disk - utils.save_volume(seg, aff, h, path_segmentations[i], dtype='int32') + utils.save_volume(seg, aff, h, path_segmentations[i], dtype="int32") if path_posteriors[i] is not None: if n_channels > 1: posteriors = utils.add_axis(posteriors, axis=[0, -1]) - utils.save_volume(posteriors, aff, h, path_posteriors[i], dtype='float32') + utils.save_volume( + posteriors, aff, h, path_posteriors[i], dtype="float32" + ) # compute volumes if path_volumes[i] is not None: - row = [os.path.basename(path_images[i]).replace('.nii.gz', '')] + [str(vol) for vol in volumes] - write_csv(path_volumes[i], row, unique_vol_file, labels_segmentation, names_segmentation) + row = [os.path.basename(path_images[i]).replace(".nii.gz", "")] + [ + str(vol) for vol in volumes + ] + write_csv( + path_volumes[i], + row, + unique_vol_file, + labels_segmentation, + names_segmentation, + ) # evaluate if gt_folder is not None: @@ -230,57 +266,82 @@ def predict(path_images, # set path of result arrays for surface distance if necessary if compute_distances: - path_hausdorff = os.path.join(eval_folder, 'hausdorff.npy') - path_hausdorff_99 = os.path.join(eval_folder, 'hausdorff_99.npy') - path_hausdorff_95 = os.path.join(eval_folder, 'hausdorff_95.npy') - path_mean_distance = os.path.join(eval_folder, 'mean_distance.npy') + path_hausdorff = os.path.join(eval_folder, "hausdorff.npy") + path_hausdorff_99 = os.path.join(eval_folder, "hausdorff_99.npy") + path_hausdorff_95 = os.path.join(eval_folder, "hausdorff_95.npy") + path_mean_distance = os.path.join(eval_folder, "mean_distance.npy") else: - path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = path_mean_distance = None + path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = ( + path_mean_distance + ) = None # compute evaluation metrics - evaluate.evaluation(gt_folder, - eval_folder, - evaluation_labels, - path_dice=os.path.join(eval_folder, 'dice.npy'), - path_hausdorff=path_hausdorff, - path_hausdorff_99=path_hausdorff_99, - path_hausdorff_95=path_hausdorff_95, - path_mean_distance=path_mean_distance, - list_incorrect_labels=list_incorrect_labels, - list_correct_labels=list_correct_labels, - recompute=recompute, - verbose=verbose) - - -def prepare_output_files(path_images, out_seg, out_posteriors, out_resampled, out_volumes, recompute): + evaluate.evaluation( + gt_folder, + eval_folder, + evaluation_labels, + path_dice=os.path.join(eval_folder, "dice.npy"), + path_hausdorff=path_hausdorff, + path_hausdorff_99=path_hausdorff_99, + path_hausdorff_95=path_hausdorff_95, + path_mean_distance=path_mean_distance, + list_incorrect_labels=list_incorrect_labels, + list_correct_labels=list_correct_labels, + recompute=recompute, + verbose=verbose, + ) + + +def prepare_output_files( + path_images, out_seg, out_posteriors, out_resampled, out_volumes, recompute +): # check inputs - assert path_images is not None, 'please specify an input file/folder (--i)' - assert out_seg is not None, 'please specify an output file/folder (--o)' + assert path_images is not None, "please specify an input file/folder (--i)" + assert out_seg is not None, "please specify an output file/folder (--o)" # convert path to absolute paths path_images = os.path.abspath(path_images) basename = os.path.basename(path_images) out_seg = os.path.abspath(out_seg) - out_posteriors = os.path.abspath(out_posteriors) if (out_posteriors is not None) else out_posteriors - out_resampled = os.path.abspath(out_resampled) if (out_resampled is not None) else out_resampled - out_volumes = os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + out_posteriors = ( + os.path.abspath(out_posteriors) + if (out_posteriors is not None) + else out_posteriors + ) + out_resampled = ( + os.path.abspath(out_resampled) if (out_resampled is not None) else out_resampled + ) + out_volumes = ( + os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + ) # path_images is a text file - if basename[-4:] == '.txt': + if basename[-4:] == ".txt": # input images if not os.path.isfile(path_images): - raise Exception('provided text file containing paths of input images does not exist' % path_images) - with open(path_images, 'r') as f: - path_images = [line.replace('\n', '') for line in f.readlines() if line != '\n'] + raise Exception( + "provided text file containing paths of input images does not exist" + % path_images + ) + with open(path_images, "r") as f: + path_images = [ + line.replace("\n", "") for line in f.readlines() if line != "\n" + ] # define helper to deal with outputs def text_helper(path, name): if path is not None: - assert path[-4:] == '.txt', 'if path_images given as text file, so must be %s' % name - with open(path, 'r') as ff: - path = [line.replace('\n', '') for line in ff.readlines() if line != '\n'] + assert path[-4:] == ".txt", ( + "if path_images given as text file, so must be %s" % name + ) + with open(path, "r") as ff: + path = [ + line.replace("\n", "") + for line in ff.readlines() + if line != "\n" + ] recompute_files = [not os.path.isfile(p) for p in path] else: path = [None] * len(path_images) @@ -289,38 +350,64 @@ def text_helper(path, name): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = text_helper(out_seg, 'path_segmentations') - out_posteriors, recompute_post, _ = text_helper(out_posteriors, 'path_posteriors') - out_resampled, recompute_resampled, _ = text_helper(out_resampled, 'path_resampled') - out_volumes, recompute_volume, unique_volume_file = text_helper(out_volumes, 'path_volume') + out_seg, recompute_seg, _ = text_helper(out_seg, "path_segmentations") + out_posteriors, recompute_post, _ = text_helper( + out_posteriors, "path_posteriors" + ) + out_resampled, recompute_resampled, _ = text_helper( + out_resampled, "path_resampled" + ) + out_volumes, recompute_volume, unique_volume_file = text_helper( + out_volumes, "path_volume" + ) # path_images is a folder - elif ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename): + elif ( + (".nii.gz" not in basename) + & (".nii" not in basename) + & (".mgz" not in basename) + & (".npz" not in basename) + ): # input images if os.path.isfile(path_images): - raise Exception('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_images) + raise Exception( + "Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz" + % path_images + ) path_images = utils.list_images_in_folder(path_images) # define helper to deal with outputs def helper_dir(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" path = [path] * len(path_images) recompute_files = [True] * len(path_images) unique_file = True else: - if (path[-7:] == '.nii.gz') | (path[-4:] == '.nii') | (path[-4:] == '.mgz') | (path[-4:] == '.npz'): - raise Exception('Output FOLDER had a FILE extension' % path) - path = [os.path.join(path, os.path.basename(p)) for p in path_images] - path = [p.replace('.nii', '_%s.nii' % suffix) for p in path] - path = [p.replace('.mgz', '_%s.mgz' % suffix) for p in path] - path = [p.replace('.npz', '_%s.npz' % suffix) for p in path] + if ( + (path[-7:] == ".nii.gz") + | (path[-4:] == ".nii") + | (path[-4:] == ".mgz") + | (path[-4:] == ".npz") + ): + raise Exception("Output FOLDER had a FILE extension" % path) + path = [ + os.path.join(path, os.path.basename(p)) for p in path_images + ] + path = [p.replace(".nii", "_%s.nii" % suffix) for p in path] + path = [p.replace(".mgz", "_%s.mgz" % suffix) for p in path] + path = [p.replace(".npz", "_%s.npz" % suffix) for p in path] recompute_files = [not os.path.isfile(p) for p in path] utils.mkdir(os.path.dirname(path[0])) else: @@ -329,35 +416,57 @@ def helper_dir(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_dir(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_dir(out_posteriors, 'path_posteriors', '', 'posteriors') - out_resampled, recompute_resampled, _ = helper_dir(out_resampled, 'path_resampled', '', 'resampled') - out_volumes, recompute_volume, unique_volume_file = helper_dir(out_volumes, 'path_volumes', 'csv', '') + out_seg, recompute_seg, _ = helper_dir( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_dir( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_resampled, recompute_resampled, _ = helper_dir( + out_resampled, "path_resampled", "", "resampled" + ) + out_volumes, recompute_volume, unique_volume_file = helper_dir( + out_volumes, "path_volumes", "csv", "" + ) # path_images is an image else: # input image - assert os.path.isfile(path_images), 'file does not exist: %s \n' \ - 'please make sure the path and the extension are correct' % path_images + assert os.path.isfile(path_images), ( + "file does not exist: %s \n" + "please make sure the path and the extension are correct" % path_images + ) path_images = [path_images] # define helper to deal with outputs def helper_im(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" recompute_files = [True] unique_file = True else: - if ('.nii.gz' not in path) & ('.nii' not in path) & ('.mgz' not in path) & ('.npz' not in path): - file_name = os.path.basename(path_images[0]).replace('.nii', '_%s.nii' % suffix) - file_name = file_name.replace('.mgz', '_%s.mgz' % suffix) - file_name = file_name.replace('.npz', '_%s.npz' % suffix) + if ( + (".nii.gz" not in path) + & (".nii" not in path) + & (".mgz" not in path) + & (".npz" not in path) + ): + file_name = os.path.basename(path_images[0]).replace( + ".nii", "_%s.nii" % suffix + ) + file_name = file_name.replace(".mgz", "_%s.mgz" % suffix) + file_name = file_name.replace(".npz", "_%s.npz" % suffix) path = os.path.join(path, file_name) recompute_files = [not os.path.isfile(path)] utils.mkdir(os.path.dirname(path)) @@ -367,18 +476,40 @@ def helper_im(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_im(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_im(out_posteriors, 'path_posteriors', '', 'posteriors') - out_resampled, recompute_resampled, _ = helper_im(out_resampled, 'path_resampled', '', 'resampled') - out_volumes, recompute_volume, unique_volume_file = helper_im(out_volumes, 'path_volumes', 'csv', '') - - recompute_list = [recompute | re_seg | re_post | re_res | re_vol for (re_seg, re_post, re_res, re_vol) - in zip(recompute_seg, recompute_post, recompute_resampled, recompute_volume)] - - return path_images, out_seg, out_posteriors, out_resampled, out_volumes, recompute_list, unique_volume_file - - -def preprocess(path_image, n_levels, target_res, crop=None, min_pad=None, path_resample=None): + out_seg, recompute_seg, _ = helper_im( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_im( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_resampled, recompute_resampled, _ = helper_im( + out_resampled, "path_resampled", "", "resampled" + ) + out_volumes, recompute_volume, unique_volume_file = helper_im( + out_volumes, "path_volumes", "csv", "" + ) + + recompute_list = [ + recompute | re_seg | re_post | re_res | re_vol + for (re_seg, re_post, re_res, re_vol) in zip( + recompute_seg, recompute_post, recompute_resampled, recompute_volume + ) + ] + + return ( + path_images, + out_seg, + out_posteriors, + out_resampled, + out_volumes, + recompute_list, + unique_volume_file, + ) + + +def preprocess( + path_image, n_levels, target_res, crop=None, min_pad=None, path_resample=None +): # read image and corresponding info im, _, aff, n_dims, n_channels, h, im_res = utils.get_volume_info(path_image, True) @@ -393,33 +524,57 @@ def preprocess(path_image, n_levels, target_res, crop=None, min_pad=None, path_r utils.save_volume(im, aff, h, path_resample) # align image - im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False) + im = edit_volumes.align_volume_to_ref( + im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False + ) shape = list(im.shape[:n_dims]) # crop image if necessary if crop is not None: - crop = utils.reformat_to_list(crop, length=n_dims, dtype='int') - crop_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop] - im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True) + crop = utils.reformat_to_list(crop, length=n_dims, dtype="int") + crop_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in crop + ] + im, crop_idx = edit_volumes.crop_volume( + im, cropping_shape=crop_shape, return_crop_idx=True + ) else: crop_idx = None # normalise image if n_channels == 1: - im = edit_volumes.rescale_volume(im, new_min=0., new_max=1., min_percentile=0.5, max_percentile=99.5) + im = edit_volumes.rescale_volume( + im, new_min=0.0, new_max=1.0, min_percentile=0.5, max_percentile=99.5 + ) else: for i in range(im.shape[-1]): - im[..., i] = edit_volumes.rescale_volume(im[..., i], new_min=0., new_max=1., - min_percentile=0.5, max_percentile=99.5) + im[..., i] = edit_volumes.rescale_volume( + im[..., i], + new_min=0.0, + new_max=1.0, + min_percentile=0.5, + max_percentile=99.5, + ) # pad image input_shape = im.shape[:n_dims] - pad_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape] - if min_pad is not None: # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') - min_pad = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad] + pad_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in input_shape + ] + if ( + min_pad is not None + ): # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") + min_pad = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in min_pad + ] pad_shape = np.maximum(pad_shape, min_pad) - im, pad_idx = edit_volumes.pad_volume(im, padding_shape=pad_shape, return_pad_idx=True) + im, pad_idx = edit_volumes.pad_volume( + im, padding_shape=pad_shape, return_pad_idx=True + ) # add batch and channel axes im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im, axis=[0, -1]) @@ -427,18 +582,20 @@ def preprocess(path_image, n_levels, target_res, crop=None, min_pad=None, path_r return im, aff, h, im_res, shape, pad_idx, crop_idx -def build_model(path_model, - input_shape, - labels_segmentation, - n_levels, - nb_conv_per_level, - conv_size, - unet_feat_count, - feat_multiplier, - activation, - sigma_smoothing, - flip_indices, - gradients): +def build_model( + path_model, + input_shape, + labels_segmentation, + n_levels, + nb_conv_per_level, + conv_size, + unet_feat_count, + feat_multiplier, + activation, + sigma_smoothing, + flip_indices, + gradients, +): assert os.path.isfile(path_model), "The provided model path does not exist." @@ -447,23 +604,27 @@ def build_model(path_model, if gradients: input_image = KL.Input(input_shape) - last_tensor = layers.ImageGradients('sobel', True)(input_image) - last_tensor = KL.Lambda(lambda x: (x - K.min(x)) / (K.max(x) - K.min(x) + K.epsilon()))(last_tensor) + last_tensor = layers.ImageGradients("sobel", True)(input_image) + last_tensor = KL.Lambda( + lambda x: (x - K.min(x)) / (K.max(x) - K.min(x) + K.epsilon()) + )(last_tensor) net = Model(inputs=input_image, outputs=last_tensor) else: net = None # build UNet - net = nrn_models.unet(input_model=net, - input_shape=input_shape, - nb_labels=n_labels_seg, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1) + net = nrn_models.unet( + input_model=net, + input_shape=input_shape, + nb_labels=n_labels_seg, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + ) net.load_weights(path_model, by_name=True) # smooth posteriors if specified @@ -483,33 +644,55 @@ def build_model(path_model, # flip back and re-order channels last_tensor = layers.RandomFlip(axis=0, prob=1)(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.split(x, [1] * n_labels_seg, axis=-1), name='split')(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.split(x, [1] * n_labels_seg, axis=-1), name="split" + )(last_tensor) reordered_channels = [last_tensor[flip_indices[i]] for i in range(n_labels_seg)] - last_tensor = KL.Lambda(lambda x: tf.concat(x, -1), name='concat')(reordered_channels) + last_tensor = KL.Lambda(lambda x: tf.concat(x, -1), name="concat")( + reordered_channels + ) # average two segmentations and build model - name_segm_prediction_layer = 'average_lr' - last_tensor = KL.Lambda(lambda x: 0.5 * (x[0] + x[1]), name=name_segm_prediction_layer)([seg, last_tensor]) + name_segm_prediction_layer = "average_lr" + last_tensor = KL.Lambda( + lambda x: 0.5 * (x[0] + x[1]), name=name_segm_prediction_layer + )([seg, last_tensor]) net = Model(inputs=net.inputs, outputs=last_tensor) return net -def postprocess(post_patch, shape, pad_idx, crop_idx, n_dims, - labels_segmentation, keep_biggest_component, aff, im_res, topology_classes=None): +def postprocess( + post_patch, + shape, + pad_idx, + crop_idx, + n_dims, + labels_segmentation, + keep_biggest_component, + aff, + im_res, + topology_classes=None, +): # get posteriors post_patch = np.squeeze(post_patch) if topology_classes is None: - post_patch = edit_volumes.crop_volume_with_idx(post_patch, pad_idx, n_dims=3, return_copy=False) + post_patch = edit_volumes.crop_volume_with_idx( + post_patch, pad_idx, n_dims=3, return_copy=False + ) # keep biggest connected component if keep_biggest_component: tmp_post_patch = post_patch[..., 1:] post_patch_mask = np.sum(tmp_post_patch, axis=-1) > 0.25 post_patch_mask = edit_volumes.get_largest_connected_component(post_patch_mask) - post_patch_mask = np.stack([post_patch_mask]*tmp_post_patch.shape[-1], axis=-1) - tmp_post_patch = edit_volumes.mask_volume(tmp_post_patch, mask=post_patch_mask, return_copy=False) + post_patch_mask = np.stack( + [post_patch_mask] * tmp_post_patch.shape[-1], axis=-1 + ) + tmp_post_patch = edit_volumes.mask_volume( + tmp_post_patch, mask=post_patch_mask, return_copy=False + ) post_patch[..., 1:] = tmp_post_patch # reset posteriors to zero outside the largest connected component of each topological class @@ -521,35 +704,56 @@ def postprocess(post_patch, shape, pad_idx, crop_idx, n_dims, tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask) for idx in tmp_topology_indices: post_patch[..., idx] *= tmp_mask - post_patch = edit_volumes.crop_volume_with_idx(post_patch, pad_idx, n_dims=3, return_copy=False) + post_patch = edit_volumes.crop_volume_with_idx( + post_patch, pad_idx, n_dims=3, return_copy=False + ) # normalise posteriors and get hard segmentation if keep_biggest_component | (topology_classes is not None): post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis] - seg_patch = labels_segmentation[post_patch.argmax(-1).astype('int32')].astype('int32') + seg_patch = labels_segmentation[post_patch.argmax(-1).astype("int32")].astype( + "int32" + ) # paste patches back to matrix of original image size if crop_idx is not None: # we need to go through this because of the posteriors of the background, otherwise pad_volume would work - seg = np.zeros(shape=shape, dtype='int32') + seg = np.zeros(shape=shape, dtype="int32") posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]]) posteriors[..., 0] = np.ones(shape) # place background around patch if n_dims == 2: - seg[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3]] = seg_patch - posteriors[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], :] = post_patch + seg[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3]] = seg_patch + posteriors[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], :] = ( + post_patch + ) elif n_dims == 3: - seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch - posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch + seg[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ] = seg_patch + posteriors[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + :, + ] = post_patch else: seg = seg_patch posteriors = post_patch # align prediction back to first orientation - seg = edit_volumes.align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False) - posteriors = edit_volumes.align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False) + seg = edit_volumes.align_volume_to_ref( + seg, aff=np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False + ) + posteriors = edit_volumes.align_volume_to_ref( + posteriors, np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False + ) # compute volumes - volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1))) + volumes = np.sum( + posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)) + ) volumes = np.around(volumes * np.prod(im_res), 3) return seg, posteriors, volumes @@ -560,14 +764,24 @@ def get_flip_indices(labels_segmentation, n_neutral_labels): # get position labels n_sided_labels = int((len(labels_segmentation) - n_neutral_labels) / 2) neutral_labels = labels_segmentation[:n_neutral_labels] - left = labels_segmentation[n_neutral_labels:n_neutral_labels + n_sided_labels] + left = labels_segmentation[n_neutral_labels : n_neutral_labels + n_sided_labels] # get correspondence between labels - lr_corresp = np.stack([labels_segmentation[n_neutral_labels:n_neutral_labels + n_sided_labels], - labels_segmentation[n_neutral_labels + n_sided_labels:]]) - lr_corresp_unique, lr_corresp_indices = np.unique(lr_corresp[0, :], return_index=True) + lr_corresp = np.stack( + [ + labels_segmentation[n_neutral_labels : n_neutral_labels + n_sided_labels], + labels_segmentation[n_neutral_labels + n_sided_labels :], + ] + ) + lr_corresp_unique, lr_corresp_indices = np.unique( + lr_corresp[0, :], return_index=True + ) lr_corresp_unique = np.stack([lr_corresp_unique, lr_corresp[1, lr_corresp_indices]]) - lr_corresp_unique = lr_corresp_unique[:, 1:] if not np.all(lr_corresp_unique[:, 0]) else lr_corresp_unique + lr_corresp_unique = ( + lr_corresp_unique[:, 1:] + if not np.all(lr_corresp_unique[:, 0]) + else lr_corresp_unique + ) # get unique labels labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True) @@ -584,14 +798,20 @@ def get_flip_indices(labels_segmentation, n_neutral_labels): if labels_segmentation[i] in neutral_labels: flip_indices[i] = i elif labels_segmentation[i] in left: - flip_indices[i] = lr_indices[1, np.where(lr_corresp_unique[0, :] == labels_segmentation[i])] + flip_indices[i] = lr_indices[ + 1, np.where(lr_corresp_unique[0, :] == labels_segmentation[i]) + ] else: - flip_indices[i] = lr_indices[0, np.where(lr_corresp_unique[1, :] == labels_segmentation[i])] + flip_indices[i] = lr_indices[ + 0, np.where(lr_corresp_unique[1, :] == labels_segmentation[i]) + ] return labels_segmentation, flip_indices, unique_idx -def write_csv(path_csv, data, unique_file, labels, names, skip_first=True, last_first=False): +def write_csv( + path_csv, data, unique_file, labels, names, skip_first=True, last_first=False +): # initialisation utils.mkdir(os.path.dirname(path_csv)) @@ -608,19 +828,19 @@ def write_csv(path_csv, data, unique_file, labels, names, skip_first=True, last_ if last_first: header = [header[-1]] + header[:-1] if (not unique_file) & (data is None): - raise ValueError('data can only be None when initialising a unique volume file') + raise ValueError("data can only be None when initialising a unique volume file") # modify data if unique_file: if data is None: - type_open = 'w' - data = ['subject'] + header + type_open = "w" + data = ["subject"] + header else: - type_open = 'a' + type_open = "a" data = [data] else: - type_open = 'w' - header = [''] + header + type_open = "w" + header = [""] + header data = [header, data] # write csv diff --git a/nobrainer/ext/SynthSeg/predict_denoiser.py b/nobrainer/ext/SynthSeg/predict_denoiser.py index 71b38388..69897d4a 100644 --- a/nobrainer/ext/SynthSeg/predict_denoiser.py +++ b/nobrainer/ext/SynthSeg/predict_denoiser.py @@ -13,62 +13,78 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -import keras.layers as KL -from keras.models import Model # project imports from SynthSeg import evaluate -from SynthSeg.predict import write_csv, postprocess +from SynthSeg.predict import postprocess, write_csv # third-party imports -from ext.lab2im import edit_volumes -from ext.lab2im import utils, layers +from ext.lab2im import edit_volumes, layers, utils from ext.neuron import models as nrn_models +import keras.layers as KL +from keras.models import Model +import numpy as np +import tensorflow as tf -def predict(path_predictions, - path_corrections, - path_model, - input_segmentation_labels, - target_segmentation_labels=None, - names_segmentation=None, - path_posteriors=None, - path_volumes=None, - min_pad=None, - cropping=None, - topology_classes=None, - sigma_smoothing=0.5, - keep_biggest_component=True, - n_levels=5, - nb_conv_per_level=2, - conv_size=5, - unet_feat_count=16, - feat_multiplier=2, - activation='elu', - skip_n_concatenations=2, - gt_folder=None, - evaluation_labels=None, - list_incorrect_labels=None, - list_correct_labels=None, - compute_distances=False, - recompute=True, - verbose=True): +def predict( + path_predictions, + path_corrections, + path_model, + input_segmentation_labels, + target_segmentation_labels=None, + names_segmentation=None, + path_posteriors=None, + path_volumes=None, + min_pad=None, + cropping=None, + topology_classes=None, + sigma_smoothing=0.5, + keep_biggest_component=True, + n_levels=5, + nb_conv_per_level=2, + conv_size=5, + unet_feat_count=16, + feat_multiplier=2, + activation="elu", + skip_n_concatenations=2, + gt_folder=None, + evaluation_labels=None, + list_incorrect_labels=None, + list_correct_labels=None, + compute_distances=False, + recompute=True, + verbose=True, +): # prepare input/output filepaths - path_predictions, path_corrections, path_posteriors, path_volumes, compute, unique_vol_file = \ - prepare_output_files(path_predictions, path_corrections, path_posteriors, path_volumes, recompute) + ( + path_predictions, + path_corrections, + path_posteriors, + path_volumes, + compute, + unique_vol_file, + ) = prepare_output_files( + path_predictions, path_corrections, path_posteriors, path_volumes, recompute + ) # get label list - input_segmentation_labels = utils.get_list_labels(label_list=input_segmentation_labels)[0] - input_segmentation_labels, unique_idx = np.unique(input_segmentation_labels, return_index=True) + input_segmentation_labels = utils.get_list_labels( + label_list=input_segmentation_labels + )[0] + input_segmentation_labels, unique_idx = np.unique( + input_segmentation_labels, return_index=True + ) if target_segmentation_labels is not None: - target_segmentation_labels = utils.get_list_labels(label_list=target_segmentation_labels)[0] - target_segmentation_labels, unique_idx = np.unique(target_segmentation_labels, return_index=True) + target_segmentation_labels = utils.get_list_labels( + label_list=target_segmentation_labels + )[0] + target_segmentation_labels, unique_idx = np.unique( + target_segmentation_labels, return_index=True + ) else: target_segmentation_labels = input_segmentation_labels @@ -76,39 +92,45 @@ def predict(path_predictions, if names_segmentation is not None: names_segmentation = utils.load_array_if_path(names_segmentation)[unique_idx] if topology_classes is not None: - topology_classes = utils.load_array_if_path(topology_classes, load_as_numpy=True)[unique_idx] + topology_classes = utils.load_array_if_path( + topology_classes, load_as_numpy=True + )[unique_idx] # prepare volumes if necessary if unique_vol_file & (path_volumes[0] is not None): - write_csv(path_volumes[0], None, True, target_segmentation_labels, names_segmentation) + write_csv( + path_volumes[0], None, True, target_segmentation_labels, names_segmentation + ) # build network _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_predictions[0]) model_input_shape = [None] * n_dims + [n_channels] - net = build_model(path_model=path_model, - input_shape=model_input_shape, - input_label_list=input_segmentation_labels, - target_label_list=target_segmentation_labels, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - skip_n_concatenations=skip_n_concatenations, - sigma_smoothing=sigma_smoothing) + net = build_model( + path_model=path_model, + input_shape=model_input_shape, + input_label_list=input_segmentation_labels, + target_label_list=target_segmentation_labels, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + skip_n_concatenations=skip_n_concatenations, + sigma_smoothing=sigma_smoothing, + ) # set cropping/padding if (cropping is not None) & (min_pad is not None): - cropping = utils.reformat_to_list(cropping, length=n_dims, dtype='int') - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') + cropping = utils.reformat_to_list(cropping, length=n_dims, dtype="int") + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") min_pad = np.minimum(cropping, min_pad) # perform segmentation if len(path_predictions) <= 10: - loop_info = utils.LoopInfo(len(path_predictions), 1, 'predicting', True) + loop_info = utils.LoopInfo(len(path_predictions), 1, "predicting", True) else: - loop_info = utils.LoopInfo(len(path_predictions), 10, 'predicting', True) + loop_info = utils.LoopInfo(len(path_predictions), 10, "predicting", True) for i in range(len(path_predictions)): if verbose: loop_info.update(i) @@ -117,37 +139,51 @@ def predict(path_predictions, if compute[i]: # preprocessing - prediction, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_prediction=path_predictions[i], - n_levels=n_levels, - crop=cropping, - min_pad=min_pad) + prediction, aff, h, im_res, shape, pad_idx, crop_idx = preprocess( + path_prediction=path_predictions[i], + n_levels=n_levels, + crop=cropping, + min_pad=min_pad, + ) # prediction post_patch = net.predict(prediction) # postprocessing - seg, posteriors, volumes = postprocess(post_patch=post_patch, - shape=shape, - pad_idx=pad_idx, - crop_idx=crop_idx, - n_dims=n_dims, - labels_segmentation=target_segmentation_labels, - keep_biggest_component=keep_biggest_component, - aff=aff, - im_res=im_res, - topology_classes=topology_classes) + seg, posteriors, volumes = postprocess( + post_patch=post_patch, + shape=shape, + pad_idx=pad_idx, + crop_idx=crop_idx, + n_dims=n_dims, + labels_segmentation=target_segmentation_labels, + keep_biggest_component=keep_biggest_component, + aff=aff, + im_res=im_res, + topology_classes=topology_classes, + ) # write results to disk - utils.save_volume(seg, aff, h, path_corrections[i], dtype='int32') + utils.save_volume(seg, aff, h, path_corrections[i], dtype="int32") if path_posteriors[i] is not None: if n_channels > 1: posteriors = utils.add_axis(posteriors, axis=[0, -1]) - utils.save_volume(posteriors, aff, h, path_posteriors[i], dtype='float32') + utils.save_volume( + posteriors, aff, h, path_posteriors[i], dtype="float32" + ) # compute volumes if path_volumes[i] is not None: - row = [os.path.basename(path_predictions[i]).replace('.nii.gz', '')] + [str(vol) for vol in volumes] - write_csv(path_volumes[i], row, unique_vol_file, target_segmentation_labels, names_segmentation) + row = [os.path.basename(path_predictions[i]).replace(".nii.gz", "")] + [ + str(vol) for vol in volumes + ] + write_csv( + path_volumes[i], + row, + unique_vol_file, + target_segmentation_labels, + names_segmentation, + ) # evaluate if gt_folder is not None: @@ -159,56 +195,79 @@ def predict(path_predictions, # set path of result arrays for surface distance if necessary if compute_distances: - path_hausdorff = os.path.join(eval_folder, 'hausdorff.npy') - path_hausdorff_99 = os.path.join(eval_folder, 'hausdorff_99.npy') - path_hausdorff_95 = os.path.join(eval_folder, 'hausdorff_95.npy') - path_mean_distance = os.path.join(eval_folder, 'mean_distance.npy') + path_hausdorff = os.path.join(eval_folder, "hausdorff.npy") + path_hausdorff_99 = os.path.join(eval_folder, "hausdorff_99.npy") + path_hausdorff_95 = os.path.join(eval_folder, "hausdorff_95.npy") + path_mean_distance = os.path.join(eval_folder, "mean_distance.npy") else: - path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = path_mean_distance = None + path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = ( + path_mean_distance + ) = None # compute evaluation metrics - evaluate.evaluation(gt_folder, - eval_folder, - evaluation_labels, - path_dice=os.path.join(eval_folder, 'dice.npy'), - path_hausdorff=path_hausdorff, - path_hausdorff_99=path_hausdorff_99, - path_hausdorff_95=path_hausdorff_95, - path_mean_distance=path_mean_distance, - list_incorrect_labels=list_incorrect_labels, - list_correct_labels=list_correct_labels, - recompute=recompute, - verbose=verbose) - - -def prepare_output_files(path_predictions, out_corrections, out_posteriors, out_volumes, recompute): + evaluate.evaluation( + gt_folder, + eval_folder, + evaluation_labels, + path_dice=os.path.join(eval_folder, "dice.npy"), + path_hausdorff=path_hausdorff, + path_hausdorff_99=path_hausdorff_99, + path_hausdorff_95=path_hausdorff_95, + path_mean_distance=path_mean_distance, + list_incorrect_labels=list_incorrect_labels, + list_correct_labels=list_correct_labels, + recompute=recompute, + verbose=verbose, + ) + + +def prepare_output_files( + path_predictions, out_corrections, out_posteriors, out_volumes, recompute +): # check inputs - assert path_predictions is not None, 'please specify an input file/folder (--i)' - assert out_corrections is not None, 'please specify an output file/folder (--o)' + assert path_predictions is not None, "please specify an input file/folder (--i)" + assert out_corrections is not None, "please specify an output file/folder (--o)" # convert path to absolute paths path_predictions = os.path.abspath(path_predictions) basename = os.path.basename(path_predictions) out_corrections = os.path.abspath(out_corrections) - out_posteriors = os.path.abspath(out_posteriors) if (out_posteriors is not None) else out_posteriors - out_volumes = os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + out_posteriors = ( + os.path.abspath(out_posteriors) + if (out_posteriors is not None) + else out_posteriors + ) + out_volumes = ( + os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + ) # path_images is a text file - if basename[-4:] == '.txt': + if basename[-4:] == ".txt": # input predictions if not os.path.isfile(path_predictions): - raise Exception('provided text file containing paths of input prediction does not exist' % path_predictions) - with open(path_predictions, 'r') as f: - path_predictions = [line.replace('\n', '') for line in f.readlines() if line != '\n'] + raise Exception( + "provided text file containing paths of input prediction does not exist" + % path_predictions + ) + with open(path_predictions, "r") as f: + path_predictions = [ + line.replace("\n", "") for line in f.readlines() if line != "\n" + ] # define helper to deal with outputs def text_helper(path, name): if path is not None: - assert path[-4:] == '.txt', 'if path_predictions given as text file, so must be %s' % name - with open(path, 'r') as ff: - path = [line.replace('\n', '') for line in ff.readlines() if line != '\n'] + assert path[-4:] == ".txt", ( + "if path_predictions given as text file, so must be %s" % name + ) + with open(path, "r") as ff: + path = [ + line.replace("\n", "") + for line in ff.readlines() + if line != "\n" + ] recompute_files = [not os.path.isfile(p) for p in path] else: path = [None] * len(path_predictions) @@ -217,37 +276,64 @@ def text_helper(path, name): return path, recompute_files, unique_file # use helper on all outputs - out_corrections, recompute_cor, _ = text_helper(out_corrections, 'path_corrections') - out_posteriors, recompute_post, _ = text_helper(out_posteriors, 'path_posteriors') - out_volumes, recompute_volume, unique_volume_file = text_helper(out_volumes, 'path_volume') + out_corrections, recompute_cor, _ = text_helper( + out_corrections, "path_corrections" + ) + out_posteriors, recompute_post, _ = text_helper( + out_posteriors, "path_posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = text_helper( + out_volumes, "path_volume" + ) # path_predictions is a folder - elif ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename): + elif ( + (".nii.gz" not in basename) + & (".nii" not in basename) + & (".mgz" not in basename) + & (".npz" not in basename) + ): # input predictions if os.path.isfile(path_predictions): - raise Exception('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_predictions) + raise Exception( + "Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz" + % path_predictions + ) path_predictions = utils.list_images_in_folder(path_predictions) # define helper to deal with outputs def helper_dir(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_predictions is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_predictions is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" path = [path] * len(path_predictions) recompute_files = [True] * len(path_predictions) unique_file = True else: - if (path[-7:] == '.nii.gz') | (path[-4:] == '.nii') | (path[-4:] == '.mgz') | (path[-4:] == '.npz'): - raise Exception('Output FOLDER had a FILE extension' % path) - path = [os.path.join(path, os.path.basename(p)) for p in path_predictions] - path = [p.replace('.nii', '_%s.nii' % suffix) for p in path] - path = [p.replace('.mgz', '_%s.mgz' % suffix) for p in path] - path = [p.replace('.npz', '_%s.npz' % suffix) for p in path] + if ( + (path[-7:] == ".nii.gz") + | (path[-4:] == ".nii") + | (path[-4:] == ".mgz") + | (path[-4:] == ".npz") + ): + raise Exception("Output FOLDER had a FILE extension" % path) + path = [ + os.path.join(path, os.path.basename(p)) + for p in path_predictions + ] + path = [p.replace(".nii", "_%s.nii" % suffix) for p in path] + path = [p.replace(".mgz", "_%s.mgz" % suffix) for p in path] + path = [p.replace(".npz", "_%s.npz" % suffix) for p in path] recompute_files = [not os.path.isfile(p) for p in path] utils.mkdir(os.path.dirname(path[0])) else: @@ -256,34 +342,54 @@ def helper_dir(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_corrections, recompute_cor, _ = helper_dir(out_corrections, 'path_corrections', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_dir(out_posteriors, 'path_posteriors', '', 'posteriors') - out_volumes, recompute_volume, unique_volume_file = helper_dir(out_volumes, 'path_volumes', 'csv', '') + out_corrections, recompute_cor, _ = helper_dir( + out_corrections, "path_corrections", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_dir( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = helper_dir( + out_volumes, "path_volumes", "csv", "" + ) # path_predictions is an image else: # input prediction - assert os.path.isfile(path_predictions), 'file does not exist: %s \nplease make sure the path and ' \ - 'the extension are correct' % path_predictions + assert os.path.isfile(path_predictions), ( + "file does not exist: %s \nplease make sure the path and " + "the extension are correct" % path_predictions + ) path_predictions = [path_predictions] # define helper to deal with outputs def helper_im(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_predictions is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_predictions is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" recompute_files = [True] unique_file = True else: - if ('.nii.gz' not in path) & ('.nii' not in path) & ('.mgz' not in path) & ('.npz' not in path): - file_name = os.path.basename(path_predictions[0]).replace('.nii', '_%s.nii' % suffix) - file_name = file_name.replace('.mgz', '_%s.mgz' % suffix) - file_name = file_name.replace('.npz', '_%s.npz' % suffix) + if ( + (".nii.gz" not in path) + & (".nii" not in path) + & (".mgz" not in path) + & (".npz" not in path) + ): + file_name = os.path.basename(path_predictions[0]).replace( + ".nii", "_%s.nii" % suffix + ) + file_name = file_name.replace(".mgz", "_%s.mgz" % suffix) + file_name = file_name.replace(".npz", "_%s.npz" % suffix) path = os.path.join(path, file_name) recompute_files = [not os.path.isfile(path)] utils.mkdir(os.path.dirname(path)) @@ -293,58 +399,95 @@ def helper_im(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_corrections, recompute_cor, _ = helper_im(out_corrections, 'path_corrections', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_im(out_posteriors, 'path_posteriors', '', 'posteriors') - out_volumes, recompute_volume, unique_volume_file = helper_im(out_volumes, 'path_volumes', 'csv', '') - - recompute_list = [recompute | re_cor | re_post | re_vol for (re_cor, re_post, re_vol) - in zip(recompute_cor, recompute_post, recompute_volume)] - - return path_predictions, out_corrections, out_posteriors, out_volumes, recompute_list, unique_volume_file + out_corrections, recompute_cor, _ = helper_im( + out_corrections, "path_corrections", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_im( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = helper_im( + out_volumes, "path_volumes", "csv", "" + ) + + recompute_list = [ + recompute | re_cor | re_post | re_vol + for (re_cor, re_post, re_vol) in zip( + recompute_cor, recompute_post, recompute_volume + ) + ] + + return ( + path_predictions, + out_corrections, + out_posteriors, + out_volumes, + recompute_list, + unique_volume_file, + ) def preprocess(path_prediction, n_levels, crop=None, min_pad=None): # read image and corresponding info - pred, _, aff_pred, n_dims, _, h_pred, res_pred = utils.get_volume_info(path_prediction, True) + pred, _, aff_pred, n_dims, _, h_pred, res_pred = utils.get_volume_info( + path_prediction, True + ) # align image - pred = edit_volumes.align_volume_to_ref(pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims) + pred = edit_volumes.align_volume_to_ref( + pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims + ) shape = list(pred.shape[:n_dims]) # crop image if necessary if crop is not None: - crop = utils.reformat_to_list(crop, length=n_dims, dtype='int') - input_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop] - pred, crop_idx = edit_volumes.crop_volume(pred, cropping_shape=input_shape, return_crop_idx=True) + crop = utils.reformat_to_list(crop, length=n_dims, dtype="int") + input_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in crop + ] + pred, crop_idx = edit_volumes.crop_volume( + pred, cropping_shape=input_shape, return_crop_idx=True + ) else: crop_idx = None - input_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in shape] + input_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in shape + ] # pad image - if min_pad is not None: # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') + if ( + min_pad is not None + ): # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") input_shape = np.maximum(input_shape, min_pad) - pred, pad_idx = edit_volumes.pad_volume(pred, padding_shape=input_shape, return_pad_idx=True) + pred, pad_idx = edit_volumes.pad_volume( + pred, padding_shape=input_shape, return_pad_idx=True + ) # add batch and channel axes - pred = utils.add_axis(pred, axis=0) # channel axis will be added later when computing one-hot + pred = utils.add_axis( + pred, axis=0 + ) # channel axis will be added later when computing one-hot return pred, aff_pred, h_pred, res_pred, shape, pad_idx, crop_idx -def build_model(path_model, - input_shape, - input_label_list, - target_label_list, - n_levels, - nb_conv_per_level, - conv_size, - unet_feat_count, - feat_multiplier, - activation, - skip_n_concatenations, - sigma_smoothing): +def build_model( + path_model, + input_shape, + input_label_list, + target_label_list, + n_levels, + nb_conv_per_level, + conv_size, + unet_feat_count, + feat_multiplier, + activation, + skip_n_concatenations, + sigma_smoothing, +): assert os.path.isfile(path_model), "The provided model path does not exist." @@ -354,22 +497,28 @@ def build_model(path_model, # one-hot encoding of the input prediction as the network expects soft probabilities input_labels = KL.Input(input_shape[:-1]) labels = layers.ConvertLabels(input_label_list)(input_labels) - labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=len(input_label_list), axis=-1))(labels) + labels = KL.Lambda( + lambda x: tf.one_hot( + tf.cast(x, dtype="int32"), depth=len(input_label_list), axis=-1 + ) + )(labels) net = Model(inputs=input_labels, outputs=labels) # build UNet - net = nrn_models.unet(input_model=net, - input_shape=input_shape, - nb_labels=n_labels_seg, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - skip_n_concatenations=skip_n_concatenations, - name='l2l') + net = nrn_models.unet( + input_model=net, + input_shape=input_shape, + nb_labels=n_labels_seg, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + skip_n_concatenations=skip_n_concatenations, + name="l2l", + ) net.load_weights(path_model, by_name=True) # smooth posteriors if specified diff --git a/nobrainer/ext/SynthSeg/predict_group.py b/nobrainer/ext/SynthSeg/predict_group.py index 22d6126a..baae2438 100644 --- a/nobrainer/ext/SynthSeg/predict_group.py +++ b/nobrainer/ext/SynthSeg/predict_group.py @@ -13,56 +13,69 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -import keras.layers as KL -from keras.models import Model # project imports from SynthSeg import evaluate from SynthSeg.predict import write_csv # third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, layers, utils from ext.neuron import models as nrn_models +import keras.layers as KL +from keras.models import Model +import numpy as np +import tensorflow as tf -def predict(path_images, - path_masks, - path_segmentations, - path_model, - labels_segmentation, - labels_mask, - path_posteriors=None, - path_volumes=None, - names_segmentation=None, - min_pad=None, - cropping=None, - sigma_smoothing=0.5, - strict_masking=False, - keep_biggest_component=True, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - gt_folder=None, - evaluation_labels=None, - list_incorrect_labels=None, - list_correct_labels=None, - compute_distances=False, - recompute=True, - verbose=True): +def predict( + path_images, + path_masks, + path_segmentations, + path_model, + labels_segmentation, + labels_mask, + path_posteriors=None, + path_volumes=None, + names_segmentation=None, + min_pad=None, + cropping=None, + sigma_smoothing=0.5, + strict_masking=False, + keep_biggest_component=True, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + gt_folder=None, + evaluation_labels=None, + list_incorrect_labels=None, + list_correct_labels=None, + compute_distances=False, + recompute=True, + verbose=True, +): # prepare input/output filepaths - path_images, path_masks, path_segmentations, path_posteriors, path_volumes, compute, unique_vol_file = \ - prepare_output_files(path_images, path_masks, path_segmentations, path_posteriors, path_volumes, recompute) + ( + path_images, + path_masks, + path_segmentations, + path_posteriors, + path_volumes, + compute, + unique_vol_file, + ) = prepare_output_files( + path_images, + path_masks, + path_segmentations, + path_posteriors, + path_volumes, + recompute, + ) # get label list labels_mask, _ = utils.get_list_labels(label_list=labels_mask) @@ -77,29 +90,31 @@ def predict(path_images, # build network _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_images[0]) model_input_shape = [None] * n_dims + [n_channels] - net = build_model(path_model=path_model, - input_shape=model_input_shape, - labels_segmentation=labels_segmentation, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - sigma_smoothing=sigma_smoothing, - mask_labels=mask_labels_unique) + net = build_model( + path_model=path_model, + input_shape=model_input_shape, + labels_segmentation=labels_segmentation, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + sigma_smoothing=sigma_smoothing, + mask_labels=mask_labels_unique, + ) # set cropping/padding if (cropping is not None) & (min_pad is not None): - cropping = utils.reformat_to_list(cropping, length=n_dims, dtype='int') - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') + cropping = utils.reformat_to_list(cropping, length=n_dims, dtype="int") + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") min_pad = np.minimum(cropping, min_pad) # perform segmentation if len(path_images) <= 10: - loop_info = utils.LoopInfo(len(path_images), 1, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 1, "predicting", True) else: - loop_info = utils.LoopInfo(len(path_images), 10, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "predicting", True) for i in range(len(path_images)): if verbose: loop_info.update(i) @@ -108,39 +123,53 @@ def predict(path_images, if compute[i]: # preprocessing - image, mask, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=path_images[i], - path_mask=path_masks[i], - n_levels=n_levels, - crop=cropping, - min_pad=min_pad) + image, mask, aff, h, im_res, shape, pad_idx, crop_idx = preprocess( + path_image=path_images[i], + path_mask=path_masks[i], + n_levels=n_levels, + crop=cropping, + min_pad=min_pad, + ) # prediction post_patch = net.predict([image, mask]) # postprocessing - seg, posteriors, volumes = postprocess(post_patch=post_patch, - mask=mask, - shape=shape, - pad_idx=pad_idx, - crop_idx=crop_idx, - n_dims=n_dims, - labels_segmentation=labels_segmentation, - strict_masking=strict_masking, - keep_biggest_component=keep_biggest_component, - aff=aff, - im_res=im_res) + seg, posteriors, volumes = postprocess( + post_patch=post_patch, + mask=mask, + shape=shape, + pad_idx=pad_idx, + crop_idx=crop_idx, + n_dims=n_dims, + labels_segmentation=labels_segmentation, + strict_masking=strict_masking, + keep_biggest_component=keep_biggest_component, + aff=aff, + im_res=im_res, + ) # write results to disk - utils.save_volume(seg, aff, h, path_segmentations[i], dtype='int32') + utils.save_volume(seg, aff, h, path_segmentations[i], dtype="int32") if path_posteriors[i] is not None: if n_channels > 1: posteriors = utils.add_axis(posteriors, axis=[0, -1]) - utils.save_volume(posteriors, aff, h, path_posteriors[i], dtype='float32') + utils.save_volume( + posteriors, aff, h, path_posteriors[i], dtype="float32" + ) # compute volumes if path_volumes[i] is not None: - row = [os.path.basename(path_images[i]).replace('.nii.gz', '')] + [str(vol) for vol in volumes] - write_csv(path_volumes[i], row, unique_vol_file, labels_segmentation, names_segmentation) + row = [os.path.basename(path_images[i]).replace(".nii.gz", "")] + [ + str(vol) for vol in volumes + ] + write_csv( + path_volumes[i], + row, + unique_vol_file, + labels_segmentation, + names_segmentation, + ) # evaluate if gt_folder is not None: @@ -152,64 +181,92 @@ def predict(path_images, # set path of result arrays for surface distance if necessary if compute_distances: - path_hausdorff = os.path.join(eval_folder, 'hausdorff.npy') - path_hausdorff_99 = os.path.join(eval_folder, 'hausdorff_99.npy') - path_hausdorff_95 = os.path.join(eval_folder, 'hausdorff_95.npy') - path_mean_distance = os.path.join(eval_folder, 'mean_distance.npy') + path_hausdorff = os.path.join(eval_folder, "hausdorff.npy") + path_hausdorff_99 = os.path.join(eval_folder, "hausdorff_99.npy") + path_hausdorff_95 = os.path.join(eval_folder, "hausdorff_95.npy") + path_mean_distance = os.path.join(eval_folder, "mean_distance.npy") else: - path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = path_mean_distance = None + path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = ( + path_mean_distance + ) = None # compute evaluation metrics - evaluate.evaluation(gt_folder, - eval_folder, - evaluation_labels, - path_dice=os.path.join(eval_folder, 'dice.npy'), - path_hausdorff=path_hausdorff, - path_hausdorff_99=path_hausdorff_99, - path_hausdorff_95=path_hausdorff_95, - path_mean_distance=path_mean_distance, - list_incorrect_labels=list_incorrect_labels, - list_correct_labels=list_correct_labels, - recompute=recompute, - verbose=verbose) - - -def prepare_output_files(path_images, path_masks, out_seg, out_posteriors, out_volumes, recompute): + evaluate.evaluation( + gt_folder, + eval_folder, + evaluation_labels, + path_dice=os.path.join(eval_folder, "dice.npy"), + path_hausdorff=path_hausdorff, + path_hausdorff_99=path_hausdorff_99, + path_hausdorff_95=path_hausdorff_95, + path_mean_distance=path_mean_distance, + list_incorrect_labels=list_incorrect_labels, + list_correct_labels=list_correct_labels, + recompute=recompute, + verbose=verbose, + ) + + +def prepare_output_files( + path_images, path_masks, out_seg, out_posteriors, out_volumes, recompute +): # check inputs - assert path_images is not None, 'please specify an input file/folder (--i)' - assert path_masks is not None, 'please specify an input file/folder (--i)' - assert out_seg is not None, 'please specify an output file/folder (--o)' + assert path_images is not None, "please specify an input file/folder (--i)" + assert path_masks is not None, "please specify an input file/folder (--i)" + assert out_seg is not None, "please specify an output file/folder (--o)" # convert path to absolute paths path_images = os.path.abspath(path_images) path_masks = os.path.abspath(path_masks) basename = os.path.basename(path_images) out_seg = os.path.abspath(out_seg) - out_posteriors = os.path.abspath(out_posteriors) if (out_posteriors is not None) else out_posteriors - out_volumes = os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + out_posteriors = ( + os.path.abspath(out_posteriors) + if (out_posteriors is not None) + else out_posteriors + ) + out_volumes = ( + os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + ) # path_images is a text file - if basename[-4:] == '.txt': + if basename[-4:] == ".txt": # input images if not os.path.isfile(path_images): - raise Exception('provided text file containing paths of input images does not exist' % path_images) - with open(path_images, 'r') as f: - path_images = [line.replace('\n', '') for line in f.readlines() if line != '\n'] + raise Exception( + "provided text file containing paths of input images does not exist" + % path_images + ) + with open(path_images, "r") as f: + path_images = [ + line.replace("\n", "") for line in f.readlines() if line != "\n" + ] # masks if not os.path.isfile(path_masks): - raise Exception('provided text file containing paths of input images does not exist' % path_masks) - with open(path_masks, 'r') as f: - path_masks = [line.replace('\n', '') for line in f.readlines() if line != '\n'] + raise Exception( + "provided text file containing paths of input images does not exist" + % path_masks + ) + with open(path_masks, "r") as f: + path_masks = [ + line.replace("\n", "") for line in f.readlines() if line != "\n" + ] # define helper to deal with outputs def text_helper(path, name): if path is not None: - assert path[-4:] == '.txt', 'if path_images given as text file, so must be %s' % name - with open(path, 'r') as ff: - path = [line.replace('\n', '') for line in ff.readlines() if line != '\n'] + assert path[-4:] == ".txt", ( + "if path_images given as text file, so must be %s" % name + ) + with open(path, "r") as ff: + path = [ + line.replace("\n", "") + for line in ff.readlines() + if line != "\n" + ] recompute_files = [not os.path.isfile(p) for p in path] else: path = [None] * len(path_images) @@ -218,42 +275,69 @@ def text_helper(path, name): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = text_helper(out_seg, 'path_segmentations') - out_posteriors, recompute_post, _ = text_helper(out_posteriors, 'path_posteriors') - out_volumes, recompute_volume, unique_volume_file = text_helper(out_volumes, 'path_volume') + out_seg, recompute_seg, _ = text_helper(out_seg, "path_segmentations") + out_posteriors, recompute_post, _ = text_helper( + out_posteriors, "path_posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = text_helper( + out_volumes, "path_volume" + ) # path_images is a folder - elif ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename): + elif ( + (".nii.gz" not in basename) + & (".nii" not in basename) + & (".mgz" not in basename) + & (".npz" not in basename) + ): # input images if os.path.isfile(path_images): - raise Exception('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_images) + raise Exception( + "Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz" + % path_images + ) path_images = utils.list_images_in_folder(path_images) # masks if os.path.isfile(path_masks): - raise Exception('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_masks) + raise Exception( + "Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz" + % path_masks + ) path_masks = utils.list_images_in_folder(path_masks) # define helper to deal with outputs def helper_dir(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" path = [path] * len(path_images) recompute_files = [True] * len(path_images) unique_file = True else: - if (path[-7:] == '.nii.gz') | (path[-4:] == '.nii') | (path[-4:] == '.mgz') | (path[-4:] == '.npz'): - raise Exception('Output FOLDER had a FILE extension' % path) - path = [os.path.join(path, os.path.basename(p)) for p in path_images] - path = [p.replace('.nii', '_%s.nii' % suffix) for p in path] - path = [p.replace('.mgz', '_%s.mgz' % suffix) for p in path] - path = [p.replace('.npz', '_%s.npz' % suffix) for p in path] + if ( + (path[-7:] == ".nii.gz") + | (path[-4:] == ".nii") + | (path[-4:] == ".mgz") + | (path[-4:] == ".npz") + ): + raise Exception("Output FOLDER had a FILE extension" % path) + path = [ + os.path.join(path, os.path.basename(p)) for p in path_images + ] + path = [p.replace(".nii", "_%s.nii" % suffix) for p in path] + path = [p.replace(".mgz", "_%s.mgz" % suffix) for p in path] + path = [p.replace(".npz", "_%s.npz" % suffix) for p in path] recompute_files = [not os.path.isfile(p) for p in path] utils.mkdir(os.path.dirname(path[0])) else: @@ -262,39 +346,61 @@ def helper_dir(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_dir(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_dir(out_posteriors, 'path_posteriors', '', 'posteriors') - out_volumes, recompute_volume, unique_volume_file = helper_dir(out_volumes, 'path_volumes', 'csv', '') + out_seg, recompute_seg, _ = helper_dir( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_dir( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = helper_dir( + out_volumes, "path_volumes", "csv", "" + ) # path_images is an image else: # input images - assert os.path.isfile(path_images), 'file does not exist: %s \n' \ - 'please make sure the path and the extension are correct' % path_images + assert os.path.isfile(path_images), ( + "file does not exist: %s \n" + "please make sure the path and the extension are correct" % path_images + ) path_images = [path_images] # masks - assert os.path.isfile(path_masks), 'file does not exist: %s \n' \ - 'please make sure the path and the extension are correct' % path_masks + assert os.path.isfile(path_masks), ( + "file does not exist: %s \n" + "please make sure the path and the extension are correct" % path_masks + ) path_masks = [path_masks] # define helper to deal with outputs def helper_im(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" recompute_files = [True] unique_file = True else: - if ('.nii.gz' not in path) & ('.nii' not in path) & ('.mgz' not in path) & ('.npz' not in path): - file_name = os.path.basename(path_images[0]).replace('.nii', '_%s.nii' % suffix) - file_name = file_name.replace('.mgz', '_%s.mgz' % suffix) - file_name = file_name.replace('.npz', '_%s.npz' % suffix) + if ( + (".nii.gz" not in path) + & (".nii" not in path) + & (".mgz" not in path) + & (".npz" not in path) + ): + file_name = os.path.basename(path_images[0]).replace( + ".nii", "_%s.nii" % suffix + ) + file_name = file_name.replace(".mgz", "_%s.mgz" % suffix) + file_name = file_name.replace(".npz", "_%s.npz" % suffix) path = os.path.join(path, file_name) recompute_files = [not os.path.isfile(path)] utils.mkdir(os.path.dirname(path)) @@ -304,14 +410,32 @@ def helper_im(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_im(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_im(out_posteriors, 'path_posteriors', '', 'posteriors') - out_volumes, recompute_volume, unique_volume_file = helper_im(out_volumes, 'path_volumes', 'csv', '') - - recompute_list = [recompute | re_seg | re_post | re_vol for (re_seg, re_post, re_vol) - in zip(recompute_seg, recompute_post, recompute_volume)] - - return path_images, path_masks, out_seg, out_posteriors, out_volumes, recompute_list, unique_volume_file + out_seg, recompute_seg, _ = helper_im( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_im( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_volumes, recompute_volume, unique_volume_file = helper_im( + out_volumes, "path_volumes", "csv", "" + ) + + recompute_list = [ + recompute | re_seg | re_post | re_vol + for (re_seg, re_post, re_vol) in zip( + recompute_seg, recompute_post, recompute_volume + ) + ] + + return ( + path_images, + path_masks, + out_seg, + out_posteriors, + out_volumes, + recompute_list, + unique_volume_file, + ) def preprocess(path_image, path_mask, n_levels, crop=None, min_pad=None): @@ -321,55 +445,85 @@ def preprocess(path_image, path_mask, n_levels, crop=None, min_pad=None): mask = utils.load_volume(path_mask, True) # align image - im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False) - mask = edit_volumes.align_volume_to_ref(mask, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False) + im = edit_volumes.align_volume_to_ref( + im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False + ) + mask = edit_volumes.align_volume_to_ref( + mask, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False + ) shape = list(im.shape[:n_dims]) # crop image if necessary if crop is not None: - crop = utils.reformat_to_list(crop, length=n_dims, dtype='int') - crop_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop] - im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True) + crop = utils.reformat_to_list(crop, length=n_dims, dtype="int") + crop_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in crop + ] + im, crop_idx = edit_volumes.crop_volume( + im, cropping_shape=crop_shape, return_crop_idx=True + ) mask = edit_volumes.crop_volume_with_idx(mask, crop_idx, n_dims=n_dims) else: crop_idx = None # normalise image if n_channels == 1: - im = edit_volumes.rescale_volume(im, new_min=0., new_max=1., min_percentile=0.5, max_percentile=99.5) + im = edit_volumes.rescale_volume( + im, new_min=0.0, new_max=1.0, min_percentile=0.5, max_percentile=99.5 + ) else: for i in range(im.shape[-1]): - im[..., i] = edit_volumes.rescale_volume(im[..., i], new_min=0., new_max=1., - min_percentile=0.5, max_percentile=99.5) + im[..., i] = edit_volumes.rescale_volume( + im[..., i], + new_min=0.0, + new_max=1.0, + min_percentile=0.5, + max_percentile=99.5, + ) # pad image input_shape = im.shape[:n_dims] - pad_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape] - if min_pad is not None: # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') - min_pad = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad] + pad_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in input_shape + ] + if ( + min_pad is not None + ): # in SynthSeg predict use crop flag and then if used do min_pad=crop else min_pad = 192 + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") + min_pad = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in min_pad + ] pad_shape = np.maximum(pad_shape, min_pad) - im, pad_idx = edit_volumes.pad_volume(im, padding_shape=pad_shape, return_pad_idx=True) + im, pad_idx = edit_volumes.pad_volume( + im, padding_shape=pad_shape, return_pad_idx=True + ) mask = edit_volumes.pad_volume(mask, padding_shape=pad_shape) # add batch and channel axes im = utils.add_axis(im) if n_channels > 1 else utils.add_axis(im, axis=[0, -1]) - mask = utils.add_axis(mask, axis=0) # channel axis will be added later when computing one-hot + mask = utils.add_axis( + mask, axis=0 + ) # channel axis will be added later when computing one-hot return im, mask, aff, h, im_res, shape, pad_idx, crop_idx -def build_model(path_model, - input_shape, - labels_segmentation, - n_levels, - nb_conv_per_level, - conv_size, - unet_feat_count, - feat_multiplier, - activation, - sigma_smoothing, - mask_labels): +def build_model( + path_model, + input_shape, + labels_segmentation, + n_levels, + nb_conv_per_level, + conv_size, + unet_feat_count, + feat_multiplier, + activation, + sigma_smoothing, + mask_labels, +): assert os.path.isfile(path_model), "The provided model path does not exist." @@ -378,22 +532,28 @@ def build_model(path_model, # one-hot encoding of the input prediction as the network expects soft probabilities input_image = KL.Input(input_shape) - input_labels = KL.Input(input_shape[:-1], dtype='int32') - labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=len(mask_labels), axis=-1))(input_labels) - image = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, labels]) + input_labels = KL.Input(input_shape[:-1], dtype="int32") + labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=len(mask_labels), axis=-1) + )(input_labels) + image = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), "float32"))( + [input_image, labels] + ) net = Model(inputs=[input_image, input_labels], outputs=image) # build UNet - net = nrn_models.unet(input_model=net, - input_shape=input_shape, - nb_labels=n_labels_seg, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1) + net = nrn_models.unet( + input_model=net, + input_shape=input_shape, + nb_labels=n_labels_seg, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + ) net.load_weights(path_model, by_name=True) # smooth posteriors if specified @@ -406,8 +566,19 @@ def build_model(path_model, return net -def postprocess(post_patch, mask, shape, pad_idx, crop_idx, n_dims, - labels_segmentation, strict_masking, keep_biggest_component, aff, im_res): +def postprocess( + post_patch, + mask, + shape, + pad_idx, + crop_idx, + n_dims, + labels_segmentation, + strict_masking, + keep_biggest_component, + aff, + im_res, +): # get posteriors post_patch = np.squeeze(post_patch) @@ -416,45 +587,74 @@ def postprocess(post_patch, mask, shape, pad_idx, crop_idx, n_dims, # reset posteriors of background to 1 outside mask and to 0 inside if strict_masking: post_patch[..., 0] = np.ones_like(post_patch[..., 0]) - post_patch[..., 0] = edit_volumes.mask_volume(post_patch[..., 0], mask=mask < 0.1, return_copy=False) + post_patch[..., 0] = edit_volumes.mask_volume( + post_patch[..., 0], mask=mask < 0.1, return_copy=False + ) # keep biggest connected component (use it with smoothing!) elif keep_biggest_component: tmp_post_patch = post_patch[..., 1:] post_patch_mask = np.sum(tmp_post_patch, axis=-1) > 0.25 post_patch_mask = edit_volumes.get_largest_connected_component(post_patch_mask) - post_patch_mask = np.stack([post_patch_mask]*tmp_post_patch.shape[-1], axis=-1) - tmp_post_patch = edit_volumes.mask_volume(tmp_post_patch, mask=post_patch_mask, return_copy=False) + post_patch_mask = np.stack( + [post_patch_mask] * tmp_post_patch.shape[-1], axis=-1 + ) + tmp_post_patch = edit_volumes.mask_volume( + tmp_post_patch, mask=post_patch_mask, return_copy=False + ) post_patch[..., 1:] = tmp_post_patch # normalise posteriors and get hard segmentation if strict_masking | keep_biggest_component: post_patch /= np.sum(post_patch, axis=-1)[..., np.newaxis] - seg_patch = labels_segmentation[post_patch.argmax(-1).astype('int32')].astype('int32') + seg_patch = labels_segmentation[post_patch.argmax(-1).astype("int32")].astype( + "int32" + ) # paste patches back to matrix of original image size - seg_patch = edit_volumes.crop_volume_with_idx(seg_patch, pad_idx, n_dims=n_dims, return_copy=False) - post_patch = edit_volumes.crop_volume_with_idx(post_patch, pad_idx, n_dims=n_dims, return_copy=False) + seg_patch = edit_volumes.crop_volume_with_idx( + seg_patch, pad_idx, n_dims=n_dims, return_copy=False + ) + post_patch = edit_volumes.crop_volume_with_idx( + post_patch, pad_idx, n_dims=n_dims, return_copy=False + ) if crop_idx is not None: # we need to go through this because of the posteriors of the background, otherwise pad_volume would work - seg = np.zeros(shape=shape, dtype='int32') + seg = np.zeros(shape=shape, dtype="int32") posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]]) posteriors[..., 0] = np.ones(shape) # place background around patch if n_dims == 2: - seg[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3]] = seg_patch - posteriors[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], :] = post_patch + seg[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3]] = seg_patch + posteriors[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], :] = ( + post_patch + ) elif n_dims == 3: - seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch - posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch + seg[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ] = seg_patch + posteriors[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + :, + ] = post_patch else: seg = seg_patch posteriors = post_patch # align prediction back to first orientation - seg = edit_volumes.align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False) - posteriors = edit_volumes.align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False) + seg = edit_volumes.align_volume_to_ref( + seg, aff=np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False + ) + posteriors = edit_volumes.align_volume_to_ref( + posteriors, np.eye(4), aff_ref=aff, n_dims=n_dims, return_copy=False + ) # compute volumes - volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1))) + volumes = np.sum( + posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)) + ) volumes = np.around(volumes * np.prod(im_res), 3) return seg, posteriors, volumes diff --git a/nobrainer/ext/SynthSeg/predict_qc.py b/nobrainer/ext/SynthSeg/predict_qc.py index d81efea8..4897a8d6 100644 --- a/nobrainer/ext/SynthSeg/predict_qc.py +++ b/nobrainer/ext/SynthSeg/predict_qc.py @@ -13,42 +13,43 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -import keras.layers as KL -from keras.models import Model # project imports from SynthSeg import evaluate # third-party imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, utils from ext.neuron import models as nrn_models +import keras.layers as KL +from keras.models import Model +import numpy as np +import tensorflow as tf -def predict(path_predictions, - path_qc_results, - path_model, - labels_list, - labels_to_convert=None, - convert_gt=False, - shape=224, - n_levels=5, - nb_conv_per_level=3, - conv_size=5, - unet_feat_count=24, - feat_multiplier=2, - activation='relu', - path_gts=None, - verbose=True): +def predict( + path_predictions, + path_qc_results, + path_model, + labels_list, + labels_to_convert=None, + convert_gt=False, + shape=224, + n_levels=5, + nb_conv_per_level=3, + conv_size=5, + unet_feat_count=24, + feat_multiplier=2, + activation="relu", + path_gts=None, + verbose=True, +): # prepare input/output filepaths - path_predictions, path_gts, path_qc_results, path_gt_results, path_diff = \ + path_predictions, path_gts, path_qc_results, path_gt_results, path_diff = ( prepare_output_files(path_predictions, path_gts, path_qc_results) + ) # get label list labels_list, _ = utils.get_list_labels(label_list=labels_list) @@ -58,22 +59,28 @@ def predict(path_predictions, # prepare qc results pred_qc_results = np.zeros((len(labels_list_unique) + 1, len(path_predictions))) - gt_qc_results = np.zeros((len(labels_list_unique), len(path_predictions))) if path_gt_results is not None else None + gt_qc_results = ( + np.zeros((len(labels_list_unique), len(path_predictions))) + if path_gt_results is not None + else None + ) # build network model_input_shape = [None, None, None, 1] - net = build_qc_model(path_model=path_model, - input_shape=model_input_shape, - label_list=labels_list_unique, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation) + net = build_qc_model( + path_model=path_model, + input_shape=model_input_shape, + label_list=labels_list_unique, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + ) # perform segmentation - loop_info = utils.LoopInfo(len(path_predictions), 10, 'predicting', True) + loop_info = utils.LoopInfo(len(path_predictions), 10, "predicting", True) for idx, (path_prediction, path_gt) in enumerate(zip(path_predictions, path_gts)): # compute segmentation only if needed @@ -81,7 +88,9 @@ def predict(path_predictions, loop_info.update(idx) # preprocessing - prediction, gt_scores = preprocess(path_prediction, path_gt, shape, labels_list, labels_to_convert, convert_gt) + prediction, gt_scores = preprocess( + path_prediction, path_gt, shape, labels_list, labels_to_convert, convert_gt + ) # get predicted scores pred_qc_results[-1, idx] = np.sum(prediction > 0) @@ -101,8 +110,8 @@ def predict(path_predictions, def prepare_output_files(path_predictions, path_gts, path_qc_results): # check inputs - assert path_predictions is not None, 'please specify an input file/folder (--i)' - assert path_qc_results is not None, 'please specify an output file/folder (--o)' + assert path_predictions is not None, "please specify an input file/folder (--i)" + assert path_qc_results is not None, "please specify an output file/folder (--o)" # convert path to absolute paths path_predictions = os.path.abspath(path_predictions) @@ -112,16 +121,20 @@ def prepare_output_files(path_predictions, path_gts, path_qc_results): path_predictions = utils.list_images_in_folder(path_predictions) # build path output with qc results - if path_qc_results[-4:] != '.npy': - print('Path for QC outputs provided without npy extension. Adding npy extension.') - path_qc_results += '.npy' + if path_qc_results[-4:] != ".npy": + print( + "Path for QC outputs provided without npy extension. Adding npy extension." + ) + path_qc_results += ".npy" utils.mkdir(os.path.dirname(path_qc_results)) if path_gts is not None: path_gts = utils.list_images_in_folder(path_gts) - assert len(path_gts) == len(path_predictions), 'not the same number of predictions and GTs' - path_gt_results = path_qc_results.replace('.npy', '_gt.npy') - path_diff = path_qc_results.replace('.npy', '_diff.npy') + assert len(path_gts) == len( + path_predictions + ), "not the same number of predictions and GTs" + path_gt_results = path_qc_results.replace(".npy", "_gt.npy") + path_diff = path_qc_results.replace(".npy", "_diff.npy") else: path_gts = [None] * len(path_predictions) path_gt_results = path_diff = None @@ -129,14 +142,23 @@ def prepare_output_files(path_predictions, path_gts, path_qc_results): return path_predictions, path_gts, path_qc_results, path_gt_results, path_diff -def preprocess(path_prediction, path_gt=None, shape=224, labels_list=None, labels_to_convert=None, convert_gt=False): +def preprocess( + path_prediction, + path_gt=None, + shape=224, + labels_list=None, + labels_to_convert=None, + convert_gt=False, +): # read image and corresponding info pred, _, aff_pred, n_dims, _, _, _ = utils.get_volume_info(path_prediction, True) gt = utils.load_volume(path_gt, aff_ref=np.eye(4)) if path_gt is not None else None # align - pred = edit_volumes.align_volume_to_ref(pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims) + pred = edit_volumes.align_volume_to_ref( + pred, aff_pred, aff_ref=np.eye(4), n_dims=n_dims + ) # pad/crop to 224, such that segmentations are in the middle of the patch if gt is not None: @@ -147,15 +169,19 @@ def preprocess(path_prediction, path_gt=None, shape=224, labels_list=None, label # convert labels if necessary if labels_to_convert is not None: lut = utils.get_mapping_lut(labels_to_convert, labels_list) - pred = lut[pred.astype('int32')] + pred = lut[pred.astype("int32")] if convert_gt & (gt is not None): - gt = lut[gt.astype('int32')] + gt = lut[gt.astype("int32")] # compute GT dice scores - gt_scores = evaluate.fast_dice(pred, gt, np.unique(labels_list)) if gt is not None else None + gt_scores = ( + evaluate.fast_dice(pred, gt, np.unique(labels_list)) if gt is not None else None + ) # add batch and channel axes - pred = utils.add_axis(pred, axis=0) # channel axis will be added later when computing one-hot + pred = utils.add_axis( + pred, axis=0 + ) # channel axis will be added later when computing one-hot return pred, gt_scores @@ -175,11 +201,17 @@ def make_shape(pred, gt, shape, n_dims): # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx cropping_shape = np.array(utils.reformat_to_list(shape, length=n_dims)) - min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) # crop volume - cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) pred = edit_volumes.crop_volume_with_idx(pred, cropping, n_dims=n_dims) gt = edit_volumes.crop_volume_with_idx(gt, cropping, n_dims=n_dims) @@ -187,22 +219,26 @@ def make_shape(pred, gt, shape, n_dims): min_padding = np.abs(np.minimum(min_idx, 0)) max_padding = np.maximum(max_idx - vol_shape, 0) if np.any(min_padding > 0) | np.any(max_padding > 0): - pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) - pred = np.pad(pred, pad_margins, mode='constant', constant_values=0) - gt = np.pad(gt, pad_margins, mode='constant', constant_values=0) + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) + pred = np.pad(pred, pad_margins, mode="constant", constant_values=0) + gt = np.pad(gt, pad_margins, mode="constant", constant_values=0) return pred, gt -def build_qc_model(path_model, - input_shape, - label_list, - n_levels, - nb_conv_per_level, - conv_size, - unet_feat_count, - feat_multiplier, - activation): +def build_qc_model( + path_model, + input_shape, + label_list, + n_levels, + nb_conv_per_level, + conv_size, + unet_feat_count, + feat_multiplier, + activation, +): assert os.path.isfile(path_model), "The provided model path does not exist." label_list_unique = np.unique(label_list) @@ -210,28 +246,44 @@ def build_qc_model(path_model, # one-hot encoding of the input prediction as the network expects soft probabilities input_labels = KL.Input(input_shape[:-1]) - labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(input_labels) + labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, dtype="int32"), depth=n_labels, axis=-1) + )(input_labels) net = Model(inputs=input_labels, outputs=labels) # build model - model = nrn_models.conv_enc(input_model=net, - input_shape=input_shape, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - use_residuals=True, - name='qc') + model = nrn_models.conv_enc( + input_model=net, + input_shape=input_shape, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + use_residuals=True, + name="qc", + ) last = model.outputs[0] - conv_kwargs = {'padding': 'same', 'activation': 'relu', 'data_format': 'channels_last'} - last = KL.MaxPool3D(pool_size=(2, 2, 2), name='qc_maxpool_%s' % (n_levels - 1), padding='same')(last) - last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_0')(last) - last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_1')(last) - last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name='qc_final_pred')(last) + conv_kwargs = { + "padding": "same", + "activation": "relu", + "data_format": "channels_last", + } + last = KL.MaxPool3D( + pool_size=(2, 2, 2), name="qc_maxpool_%s" % (n_levels - 1), padding="same" + )(last) + last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name="qc_final_conv_0")( + last + ) + last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name="qc_final_conv_1")( + last + ) + last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name="qc_final_pred")( + last + ) net = Model(inputs=net.inputs, outputs=last) net.load_weights(path_model, by_name=True) diff --git a/nobrainer/ext/SynthSeg/predict_synthseg.py b/nobrainer/ext/SynthSeg/predict_synthseg.py index 39c060cf..d76d97d9 100644 --- a/nobrainer/ext/SynthSeg/predict_synthseg.py +++ b/nobrainer/ext/SynthSeg/predict_synthseg.py @@ -13,66 +13,72 @@ License. """ - # python imports import os import sys import traceback -import numpy as np -import tensorflow as tf -import keras.layers as KL -import keras.backend as K -from keras.models import Model # project imports from SynthSeg import evaluate -from SynthSeg.predict import write_csv, get_flip_indices +from SynthSeg.predict import get_flip_indices, write_csv # third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im import edit_volumes +from ext.lab2im import edit_volumes, layers, utils from ext.neuron import models as nrn_models +import keras.backend as K +import keras.layers as KL +from keras.models import Model +import numpy as np +import tensorflow as tf -def predict(path_images, - path_segmentations, - path_model_segmentation, - labels_segmentation, - robust, - fast, - v1, - n_neutral_labels, - labels_denoiser, - path_posteriors, - path_resampled, - path_volumes, - do_parcellation, - path_model_parcellation, - labels_parcellation, - path_qc_scores, - path_model_qc, - labels_qc, - cropping, - ct=False, - names_segmentation=None, - names_parcellation=None, - names_qc=None, - topology_classes=None, - sigma_smoothing=0.5, - input_shape_qc=224, - gt_folder=None, - evaluation_labels=None, - mask_folder=None, - list_incorrect_labels=None, - list_correct_labels=None, - compute_distances=False, - recompute=True, - verbose=True): +def predict( + path_images, + path_segmentations, + path_model_segmentation, + labels_segmentation, + robust, + fast, + v1, + n_neutral_labels, + labels_denoiser, + path_posteriors, + path_resampled, + path_volumes, + do_parcellation, + path_model_parcellation, + labels_parcellation, + path_qc_scores, + path_model_qc, + labels_qc, + cropping, + ct=False, + names_segmentation=None, + names_parcellation=None, + names_qc=None, + topology_classes=None, + sigma_smoothing=0.5, + input_shape_qc=224, + gt_folder=None, + evaluation_labels=None, + mask_folder=None, + list_incorrect_labels=None, + list_correct_labels=None, + compute_distances=False, + recompute=True, + verbose=True, +): # prepare input/output filepaths - outputs = prepare_output_files(path_images, path_segmentations, path_posteriors, path_resampled, - path_volumes, path_qc_scores, recompute) + outputs = prepare_output_files( + path_images, + path_segmentations, + path_posteriors, + path_resampled, + path_volumes, + path_qc_scores, + recompute, + ) path_images = outputs[0] path_segmentations = outputs[1] path_posteriors = outputs[2] @@ -86,22 +92,32 @@ def predict(path_images, # get label lists labels_segmentation, _ = utils.get_list_labels(label_list=labels_segmentation) if (n_neutral_labels is not None) & (not fast) & (not robust): - labels_segmentation, flip_indices, unique_idx = get_flip_indices(labels_segmentation, n_neutral_labels) + labels_segmentation, flip_indices, unique_idx = get_flip_indices( + labels_segmentation, n_neutral_labels + ) else: - labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True) + labels_segmentation, unique_idx = np.unique( + labels_segmentation, return_index=True + ) flip_indices = None # prepare other labels list if names_segmentation is not None: names_segmentation = utils.load_array_if_path(names_segmentation)[unique_idx] if topology_classes is not None: - topology_classes = utils.load_array_if_path(topology_classes, load_as_numpy=True)[unique_idx] + topology_classes = utils.load_array_if_path( + topology_classes, load_as_numpy=True + )[unique_idx] labels_denoiser = np.unique(utils.get_list_labels(labels_denoiser)[0]) if do_parcellation: - labels_parcellation, unique_i_parc = np.unique(utils.get_list_labels(labels_parcellation)[0], return_index=True) + labels_parcellation, unique_i_parc = np.unique( + utils.get_list_labels(labels_parcellation)[0], return_index=True + ) labels_volumes = np.concatenate([labels_segmentation, labels_parcellation[1:]]) if (names_parcellation is not None) & (names_segmentation is not None): - names_parcellation = utils.load_array_if_path(names_parcellation)[unique_i_parc][1:] + names_parcellation = utils.load_array_if_path(names_parcellation)[ + unique_i_parc + ][1:] names_volumes = np.concatenate([names_segmentation, names_parcellation]) else: names_volumes = names_segmentation @@ -109,9 +125,13 @@ def predict(path_images, labels_volumes = labels_segmentation names_volumes = names_segmentation if not v1: - labels_volumes = np.concatenate([labels_volumes, np.array([np.max(labels_volumes + 1)])]) + labels_volumes = np.concatenate( + [labels_volumes, np.array([np.max(labels_volumes + 1)])] + ) if names_segmentation is not None: - names_volumes = np.concatenate([names_volumes, np.array(['total intracranial'])]) + names_volumes = np.concatenate( + [names_volumes, np.array(["total intracranial"])] + ) do_qc = True if path_qc_scores[0] is not None else False if do_qc: labels_qc = utils.get_list_labels(labels_qc)[0][unique_idx] @@ -120,37 +140,46 @@ def predict(path_images, # prepare volume/QC files if necessary if unique_vol_file & (path_volumes[0] is not None): - write_csv(path_volumes[0], None, True, labels_volumes, names_volumes, last_first=(not v1)) + write_csv( + path_volumes[0], + None, + True, + labels_volumes, + names_volumes, + last_first=(not v1), + ) if unique_qc_file & do_qc: write_csv(path_qc_scores[0], None, True, labels_qc, names_qc) # build network - net = build_model(path_model_segmentation=path_model_segmentation, - path_model_parcellation=path_model_parcellation, - path_model_qc=path_model_qc, - input_shape_qc=input_shape_qc, - labels_segmentation=labels_segmentation, - labels_denoiser=labels_denoiser, - labels_parcellation=labels_parcellation, - labels_qc=labels_qc, - sigma_smoothing=sigma_smoothing, - flip_indices=flip_indices, - robust=robust, - do_parcellation=do_parcellation, - do_qc=do_qc) + net = build_model( + path_model_segmentation=path_model_segmentation, + path_model_parcellation=path_model_parcellation, + path_model_qc=path_model_qc, + input_shape_qc=input_shape_qc, + labels_segmentation=labels_segmentation, + labels_denoiser=labels_denoiser, + labels_parcellation=labels_parcellation, + labels_qc=labels_qc, + sigma_smoothing=sigma_smoothing, + flip_indices=flip_indices, + robust=robust, + do_parcellation=do_parcellation, + do_qc=do_qc, + ) # set cropping/padding if cropping is not None: - cropping = utils.reformat_to_list(cropping, length=3, dtype='int') + cropping = utils.reformat_to_list(cropping, length=3, dtype="int") min_pad = cropping else: min_pad = 128 # perform segmentation if len(path_images) <= 10: - loop_info = utils.LoopInfo(len(path_images), 1, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 1, "predicting", True) else: - loop_info = utils.LoopInfo(len(path_images), 10, 'predicting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "predicting", True) list_errors = list() for i in range(len(path_images)): if verbose: @@ -162,98 +191,151 @@ def predict(path_images, try: # preprocessing - image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=path_images[i], - ct=ct, - crop=cropping, - min_pad=min_pad, - path_resample=path_resampled[i]) + image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess( + path_image=path_images[i], + ct=ct, + crop=cropping, + min_pad=min_pad, + path_resample=path_resampled[i], + ) # prediction shape_input = utils.add_axis(np.array(image.shape[1:-1])) if do_parcellation & do_qc: - post_patch_segmentation, post_patch_parcellation, qc_score = net.predict([image, shape_input]) + post_patch_segmentation, post_patch_parcellation, qc_score = ( + net.predict([image, shape_input]) + ) elif do_parcellation & (not do_qc): - post_patch_segmentation, post_patch_parcellation = net.predict(image) + post_patch_segmentation, post_patch_parcellation = net.predict( + image + ) qc_score = None elif (not do_parcellation) & do_qc: - post_patch_segmentation, qc_score = net.predict([image, shape_input]) + post_patch_segmentation, qc_score = net.predict( + [image, shape_input] + ) post_patch_parcellation = None else: post_patch_segmentation = net.predict(image) post_patch_parcellation = qc_score = None # postprocessing - seg, posteriors, volumes = postprocess(post_patch_seg=post_patch_segmentation, - post_patch_parc=post_patch_parcellation, - shape=shape, - pad_idx=pad_idx, - crop_idx=crop_idx, - labels_segmentation=labels_segmentation, - labels_parcellation=labels_parcellation, - aff=aff, - im_res=im_res, - fast=fast, - topology_classes=topology_classes, - v1=v1) + seg, posteriors, volumes = postprocess( + post_patch_seg=post_patch_segmentation, + post_patch_parc=post_patch_parcellation, + shape=shape, + pad_idx=pad_idx, + crop_idx=crop_idx, + labels_segmentation=labels_segmentation, + labels_parcellation=labels_parcellation, + aff=aff, + im_res=im_res, + fast=fast, + topology_classes=topology_classes, + v1=v1, + ) # write predictions to disc - utils.save_volume(seg, aff, h, path_segmentations[i], dtype='int32') + utils.save_volume(seg, aff, h, path_segmentations[i], dtype="int32") if path_posteriors[i] is not None: - utils.save_volume(posteriors, aff, h, path_posteriors[i], dtype='float32') + utils.save_volume( + posteriors, aff, h, path_posteriors[i], dtype="float32" + ) # write volumes to disc if necessary if path_volumes[i] is not None: - row = [os.path.basename(path_images[i]).replace('.nii.gz', '')] + [str(vol) for vol in volumes] - write_csv(path_volumes[i], row, unique_vol_file, labels_volumes, names_volumes, last_first=(not v1)) + row = [os.path.basename(path_images[i]).replace(".nii.gz", "")] + [ + str(vol) for vol in volumes + ] + write_csv( + path_volumes[i], + row, + unique_vol_file, + labels_volumes, + names_volumes, + last_first=(not v1), + ) # write QC scores to disc if necessary if path_qc_scores[i] is not None: qc_score = np.around(np.clip(np.squeeze(qc_score)[1:], 0, 1), 4) - row = [os.path.basename(path_images[i]).replace('.nii.gz', '')] + ['%.4f' % q for q in qc_score] - write_csv(path_qc_scores[i], row, unique_qc_file, labels_qc, names_qc) + row = [os.path.basename(path_images[i]).replace(".nii.gz", "")] + [ + "%.4f" % q for q in qc_score + ] + write_csv( + path_qc_scores[i], row, unique_qc_file, labels_qc, names_qc + ) except Exception as e: list_errors.append(path_images[i]) - print('\nthe following problem occurred with image %s :' % path_images[i]) + print( + "\nthe following problem occurred with image %s :" % path_images[i] + ) print(traceback.format_exc()) - print('resuming program execution\n') + print("resuming program execution\n") continue # print output info if len(path_segmentations) == 1: # only one image is processed - print('\nsegmentation saved in: ' + path_segmentations[0]) + print("\nsegmentation saved in: " + path_segmentations[0]) if path_posteriors[0] is not None: - print('posteriors saved in: ' + path_posteriors[0]) + print("posteriors saved in: " + path_posteriors[0]) if path_resampled[0] is not None: - print('resampled image saved in: ' + path_resampled[0]) + print("resampled image saved in: " + path_resampled[0]) if path_volumes[0] is not None: - print('volumes saved in: ' + path_volumes[0]) + print("volumes saved in: " + path_volumes[0]) if path_qc_scores[0] is not None: - print('QC scores saved in: ' + path_qc_scores[0]) + print("QC scores saved in: " + path_qc_scores[0]) else: # all segmentations are in the same folder, and we have unique vol/QC files - if len(set([os.path.dirname(path_segmentations[i]) for i in range(len(path_segmentations))])) <= 1: - print('\nsegmentations saved in: ' + os.path.dirname(path_segmentations[0])) + if ( + len( + set( + [ + os.path.dirname(path_segmentations[i]) + for i in range(len(path_segmentations)) + ] + ) + ) + <= 1 + ): + print( + "\nsegmentations saved in: " + os.path.dirname(path_segmentations[0]) + ) if path_posteriors[0] is not None: - print('posteriors saved in: ' + os.path.dirname(path_posteriors[0])) + print( + "posteriors saved in: " + os.path.dirname(path_posteriors[0]) + ) if path_resampled[0] is not None: - print('resampled images saved in: ' + os.path.dirname(path_resampled[0])) + print( + "resampled images saved in: " + os.path.dirname(path_resampled[0]) + ) if path_volumes[0] is not None: - print('volumes saved in: ' + path_volumes[0]) + print("volumes saved in: " + path_volumes[0]) if path_qc_scores[0] is not None: - print('QC scores saved in: ' + path_qc_scores[0]) + print("QC scores saved in: " + path_qc_scores[0]) if robust: - print('\nIf you use the new robust version of SynthSeg in a publication, please cite:') - print('Robust machine learning segmentation for large-scale analysis of heterogeneous clinical brain MRI ' - 'datasets') - print('B. Billot, M. Collin, S.E. Arnold, S. Das, J.E. Iglesias') + print( + "\nIf you use the new robust version of SynthSeg in a publication, please cite:" + ) + print( + "Robust machine learning segmentation for large-scale analysis of heterogeneous clinical brain MRI " + "datasets" + ) + print("B. Billot, M. Collin, S.E. Arnold, S. Das, J.E. Iglesias") else: - print('\nIf you use this tool in a publication, please cite:') - print('SynthSeg: domain randomisation for segmentation of brain MRI scans of any contrast and resolution') - print('B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias') + print("\nIf you use this tool in a publication, please cite:") + print( + "SynthSeg: domain randomisation for segmentation of brain MRI scans of any contrast and resolution" + ) + print( + "B. Billot, D.N. Greve, O. Puonti, A. Thielscher, K. Van Leemput, B. Fischl, A.V. Dalca, J.E. Iglesias" + ) if len(list_errors) > 0: - print('\nERROR: some problems occurred for the following inputs (see corresponding errors above):') + print( + "\nERROR: some problems occurred for the following inputs (see corresponding errors above):" + ) for path_error_image in list_errors: print(path_error_image) sys.exit(1) @@ -268,59 +350,84 @@ def predict(path_images, # set path of result arrays for surface distance if necessary if compute_distances: - path_hausdorff = os.path.join(eval_folder, 'hausdorff.npy') - path_hausdorff_99 = os.path.join(eval_folder, 'hausdorff_99.npy') - path_hausdorff_95 = os.path.join(eval_folder, 'hausdorff_95.npy') - path_mean_distance = os.path.join(eval_folder, 'mean_distance.npy') + path_hausdorff = os.path.join(eval_folder, "hausdorff.npy") + path_hausdorff_99 = os.path.join(eval_folder, "hausdorff_99.npy") + path_hausdorff_95 = os.path.join(eval_folder, "hausdorff_95.npy") + path_mean_distance = os.path.join(eval_folder, "mean_distance.npy") else: - path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = path_mean_distance = None + path_hausdorff = path_hausdorff_99 = path_hausdorff_95 = ( + path_mean_distance + ) = None # compute evaluation metrics - evaluate.evaluation(gt_folder, - eval_folder, - evaluation_labels, - mask_dir=mask_folder, - path_dice=os.path.join(eval_folder, 'dice.npy'), - path_hausdorff=path_hausdorff, - path_hausdorff_99=path_hausdorff_99, - path_hausdorff_95=path_hausdorff_95, - path_mean_distance=path_mean_distance, - list_incorrect_labels=list_incorrect_labels, - list_correct_labels=list_correct_labels, - recompute=recompute, - verbose=verbose) - - -def prepare_output_files(path_images, out_seg, out_posteriors, out_resampled, out_volumes, out_qc, recompute): + evaluate.evaluation( + gt_folder, + eval_folder, + evaluation_labels, + mask_dir=mask_folder, + path_dice=os.path.join(eval_folder, "dice.npy"), + path_hausdorff=path_hausdorff, + path_hausdorff_99=path_hausdorff_99, + path_hausdorff_95=path_hausdorff_95, + path_mean_distance=path_mean_distance, + list_incorrect_labels=list_incorrect_labels, + list_correct_labels=list_correct_labels, + recompute=recompute, + verbose=verbose, + ) + + +def prepare_output_files( + path_images, out_seg, out_posteriors, out_resampled, out_volumes, out_qc, recompute +): # check inputs - assert path_images is not None, 'please specify an input file/folder (--i)' - assert out_seg is not None, 'please specify an output file/folder (--o)' + assert path_images is not None, "please specify an input file/folder (--i)" + assert out_seg is not None, "please specify an output file/folder (--o)" # convert path to absolute paths path_images = os.path.abspath(path_images) basename = os.path.basename(path_images) out_seg = os.path.abspath(out_seg) - out_posteriors = os.path.abspath(out_posteriors) if (out_posteriors is not None) else out_posteriors - out_resampled = os.path.abspath(out_resampled) if (out_resampled is not None) else out_resampled - out_volumes = os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + out_posteriors = ( + os.path.abspath(out_posteriors) + if (out_posteriors is not None) + else out_posteriors + ) + out_resampled = ( + os.path.abspath(out_resampled) if (out_resampled is not None) else out_resampled + ) + out_volumes = ( + os.path.abspath(out_volumes) if (out_volumes is not None) else out_volumes + ) out_qc = os.path.abspath(out_qc) if (out_qc is not None) else out_qc # path_images is a text file - if basename[-4:] == '.txt': + if basename[-4:] == ".txt": # input images if not os.path.isfile(path_images): - raise Exception('provided text file containing paths of input images does not exist' % path_images) - with open(path_images, 'r') as f: - path_images = [line.replace('\n', '') for line in f.readlines() if line != '\n'] + raise Exception( + "provided text file containing paths of input images does not exist" + % path_images + ) + with open(path_images, "r") as f: + path_images = [ + line.replace("\n", "") for line in f.readlines() if line != "\n" + ] # define helper to deal with outputs def text_helper(path, name): if path is not None: - assert path[-4:] == '.txt', 'if path_images given as text file, so must be %s' % name - with open(path, 'r') as ff: - path = [line.replace('\n', '') for line in ff.readlines() if line != '\n'] + assert path[-4:] == ".txt", ( + "if path_images given as text file, so must be %s" % name + ) + with open(path, "r") as ff: + path = [ + line.replace("\n", "") + for line in ff.readlines() + if line != "\n" + ] recompute_files = [not os.path.isfile(p) for p in path] else: path = [None] * len(path_images) @@ -329,39 +436,65 @@ def text_helper(path, name): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = text_helper(out_seg, 'path_segmentations') - out_posteriors, recompute_post, _ = text_helper(out_posteriors, 'path_posteriors') - out_resampled, recompute_resampled, _ = text_helper(out_resampled, 'path_resampled') - out_volumes, recompute_volume, unique_volume_file = text_helper(out_volumes, 'path_volume') - out_qc, recompute_qc, unique_qc_file = text_helper(out_qc, 'path_qc_scores') + out_seg, recompute_seg, _ = text_helper(out_seg, "path_segmentations") + out_posteriors, recompute_post, _ = text_helper( + out_posteriors, "path_posteriors" + ) + out_resampled, recompute_resampled, _ = text_helper( + out_resampled, "path_resampled" + ) + out_volumes, recompute_volume, unique_volume_file = text_helper( + out_volumes, "path_volume" + ) + out_qc, recompute_qc, unique_qc_file = text_helper(out_qc, "path_qc_scores") # path_images is a folder - elif ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename): + elif ( + (".nii.gz" not in basename) + & (".nii" not in basename) + & (".mgz" not in basename) + & (".npz" not in basename) + ): # input images if os.path.isfile(path_images): - raise Exception('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_images) + raise Exception( + "Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz" + % path_images + ) path_images = utils.list_images_in_folder(path_images) # define helper to deal with outputs def helper_dir(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" path = [path] * len(path_images) recompute_files = [True] * len(path_images) unique_file = True else: - if (path[-7:] == '.nii.gz') | (path[-4:] == '.nii') | (path[-4:] == '.mgz') | (path[-4:] == '.npz'): - raise Exception('Output FOLDER had a FILE extension' % path) - path = [os.path.join(path, os.path.basename(p)) for p in path_images] - path = [p.replace('.nii', '_%s.nii' % suffix) for p in path] - path = [p.replace('.mgz', '_%s.mgz' % suffix) for p in path] - path = [p.replace('.npz', '_%s.npz' % suffix) for p in path] + if ( + (path[-7:] == ".nii.gz") + | (path[-4:] == ".nii") + | (path[-4:] == ".mgz") + | (path[-4:] == ".npz") + ): + raise Exception("Output FOLDER had a FILE extension" % path) + path = [ + os.path.join(path, os.path.basename(p)) for p in path_images + ] + path = [p.replace(".nii", "_%s.nii" % suffix) for p in path] + path = [p.replace(".mgz", "_%s.mgz" % suffix) for p in path] + path = [p.replace(".npz", "_%s.npz" % suffix) for p in path] recompute_files = [not os.path.isfile(p) for p in path] utils.mkdir(os.path.dirname(path[0])) else: @@ -370,36 +503,60 @@ def helper_dir(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_dir(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_dir(out_posteriors, 'path_posteriors', '', 'posteriors') - out_resampled, recompute_resampled, _ = helper_dir(out_resampled, 'path_resampled', '', 'resampled') - out_volumes, recompute_volume, unique_volume_file = helper_dir(out_volumes, 'path_volumes', 'csv', '') - out_qc, recompute_qc, unique_qc_file = helper_dir(out_qc, 'path_qc_scores', 'csv', '') + out_seg, recompute_seg, _ = helper_dir( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_dir( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_resampled, recompute_resampled, _ = helper_dir( + out_resampled, "path_resampled", "", "resampled" + ) + out_volumes, recompute_volume, unique_volume_file = helper_dir( + out_volumes, "path_volumes", "csv", "" + ) + out_qc, recompute_qc, unique_qc_file = helper_dir( + out_qc, "path_qc_scores", "csv", "" + ) # path_images is an image else: # input images - assert os.path.isfile(path_images), 'file does not exist: %s \n' \ - 'please make sure the path and the extension are correct' % path_images + assert os.path.isfile(path_images), ( + "file does not exist: %s \n" + "please make sure the path and the extension are correct" % path_images + ) path_images = [path_images] # define helper to deal with outputs def helper_im(path, name, file_type, suffix): unique_file = False if path is not None: - assert path[-4:] != '.txt', '%s can only be given as text file when path_images is.' % name - if file_type == 'csv': - if path[-4:] != '.csv': - print('%s provided without csv extension. Adding csv extension.' % name) - path += '.csv' + assert path[-4:] != ".txt", ( + "%s can only be given as text file when path_images is." % name + ) + if file_type == "csv": + if path[-4:] != ".csv": + print( + "%s provided without csv extension. Adding csv extension." + % name + ) + path += ".csv" recompute_files = [True] unique_file = True else: - if ('.nii.gz' not in path) & ('.nii' not in path) & ('.mgz' not in path) & ('.npz' not in path): - file_name = os.path.basename(path_images[0]).replace('.nii', '_%s.nii' % suffix) - file_name = file_name.replace('.mgz', '_%s.mgz' % suffix) - file_name = file_name.replace('.npz', '_%s.npz' % suffix) + if ( + (".nii.gz" not in path) + & (".nii" not in path) + & (".mgz" not in path) + & (".npz" not in path) + ): + file_name = os.path.basename(path_images[0]).replace( + ".nii", "_%s.nii" % suffix + ) + file_name = file_name.replace(".mgz", "_%s.mgz" % suffix) + file_name = file_name.replace(".npz", "_%s.npz" % suffix) path = os.path.join(path, file_name) recompute_files = [not os.path.isfile(path)] utils.mkdir(os.path.dirname(path)) @@ -409,39 +566,75 @@ def helper_im(path, name, file_type, suffix): return path, recompute_files, unique_file # use helper on all outputs - out_seg, recompute_seg, _ = helper_im(out_seg, 'path_segmentations', '', 'synthseg') - out_posteriors, recompute_post, _ = helper_im(out_posteriors, 'path_posteriors', '', 'posteriors') - out_resampled, recompute_resampled, _ = helper_im(out_resampled, 'path_resampled', '', 'resampled') - out_volumes, recompute_volume, unique_volume_file = helper_im(out_volumes, 'path_volumes', 'csv', '') - out_qc, recompute_qc, unique_qc_file = helper_im(out_qc, 'path_qc_scores', 'csv', '') - - recompute_list = [recompute | re_seg | re_post | re_res | re_vol | re_qc - for (re_seg, re_post, re_res, re_vol, re_qc) - in zip(recompute_seg, recompute_post, recompute_resampled, recompute_volume, recompute_qc)] - - return path_images, out_seg, out_posteriors, out_resampled, out_volumes, unique_volume_file, \ - out_qc, unique_qc_file, recompute_list - - -def preprocess(path_image, ct, target_res=1., n_levels=5, crop=None, min_pad=None, path_resample=None): + out_seg, recompute_seg, _ = helper_im( + out_seg, "path_segmentations", "", "synthseg" + ) + out_posteriors, recompute_post, _ = helper_im( + out_posteriors, "path_posteriors", "", "posteriors" + ) + out_resampled, recompute_resampled, _ = helper_im( + out_resampled, "path_resampled", "", "resampled" + ) + out_volumes, recompute_volume, unique_volume_file = helper_im( + out_volumes, "path_volumes", "csv", "" + ) + out_qc, recompute_qc, unique_qc_file = helper_im( + out_qc, "path_qc_scores", "csv", "" + ) + + recompute_list = [ + recompute | re_seg | re_post | re_res | re_vol | re_qc + for (re_seg, re_post, re_res, re_vol, re_qc) in zip( + recompute_seg, + recompute_post, + recompute_resampled, + recompute_volume, + recompute_qc, + ) + ] + + return ( + path_images, + out_seg, + out_posteriors, + out_resampled, + out_volumes, + unique_volume_file, + out_qc, + unique_qc_file, + recompute_list, + ) + + +def preprocess( + path_image, + ct, + target_res=1.0, + n_levels=5, + crop=None, + min_pad=None, + path_resample=None, +): # read image and corresponding info im, _, aff, n_dims, n_channels, h, im_res = utils.get_volume_info(path_image, True) if n_dims == 2 and 1 < n_channels < 4: - raise Exception('either the input is 2D with several channels, or is 3D with at most 3 slices. ' - 'Either way, results are going to be poor...') + raise Exception( + "either the input is 2D with several channels, or is 3D with at most 3 slices. " + "Either way, results are going to be poor..." + ) elif n_dims == 2 and 3 < n_channels < 11: - print('warning: input with very few slices') + print("warning: input with very few slices") n_dims = 3 elif n_dims < 3: - raise Exception('input should have 3 dimensions, had %s' % n_dims) + raise Exception("input should have 3 dimensions, had %s" % n_dims) elif n_dims == 4 and n_channels == 1: n_dims = 3 im = im[..., 0] elif n_dims > 3: - raise Exception('input should have 3 dimensions, had %s' % n_dims) + raise Exception("input should have 3 dimensions, had %s" % n_dims) elif n_channels > 1: - print('WARNING: detected more than 1 channel, only keeping the first channel.') + print("WARNING: detected more than 1 channel, only keeping the first channel.") im = im[..., 0] # resample image if necessary @@ -453,29 +646,46 @@ def preprocess(path_image, ct, target_res=1., n_levels=5, crop=None, min_pad=Non utils.save_volume(im, aff, h, path_resample) # align image - im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False) + im = edit_volumes.align_volume_to_ref( + im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False + ) shape = list(im.shape[:n_dims]) # crop image if necessary if crop is not None: - crop = utils.reformat_to_list(crop, length=n_dims, dtype='int') - crop_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop] - im, crop_idx = edit_volumes.crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True) + crop = utils.reformat_to_list(crop, length=n_dims, dtype="int") + crop_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in crop + ] + im, crop_idx = edit_volumes.crop_volume( + im, cropping_shape=crop_shape, return_crop_idx=True + ) else: crop_idx = None # normalise image if ct: im = np.clip(im, 0, 80) - im = edit_volumes.rescale_volume(im, new_min=0., new_max=1., min_percentile=0.5, max_percentile=99.5) + im = edit_volumes.rescale_volume( + im, new_min=0.0, new_max=1.0, min_percentile=0.5, max_percentile=99.5 + ) # pad image input_shape = im.shape[:n_dims] - pad_shape = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape] - min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype='int') - min_pad = [utils.find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad] + pad_shape = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in input_shape + ] + min_pad = utils.reformat_to_list(min_pad, length=n_dims, dtype="int") + min_pad = [ + utils.find_closest_number_divisible_by_m(s, 2**n_levels, "higher") + for s in min_pad + ] pad_shape = np.maximum(pad_shape, min_pad) - im, pad_idx = edit_volumes.pad_volume(im, padding_shape=pad_shape, return_pad_idx=True) + im, pad_idx = edit_volumes.pad_volume( + im, padding_shape=pad_shape, return_pad_idx=True + ) # add batch and channel axes im = utils.add_axis(im, axis=[0, -1]) @@ -483,21 +693,25 @@ def preprocess(path_image, ct, target_res=1., n_levels=5, crop=None, min_pad=Non return im, aff, h, im_res, shape, pad_idx, crop_idx -def build_model(path_model_segmentation, - path_model_parcellation, - path_model_qc, - input_shape_qc, - labels_segmentation, - labels_denoiser, - labels_parcellation, - labels_qc, - sigma_smoothing, - flip_indices, - robust, - do_parcellation, - do_qc): - - assert os.path.isfile(path_model_segmentation), "The provided model path does not exist." +def build_model( + path_model_segmentation, + path_model_parcellation, + path_model_qc, + input_shape_qc, + labels_segmentation, + labels_denoiser, + labels_parcellation, + labels_qc, + sigma_smoothing, + flip_indices, + robust, + do_parcellation, + do_qc, +): + + assert os.path.isfile( + path_model_segmentation + ), "The provided model path does not exist." # get labels n_labels_seg = len(labels_segmentation) @@ -506,78 +720,92 @@ def build_model(path_model_segmentation, n_groups = len(labels_denoiser) # build first UNet - net = nrn_models.unet(input_shape=[None, None, None, 1], - nb_labels=n_groups, - nb_levels=5, - nb_conv_per_level=2, - conv_size=3, - nb_features=24, - feat_mult=2, - activation='elu', - batch_norm=-1, - name='unet') + net = nrn_models.unet( + input_shape=[None, None, None, 1], + nb_labels=n_groups, + nb_levels=5, + nb_conv_per_level=2, + conv_size=3, + nb_features=24, + feat_mult=2, + activation="elu", + batch_norm=-1, + name="unet", + ) # transition between the two networks: one_hot -> argmax -> one_hot (it simulates how the network was trained) last_tensor = net.output last_tensor = KL.Lambda(lambda x: tf.argmax(x, axis=-1))(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=n_groups, axis=-1))(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=n_groups, axis=-1) + )(last_tensor) net = Model(inputs=net.inputs, outputs=last_tensor) # build denoiser - net = nrn_models.unet(input_model=net, - input_shape=[None, None, None, 1], - nb_labels=n_groups, - nb_levels=5, - nb_conv_per_level=2, - conv_size=5, - nb_features=16, - feat_mult=2, - activation='elu', - batch_norm=-1, - skip_n_concatenations=2, - name='l2l') + net = nrn_models.unet( + input_model=net, + input_shape=[None, None, None, 1], + nb_labels=n_groups, + nb_levels=5, + nb_conv_per_level=2, + conv_size=5, + nb_features=16, + feat_mult=2, + activation="elu", + batch_norm=-1, + skip_n_concatenations=2, + name="l2l", + ) # transition between the two networks: one_hot -> argmax -> one_hot, and concatenate input image and labels input_image = net.inputs[0] last_tensor = net.output last_tensor = KL.Lambda(lambda x: tf.argmax(x, axis=-1))(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=n_groups, axis=-1))(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=n_groups, axis=-1) + )(last_tensor) if n_groups <= 2: last_tensor = KL.Lambda(lambda x: x[..., 1:])(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor]) + last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), "float32"))( + [input_image, last_tensor] + ) net = Model(inputs=net.inputs, outputs=last_tensor) # build 2nd network - net = nrn_models.unet(input_model=net, - input_shape=[None, None, None, 2], - nb_labels=n_labels_seg, - nb_levels=5, - nb_conv_per_level=2, - conv_size=3, - nb_features=24, - feat_mult=2, - activation='elu', - batch_norm=-1, - name='unet2') + net = nrn_models.unet( + input_model=net, + input_shape=[None, None, None, 2], + nb_labels=n_labels_seg, + nb_levels=5, + nb_conv_per_level=2, + conv_size=3, + nb_features=24, + feat_mult=2, + activation="elu", + batch_norm=-1, + name="unet2", + ) net.load_weights(path_model_segmentation, by_name=True) - name_segm_prediction_layer = 'unet2_prediction' + name_segm_prediction_layer = "unet2_prediction" else: # build UNet - net = nrn_models.unet(input_shape=[None, None, None, 1], - nb_labels=n_labels_seg, - nb_levels=5, - nb_conv_per_level=2, - conv_size=3, - nb_features=24, - feat_mult=2, - activation='elu', - batch_norm=-1, - name='unet') + net = nrn_models.unet( + input_shape=[None, None, None, 1], + nb_labels=n_labels_seg, + nb_levels=5, + nb_conv_per_level=2, + conv_size=3, + nb_features=24, + feat_mult=2, + activation="elu", + batch_norm=-1, + name="unet", + ) net.load_weights(path_model_segmentation, by_name=True) input_image = net.inputs[0] - name_segm_prediction_layer = 'unet_prediction' + name_segm_prediction_layer = "unet_prediction" # smooth posteriors if specified if sigma_smoothing > 0: @@ -595,13 +823,21 @@ def build_model(path_model_segmentation, # flip back and re-order channels last_tensor = layers.RandomFlip(axis=0, prob=1)(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.split(x, [1] * n_labels_seg, axis=-1), name='split')(last_tensor) - reordered_channels = [last_tensor[flip_indices[i]] for i in range(n_labels_seg)] - last_tensor = KL.Lambda(lambda x: tf.concat(x, -1), name='concat')(reordered_channels) + last_tensor = KL.Lambda( + lambda x: tf.split(x, [1] * n_labels_seg, axis=-1), name="split" + )(last_tensor) + reordered_channels = [ + last_tensor[flip_indices[i]] for i in range(n_labels_seg) + ] + last_tensor = KL.Lambda(lambda x: tf.concat(x, -1), name="concat")( + reordered_channels + ) # average two segmentations and build model - name_segm_prediction_layer = 'average_lr' - last_tensor = KL.Lambda(lambda x: 0.5 * (x[0] + x[1]), name=name_segm_prediction_layer)([seg, last_tensor]) + name_segm_prediction_layer = "average_lr" + last_tensor = KL.Lambda( + lambda x: 0.5 * (x[0] + x[1]), name=name_segm_prediction_layer + )([seg, last_tensor]) net = Model(inputs=net.inputs, outputs=last_tensor) # add aparc segmenter if needed @@ -610,74 +846,115 @@ def build_model(path_model_segmentation, # build input for S3: only takes one map for cortical segmentation (no image), 1 = cortex, 0 = other last_tensor = net.output - last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(last_tensor) - last_tensor = layers.ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor) - parcellation_masking_values = np.array([1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation]) - last_tensor = layers.ConvertLabels(labels_segmentation, parcellation_masking_values)(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=2, axis=-1))(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor]) + last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), "int32"))( + last_tensor + ) + last_tensor = layers.ConvertLabels( + np.arange(n_labels_seg), labels_segmentation + )(last_tensor) + parcellation_masking_values = np.array( + [1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation] + ) + last_tensor = layers.ConvertLabels( + labels_segmentation, parcellation_masking_values + )(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=2, axis=-1) + )(last_tensor) + last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), "float32"))( + [input_image, last_tensor] + ) net = Model(inputs=net.inputs, outputs=last_tensor) # build UNet - net = nrn_models.unet(input_model=net, - input_shape=[None, None, None, 3], - nb_labels=n_labels_parcellation, - nb_levels=5, - nb_conv_per_level=2, - conv_size=3, - nb_features=24, - feat_mult=2, - activation='elu', - batch_norm=-1, - name='unet_parc') + net = nrn_models.unet( + input_model=net, + input_shape=[None, None, None, 3], + nb_labels=n_labels_parcellation, + nb_levels=5, + nb_conv_per_level=2, + conv_size=3, + nb_features=24, + feat_mult=2, + activation="elu", + batch_norm=-1, + name="unet_parc", + ) net.load_weights(path_model_parcellation, by_name=True) # smooth predictions last_tensor = net.output last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list()) last_tensor = layers.GaussianBlur(sigma=0.5)(last_tensor) - net = Model(inputs=net.inputs, outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor]) + net = Model( + inputs=net.inputs, + outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor], + ) # add CNN regressor for automated QC if needed if do_qc: n_labels_qc = len(np.unique(labels_qc)) # transition between the two networks: one_hot -> argmax -> qc_labels -> one_hot - shape_prediction = KL.Input([3], dtype='int32') + shape_prediction = KL.Input([3], dtype="int32") if do_parcellation: - last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x[0], axis=-1), 'int32'))(net.outputs) + last_tensor = KL.Lambda( + lambda x: tf.cast(tf.argmax(x[0], axis=-1), "int32") + )(net.outputs) else: - last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(net.output) + last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), "int32"))( + net.output + ) last_tensor = MakeShape(input_shape_qc)([last_tensor, shape_prediction]) - last_tensor = layers.ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor) + last_tensor = layers.ConvertLabels( + np.arange(n_labels_seg), labels_segmentation + )(last_tensor) last_tensor = layers.ConvertLabels(labels_segmentation, labels_qc)(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=n_labels_qc, axis=-1))(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=n_labels_qc, axis=-1) + )(last_tensor) net = Model(inputs=[*net.inputs, shape_prediction], outputs=last_tensor) # build QC regressor network - net = nrn_models.conv_enc(input_model=net, - input_shape=[None, None, None, 1], - nb_levels=4, - nb_conv_per_level=2, - conv_size=5, - nb_features=24, - feat_mult=2, - activation='relu', - batch_norm=-1, - use_residuals=True, - name='qc') + net = nrn_models.conv_enc( + input_model=net, + input_shape=[None, None, None, 1], + nb_levels=4, + nb_conv_per_level=2, + conv_size=5, + nb_features=24, + feat_mult=2, + activation="relu", + batch_norm=-1, + use_residuals=True, + name="qc", + ) last_tensor = net.outputs[0] - conv_kwargs = {'padding': 'same', 'activation': 'relu', 'data_format': 'channels_last'} - last_tensor = KL.MaxPool3D(pool_size=(2, 2, 2), name='qc_maxpool_3', padding='same')(last_tensor) - last_tensor = KL.Conv3D(n_labels_qc, kernel_size=5, **conv_kwargs, name='qc_final_conv_0')(last_tensor) - last_tensor = KL.Conv3D(n_labels_qc, kernel_size=5, **conv_kwargs, name='qc_final_conv_1')(last_tensor) - last_tensor = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name='qc_final_pred')(last_tensor) + conv_kwargs = { + "padding": "same", + "activation": "relu", + "data_format": "channels_last", + } + last_tensor = KL.MaxPool3D( + pool_size=(2, 2, 2), name="qc_maxpool_3", padding="same" + )(last_tensor) + last_tensor = KL.Conv3D( + n_labels_qc, kernel_size=5, **conv_kwargs, name="qc_final_conv_0" + )(last_tensor) + last_tensor = KL.Conv3D( + n_labels_qc, kernel_size=5, **conv_kwargs, name="qc_final_conv_1" + )(last_tensor) + last_tensor = KL.Lambda( + lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name="qc_final_pred" + )(last_tensor) # build model if do_parcellation: - outputs = [net.get_layer(name_segm_prediction_layer).output, - net.get_layer('unet_parc_prediction').output, - last_tensor] + outputs = [ + net.get_layer(name_segm_prediction_layer).output, + net.get_layer("unet_parc_prediction").output, + last_tensor, + ] else: outputs = [net.get_layer(name_segm_prediction_layer).output, last_tensor] net = Model(inputs=net.inputs, outputs=outputs) @@ -686,20 +963,40 @@ def build_model(path_model_segmentation, return net -def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx, - labels_segmentation, labels_parcellation, aff, im_res, fast, topology_classes, v1): +def postprocess( + post_patch_seg, + post_patch_parc, + shape, + pad_idx, + crop_idx, + labels_segmentation, + labels_parcellation, + aff, + im_res, + fast, + topology_classes, + v1, +): # get posteriors post_patch_seg = np.squeeze(post_patch_seg) if fast | (topology_classes is None): - post_patch_seg = edit_volumes.crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False) + post_patch_seg = edit_volumes.crop_volume_with_idx( + post_patch_seg, pad_idx, n_dims=3, return_copy=False + ) # keep biggest connected component tmp_post_patch_seg = post_patch_seg[..., 1:] post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25 - post_patch_seg_mask = edit_volumes.get_largest_connected_component(post_patch_seg_mask) - post_patch_seg_mask = np.stack([post_patch_seg_mask]*tmp_post_patch_seg.shape[-1], axis=-1) - tmp_post_patch_seg = edit_volumes.mask_volume(tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False) + post_patch_seg_mask = edit_volumes.get_largest_connected_component( + post_patch_seg_mask + ) + post_patch_seg_mask = np.stack( + [post_patch_seg_mask] * tmp_post_patch_seg.shape[-1], axis=-1 + ) + tmp_post_patch_seg = edit_volumes.mask_volume( + tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False + ) post_patch_seg[..., 1:] = tmp_post_patch_seg # reset posteriors to zero outside the largest connected component of each topological class @@ -711,49 +1008,81 @@ def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx, tmp_mask = edit_volumes.get_largest_connected_component(tmp_mask) for idx in tmp_topology_indices: post_patch_seg[..., idx] *= tmp_mask - post_patch_seg = edit_volumes.crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False) + post_patch_seg = edit_volumes.crop_volume_with_idx( + post_patch_seg, pad_idx, n_dims=3, return_copy=False + ) else: post_patch_seg_mask = post_patch_seg > 0.2 post_patch_seg[..., 1:] *= post_patch_seg_mask[..., 1:] # get hard segmentation post_patch_seg /= np.sum(post_patch_seg, axis=-1)[..., np.newaxis] - seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype('int32')].astype('int32') + seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype("int32")].astype( + "int32" + ) # postprocess parcellation if post_patch_parc is not None: post_patch_parc = np.squeeze(post_patch_parc) - post_patch_parc = edit_volumes.crop_volume_with_idx(post_patch_parc, pad_idx, n_dims=3, return_copy=False) + post_patch_parc = edit_volumes.crop_volume_with_idx( + post_patch_parc, pad_idx, n_dims=3, return_copy=False + ) mask = (seg_patch == 3) | (seg_patch == 42) post_patch_parc[..., 0] = np.ones_like(post_patch_parc[..., 0]) - post_patch_parc[..., 0] = edit_volumes.mask_volume(post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False) + post_patch_parc[..., 0] = edit_volumes.mask_volume( + post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False + ) post_patch_parc /= np.sum(post_patch_parc, axis=-1)[..., np.newaxis] - parc_patch = labels_parcellation[post_patch_parc.argmax(-1).astype('int32')].astype('int32') + parc_patch = labels_parcellation[ + post_patch_parc.argmax(-1).astype("int32") + ].astype("int32") seg_patch[mask] = parc_patch[mask] # paste patches back to matrix of original image size if crop_idx is not None: # we need to go through this because of the posteriors of the background, otherwise pad_volume would work - seg = np.zeros(shape=shape, dtype='int32') + seg = np.zeros(shape=shape, dtype="int32") posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]]) posteriors[..., 0] = np.ones(shape) # place background around patch - seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch - posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_seg + seg[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ] = seg_patch + posteriors[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + :, + ] = post_patch_seg else: seg = seg_patch posteriors = post_patch_seg # align prediction back to first orientation - seg = edit_volumes.align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False) - posteriors = edit_volumes.align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False) + seg = edit_volumes.align_volume_to_ref( + seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False + ) + posteriors = edit_volumes.align_volume_to_ref( + posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False + ) # compute volumes - volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1))) + volumes = np.sum( + posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)) + ) if not v1: volumes = np.concatenate([np.array([np.sum(volumes)]), volumes]) if post_patch_parc is not None: - volumes_parc = np.sum(post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1))) - total_volume_cortex = np.sum(volumes[np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] - 1]) + volumes_parc = np.sum( + post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)) + ) + total_volume_cortex = np.sum( + volumes[ + np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] + - 1 + ] + ) volumes_parc = volumes_parc / np.sum(volumes_parc) * total_volume_cortex volumes = np.concatenate([volumes, volumes_parc]) volumes = np.around(volumes * np.prod(im_res), 3) @@ -777,7 +1106,9 @@ def get_config(self): def build(self, input_shape): self.n_dims = input_shape[1][1] - self.cropping_shape = np.array(utils.reformat_to_list(self.target_shape, length=self.n_dims)) + self.cropping_shape = np.array( + utils.reformat_to_list(self.target_shape, length=self.n_dims) + ) self.built = True super(MakeShape, self).build(input_shape) @@ -791,28 +1122,42 @@ def _single_process(self, inputs): # find cropping indices mask = tf.logical_and(tf.not_equal(x, 0), tf.not_equal(x, 24)) - indices = tf.cast(tf.where(mask), 'int32') - - min_idx = K.switch(tf.equal(tf.shape(indices)[0], 0), - tf.zeros(self.n_dims, dtype='int32'), - tf.maximum(tf.reduce_min(indices, axis=0), 0)) - max_idx = K.switch(tf.equal(tf.shape(indices)[0], 0), - tf.minimum(shape, self.cropping_shape), - tf.minimum(tf.reduce_max(indices, axis=0) + 1, shape)) + indices = tf.cast(tf.where(mask), "int32") + + min_idx = K.switch( + tf.equal(tf.shape(indices)[0], 0), + tf.zeros(self.n_dims, dtype="int32"), + tf.maximum(tf.reduce_min(indices, axis=0), 0), + ) + max_idx = K.switch( + tf.equal(tf.shape(indices)[0], 0), + tf.minimum(shape, self.cropping_shape), + tf.minimum(tf.reduce_max(indices, axis=0) + 1, shape), + ) # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx - min_idx = min_idx - tf.cast(tf.math.ceil((self.cropping_shape - intermediate_vol_shape) / 2), 'int32') - max_idx = max_idx + tf.cast(tf.math.floor((self.cropping_shape - intermediate_vol_shape) / 2), 'int32') + min_idx = min_idx - tf.cast( + tf.math.ceil((self.cropping_shape - intermediate_vol_shape) / 2), "int32" + ) + max_idx = max_idx + tf.cast( + tf.math.floor((self.cropping_shape - intermediate_vol_shape) / 2), "int32" + ) tmp_min_idx = tf.maximum(min_idx, 0) tmp_max_idx = tf.minimum(max_idx, shape) - x = tf.slice(x, begin=tmp_min_idx, size=tf.minimum(tmp_max_idx - tmp_min_idx, shape)) + x = tf.slice( + x, begin=tmp_min_idx, size=tf.minimum(tmp_max_idx - tmp_min_idx, shape) + ) # pad if necessary min_padding = tf.abs(tf.minimum(min_idx, 0)) max_padding = tf.maximum(max_idx - shape, 0) - x = K.switch(tf.reduce_any(tf.logical_or(tf.greater(min_padding, 0), tf.greater(max_padding, 0))), - tf.pad(x, tf.stack([min_padding, max_padding], axis=1)), - x) + x = K.switch( + tf.reduce_any( + tf.logical_or(tf.greater(min_padding, 0), tf.greater(max_padding, 0)) + ), + tf.pad(x, tf.stack([min_padding, max_padding], axis=1)), + x, + ) return x diff --git a/nobrainer/ext/SynthSeg/sample_segmentation_pairs_d.py b/nobrainer/ext/SynthSeg/sample_segmentation_pairs_d.py index 557fea4f..6f27fdc9 100644 --- a/nobrainer/ext/SynthSeg/sample_segmentation_pairs_d.py +++ b/nobrainer/ext/SynthSeg/sample_segmentation_pairs_d.py @@ -13,46 +13,49 @@ License. """ +import copy + # python imports import os -import copy -import numpy as np -import tensorflow as tf -from keras import models -import keras.layers as KL # third-party imports +from ext.lab2im import edit_tensors as l2i_et +from ext.lab2im import edit_volumes from ext.lab2im import layers as layers -from ext.lab2im import utils, edit_volumes +from ext.lab2im import utils from ext.neuron import models as nrn_models -from ext.lab2im import edit_tensors as l2i_et - +from keras import models +import keras.layers as KL +import numpy as np +import tensorflow as tf -def sample_segmentation_pairs(image_dir, - labels_dir, - results_dir, - n_examples, - path_model, - segmentation_labels, - n_neutral_labels=None, - batchsize=1, - flipping=True, - scaling_bounds=.15, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.04, - min_res=1., - max_res_iso=4., - max_res_aniso=8., - noise_std_lr=3., - blur_range=1.03, - bias_field_std=.5, - bias_scale=.025, - noise_std=10, - gamma_std=.5): +def sample_segmentation_pairs( + image_dir, + labels_dir, + results_dir, + n_examples, + path_model, + segmentation_labels, + n_neutral_labels=None, + batchsize=1, + flipping=True, + scaling_bounds=0.15, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.04, + min_res=1.0, + max_res_iso=4.0, + max_res_aniso=8.0, + noise_std_lr=3.0, + blur_range=1.03, + bias_field_std=0.5, + bias_scale=0.025, + noise_std=10, + gamma_std=0.5, +): """ This function enables us to obtain segmentations from a segmentation network along with the corresponding ground truths. The segmentations are obtained by taking real images and aggressively augmenting them with spatial and @@ -135,64 +138,76 @@ def sample_segmentation_pairs(image_dir, # prepare data files path_images = utils.list_images_in_folder(image_dir) path_labels = utils.list_images_in_folder(labels_dir) - assert len(path_images) == len(path_labels), "There should be as many images as label maps." + assert len(path_images) == len( + path_labels + ), "There should be as many images as label maps." # prepare results subfolders - gt_result_dir = os.path.join(results_dir, 'labels_gt') - pred_result_dir = os.path.join(results_dir, 'labels_seg') + gt_result_dir = os.path.join(results_dir, "labels_gt") + pred_result_dir = os.path.join(results_dir, "labels_seg") utils.mkdir(gt_result_dir) utils.mkdir(pred_result_dir) # get label lists - segmentation_labels, _ = utils.get_list_labels(label_list=segmentation_labels, labels_dir=labels_dir) + segmentation_labels, _ = utils.get_list_labels( + label_list=segmentation_labels, labels_dir=labels_dir + ) n_labels = np.size(np.unique(segmentation_labels)) # create augmentation model - im_shape, _, n_dims, n_channels, _, atlas_res = utils.get_volume_info(path_images[0], aff_ref=np.eye(4)) - augmentation_model = build_augmentation_model(im_shape, - n_channels, - segmentation_labels, - n_neutral_labels, - n_dims, - atlas_res, - flipping=flipping, - aff=np.eye(4), - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_shape_factor=nonlin_scale, - min_res=min_res, - max_res_iso=max_res_iso, - max_res_aniso=max_res_aniso, - noise_std_lr=noise_std_lr, - blur_range=blur_range, - bias_field_std=bias_field_std, - bias_shape_factor=bias_scale, - noise_std=noise_std, - gamma_std=gamma_std) + im_shape, _, n_dims, n_channels, _, atlas_res = utils.get_volume_info( + path_images[0], aff_ref=np.eye(4) + ) + augmentation_model = build_augmentation_model( + im_shape, + n_channels, + segmentation_labels, + n_neutral_labels, + n_dims, + atlas_res, + flipping=flipping, + aff=np.eye(4), + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_shape_factor=nonlin_scale, + min_res=min_res, + max_res_iso=max_res_iso, + max_res_aniso=max_res_aniso, + noise_std_lr=noise_std_lr, + blur_range=blur_range, + bias_field_std=bias_field_std, + bias_shape_factor=bias_scale, + noise_std=noise_std, + gamma_std=gamma_std, + ) unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] # prepare the segmentation model - unet_model = nrn_models.unet(nb_features=24, - input_shape=unet_input_shape, - nb_levels=5, - conv_size=3, - nb_labels=n_labels, - feat_mult=2, - nb_conv_per_level=2, - batch_norm=-1, - activation='elu', - input_model=augmentation_model) + unet_model = nrn_models.unet( + nb_features=24, + input_shape=unet_input_shape, + nb_levels=5, + conv_size=3, + nb_labels=n_labels, + feat_mult=2, + nb_conv_per_level=2, + batch_norm=-1, + activation="elu", + input_model=augmentation_model, + ) # get generative model train_example_gen = build_model_inputs(path_images, path_labels, batchsize) segmentation_labels = np.unique(segmentation_labels) # redefine model to output deformed image, deformed GT labels, and predicted labels - list_output_tensors = [unet_model.get_layer('labels_out').output, unet_model.output] - generation_model = models.Model(inputs=unet_model.inputs, outputs=list_output_tensors) + list_output_tensors = [unet_model.get_layer("labels_out").output, unet_model.output] + generation_model = models.Model( + inputs=unet_model.inputs, outputs=list_output_tensors + ) generation_model.load_weights(path_model, by_name=True) # generate ! @@ -204,103 +219,143 @@ def sample_segmentation_pairs(image_dir, outputs = generation_model.predict(next(train_example_gen)) # save results - for (output, name, res_dir) in zip(outputs, - ['labels_gt', 'labels_pred_argmax_convert'], - [gt_result_dir, pred_result_dir]): + for output, name, res_dir in zip( + outputs, + ["labels_gt", "labels_pred_argmax_convert"], + [gt_result_dir, pred_result_dir], + ): for b in range(batchsize): tmp_name = copy.deepcopy(name) tmp_output = np.squeeze(output[b, ...]) - if '_argmax' in tmp_name: + if "_argmax" in tmp_name: tmp_output = tmp_output.argmax(-1) - tmp_name = tmp_name.replace('_argmax', '') - if '_convert' in tmp_name: + tmp_name = tmp_name.replace("_argmax", "") + if "_convert" in tmp_name: tmp_output = segmentation_labels[tmp_output] - tmp_name = tmp_name.replace('_convert', '') - path = os.path.join(res_dir, tmp_name + '_%.{}d'.format(n) % i + '.nii.gz') + tmp_name = tmp_name.replace("_convert", "") + path = os.path.join( + res_dir, tmp_name + "_%.{}d".format(n) % i + ".nii.gz" + ) if batchsize > 1: - path = path.replace('.nii.gz', '_%s.nii.gz' % (b + 1)) + path = path.replace(".nii.gz", "_%s.nii.gz" % (b + 1)) utils.save_volume(tmp_output, np.eye(4), None, path) i += 1 -def build_augmentation_model(im_shape, - n_channels, - segmentation_labels, - n_neutral_labels, - n_dims, - atlas_res, - flipping=True, - aff=None, - scaling_bounds=0.15, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_shape_factor=.0625, - min_res=1., - max_res_iso=4., - max_res_aniso=8., - noise_std_lr=3., - blur_range=1.03, - bias_field_std=.5, - bias_shape_factor=.025, - noise_std=10, - gamma_std=.5): +def build_augmentation_model( + im_shape, + n_channels, + segmentation_labels, + n_neutral_labels, + n_dims, + atlas_res, + flipping=True, + aff=None, + scaling_bounds=0.15, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_shape_factor=0.0625, + min_res=1.0, + max_res_iso=4.0, + max_res_aniso=8.0, + noise_std_lr=3.0, + blur_range=1.03, + bias_field_std=0.5, + bias_shape_factor=0.025, + noise_std=10, + gamma_std=0.5, +): # define model inputs - image_input = KL.Input(shape=im_shape+[n_channels], name='image_input') - labels_input = KL.Input(shape=im_shape + [1], name='labels_input', dtype='int32') + image_input = KL.Input(shape=im_shape + [n_channels], name="image_input") + labels_input = KL.Input(shape=im_shape + [1], name="labels_input", dtype="int32") # deform labels - labels, image = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_shape_factor, - inter_method=['nearest', 'linear'])([labels_input, image_input]) + labels, image = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_shape_factor, + inter_method=["nearest", "linear"], + )([labels_input, image_input]) # flipping if flipping: - assert aff is not None, 'aff should not be None if flipping is True' - labels, image = layers.RandomFlip(edit_volumes.get_ras_axes(aff, n_dims)[0], [True, False], - segmentation_labels, n_neutral_labels)([labels, image]) + assert aff is not None, "aff should not be None if flipping is True" + labels, image = layers.RandomFlip( + edit_volumes.get_ras_axes(aff, n_dims)[0], + [True, False], + segmentation_labels, + n_neutral_labels, + )([labels, image]) # apply bias field if bias_field_std > 0: - image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor, False)(image) + image = layers.BiasFieldCorruption(bias_field_std, bias_shape_factor, False)( + image + ) # intensity augmentation - image = layers.IntensityAugmentation(noise_std, gamma_std=gamma_std, contrast_inversion=True)(image) + image = layers.IntensityAugmentation( + noise_std, gamma_std=gamma_std, contrast_inversion=True + )(image) # loop over channels channels = list() - split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) for i, channel in enumerate(split): # reformat resolution range parameters - min_res = np.array(utils.reformat_to_list(min_res, length=n_dims, dtype='float')) - max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) - max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + min_res = np.array( + utils.reformat_to_list(min_res, length=n_dims, dtype="float") + ) + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) max_res = np.maximum(max_res_iso, max_res_aniso) # sample resolution and thickness (blurring res) - resolution, blur_res = layers.SampleResolution(min_res, max_res_iso, max_res_aniso)(channel) - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) + resolution, blur_res = layers.SampleResolution( + min_res, max_res_iso, max_res_aniso + )(channel) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) # blur and downsample/resample - channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), blur_range)([channel, sigma]) - channel = layers.MimicAcquisition(atlas_res, min_res, im_shape, False, noise_std_lr)([channel, resolution]) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), blur_range + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, min_res, im_shape, False, noise_std_lr + )([channel, resolution]) channels.append(channel) # concatenate all channels back - image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) # build model (dummy layer enables to keep the labels when plugging this model to other models) - labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'), name='labels_out')(labels) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) - brain_model = models.Model(inputs=[image_input, labels_input], outputs=[image, labels]) + labels = KL.Lambda(lambda x: tf.cast(x, dtype="int32"), name="labels_out")(labels) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = models.Model( + inputs=[image_input, labels_input], outputs=[image, labels] + ) return brain_model @@ -330,7 +385,9 @@ def build_model_inputs(path_images, path_label_maps, batchsize=1): list_images.append(utils.add_axis(image, axis=[0, -1])) # add labels - labels = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) + labels = utils.load_volume( + path_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) list_label_maps.append(utils.add_axis(labels, axis=[0, -1])) # build list of inputs of augmentation model diff --git a/nobrainer/ext/SynthSeg/training.py b/nobrainer/ext/SynthSeg/training.py index ffeca481..c90e4280 100644 --- a/nobrainer/ext/SynthSeg/training.py +++ b/nobrainer/ext/SynthSeg/training.py @@ -17,69 +17,71 @@ License. """ +from inspect import getmembers, isclass # python imports import os -import keras -import numpy as np -import tensorflow as tf -from keras import models -import keras.callbacks as KC -from keras.optimizers import Adam -from inspect import getmembers, isclass # project imports from SynthSeg import metrics_model as metrics from SynthSeg.brain_generator import BrainGenerator # third-party imports -from ext.lab2im import utils, layers +from ext.lab2im import layers, utils from ext.neuron import layers as nrn_layers from ext.neuron import models as nrn_models +import keras +from keras import models +import keras.callbacks as KC +from keras.optimizers import Adam +import numpy as np +import tensorflow as tf -def training(labels_dir, - model_dir, - generation_labels=None, - n_neutral_labels=None, - segmentation_labels=None, - subjects_prob=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False, - flipping=True, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - randomise_res=True, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - return_gradients=False, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - lr=1e-4, - wl2_epochs=1, - dice_epochs=50, - steps_per_epoch=10000, - checkpoint=None): +def training( + labels_dir, + model_dir, + generation_labels=None, + n_neutral_labels=None, + segmentation_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + lr=1e-4, + wl2_epochs=1, + dice_epochs=50, + steps_per_epoch=10000, + checkpoint=None, +): """ This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on label maps. We regroup the parameters in three categories: Generation, Architecture, Training. @@ -227,11 +229,16 @@ def training(labels_dir, """ # check epochs - assert (wl2_epochs > 0) | (dice_epochs > 0), \ - 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) + assert (wl2_epochs > 0) | ( + dice_epochs > 0 + ), "either wl2_epochs or dice_epochs must be positive, had {0} and {1}".format( + wl2_epochs, dice_epochs + ) # get label lists - generation_labels, _ = utils.get_list_labels(label_list=generation_labels, labels_dir=labels_dir) + generation_labels, _ = utils.get_list_labels( + label_list=generation_labels, labels_dir=labels_dir + ) if segmentation_labels is not None: segmentation_labels, _ = utils.get_list_labels(label_list=segmentation_labels) else: @@ -239,102 +246,150 @@ def training(labels_dir, n_segmentation_labels = len(np.unique(segmentation_labels)) # instantiate BrainGenerator object - brain_generator = BrainGenerator(labels_dir=labels_dir, - generation_labels=generation_labels, - n_neutral_labels=n_neutral_labels, - output_labels=segmentation_labels, - subjects_prob=subjects_prob, - batchsize=batchsize, - n_channels=n_channels, - target_res=target_res, - output_shape=output_shape, - output_div_by_n=2 ** n_levels, - generation_classes=generation_classes, - prior_distributions=prior_distributions, - prior_means=prior_means, - prior_stds=prior_stds, - use_specific_stats_for_channel=use_specific_stats_for_channel, - mix_prior_and_random=mix_prior_and_random, - flipping=flipping, - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - randomise_res=randomise_res, - max_res_iso=max_res_iso, - max_res_aniso=max_res_aniso, - data_res=data_res, - thickness=thickness, - bias_field_std=bias_field_std, - bias_scale=bias_scale, - return_gradients=return_gradients) + brain_generator = BrainGenerator( + labels_dir=labels_dir, + generation_labels=generation_labels, + n_neutral_labels=n_neutral_labels, + output_labels=segmentation_labels, + subjects_prob=subjects_prob, + batchsize=batchsize, + n_channels=n_channels, + target_res=target_res, + output_shape=output_shape, + output_div_by_n=2**n_levels, + generation_classes=generation_classes, + prior_distributions=prior_distributions, + prior_means=prior_means, + prior_stds=prior_stds, + use_specific_stats_for_channel=use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random, + flipping=flipping, + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + randomise_res=randomise_res, + max_res_iso=max_res_iso, + max_res_aniso=max_res_aniso, + data_res=data_res, + thickness=thickness, + bias_field_std=bias_field_std, + bias_scale=bias_scale, + return_gradients=return_gradients, + ) # generation model labels_to_image_model = brain_generator.labels_to_image_model unet_input_shape = brain_generator.model_output_shape # prepare the segmentation model - unet_model = nrn_models.unet(input_model=labels_to_image_model, - input_shape=unet_input_shape, - nb_labels=n_segmentation_labels, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - name='unet') + unet_model = nrn_models.unet( + input_model=labels_to_image_model, + input_shape=unet_input_shape, + nb_labels=n_segmentation_labels, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + name="unet", + ) # input generator - input_generator = utils.build_training_generator(brain_generator.model_inputs_generator, batchsize) + input_generator = utils.build_training_generator( + brain_generator.model_inputs_generator, batchsize + ) # pre-training with weighted L2, input is fit to the softmax rather than the probabilities if wl2_epochs > 0: - wl2_model = models.Model(unet_model.inputs, [unet_model.get_layer('unet_likelihood').output]) - wl2_model = metrics.metrics_model(wl2_model, segmentation_labels, 'wl2') - train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) - checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) + wl2_model = models.Model( + unet_model.inputs, [unet_model.get_layer("unet_likelihood").output] + ) + wl2_model = metrics.metrics_model(wl2_model, segmentation_labels, "wl2") + train_model( + wl2_model, + input_generator, + lr, + wl2_epochs, + steps_per_epoch, + model_dir, + "wl2", + checkpoint, + ) + checkpoint = os.path.join(model_dir, "wl2_%03d.h5" % wl2_epochs) # fine-tuning with dice metric - dice_model = metrics.metrics_model(unet_model, segmentation_labels, 'dice') - train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) - - -def train_model(model, - generator, - learning_rate, - n_epochs, - n_steps, - model_dir, - metric_type, - path_checkpoint=None, - reinitialise_momentum=False): + dice_model = metrics.metrics_model(unet_model, segmentation_labels, "dice") + train_model( + dice_model, + input_generator, + lr, + dice_epochs, + steps_per_epoch, + model_dir, + "dice", + checkpoint, + ) + + +def train_model( + model, + generator, + learning_rate, + n_epochs, + n_steps, + model_dir, + metric_type, + path_checkpoint=None, + reinitialise_momentum=False, +): # prepare model and log folders utils.mkdir(model_dir) - log_dir = os.path.join(model_dir, 'logs') + log_dir = os.path.join(model_dir, "logs") utils.mkdir(log_dir) # model saving callback - save_file_name = os.path.join(model_dir, '%s_{epoch:03d}.h5' % metric_type) + save_file_name = os.path.join(model_dir, "%s_{epoch:03d}.h5" % metric_type) callbacks = [KC.ModelCheckpoint(save_file_name, verbose=1)] # TensorBoard callback - if metric_type == 'dice': - callbacks.append(KC.TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False)) + if metric_type == "dice": + callbacks.append( + KC.TensorBoard( + log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False + ) + ) compile_model = True init_epoch = 0 if path_checkpoint is not None: if metric_type in path_checkpoint: - init_epoch = int(os.path.basename(path_checkpoint).split(metric_type)[1][1:-3]) + init_epoch = int( + os.path.basename(path_checkpoint).split(metric_type)[1][1:-3] + ) if (not reinitialise_momentum) & (metric_type in path_checkpoint): - custom_l2i = {key: value for (key, value) in getmembers(layers, isclass) if key != 'Layer'} - custom_nrn = {key: value for (key, value) in getmembers(nrn_layers, isclass) if key != 'Layer'} - custom_objects = {**custom_l2i, **custom_nrn, 'tf': tf, 'keras': keras, 'loss': metrics.IdentityLoss().loss} + custom_l2i = { + key: value + for (key, value) in getmembers(layers, isclass) + if key != "Layer" + } + custom_nrn = { + key: value + for (key, value) in getmembers(nrn_layers, isclass) + if key != "Layer" + } + custom_objects = { + **custom_l2i, + **custom_nrn, + "tf": tf, + "keras": keras, + "loss": metrics.IdentityLoss().loss, + } model = models.load_model(path_checkpoint, custom_objects=custom_objects) compile_model = False else: @@ -342,11 +397,15 @@ def train_model(model, # compile if compile_model: - model.compile(optimizer=Adam(lr=learning_rate), loss=metrics.IdentityLoss().loss) + model.compile( + optimizer=Adam(lr=learning_rate), loss=metrics.IdentityLoss().loss + ) # fit - model.fit_generator(generator, - epochs=n_epochs, - steps_per_epoch=n_steps, - callbacks=callbacks, - initial_epoch=init_epoch) + model.fit_generator( + generator, + epochs=n_epochs, + steps_per_epoch=n_steps, + callbacks=callbacks, + initial_epoch=init_epoch, + ) diff --git a/nobrainer/ext/SynthSeg/training_denoiser.py b/nobrainer/ext/SynthSeg/training_denoiser.py index 2ed4cace..bea6787a 100644 --- a/nobrainer/ext/SynthSeg/training_denoiser.py +++ b/nobrainer/ext/SynthSeg/training_denoiser.py @@ -13,53 +13,54 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -from keras import models -from keras import layers as KL # project imports from SynthSeg import metrics_model as metrics -from SynthSeg.training import train_model from SynthSeg.labels_to_image_model import get_shapes +from SynthSeg.training import train_model from SynthSeg.training_supervised import build_model_inputs # third-party imports -from ext.lab2im import utils, layers +from ext.lab2im import layers, utils from ext.neuron import models as nrn_models +from keras import layers as KL +from keras import models +import numpy as np +import tensorflow as tf -def training(list_paths_input_labels, - list_paths_target_labels, - model_dir, - input_segmentation_labels, - target_segmentation_labels=None, - subjects_prob=None, - batchsize=1, - output_shape=None, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - nonlin_std=3., - nonlin_scale=.04, - prob_erosion_dilation=0.3, - min_erosion_dilation=4, - max_erosion_dilation=5, - n_levels=5, - nb_conv_per_level=2, - conv_size=5, - unet_feat_count=16, - feat_multiplier=2, - activation='elu', - skip_n_concatenations=2, - lr=1e-4, - wl2_epochs=1, - dice_epochs=50, - steps_per_epoch=10000, - checkpoint=None): +def training( + list_paths_input_labels, + list_paths_target_labels, + model_dir, + input_segmentation_labels, + target_segmentation_labels=None, + subjects_prob=None, + batchsize=1, + output_shape=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + nonlin_std=3.0, + nonlin_scale=0.04, + prob_erosion_dilation=0.3, + min_erosion_dilation=4, + max_erosion_dilation=5, + n_levels=5, + nb_conv_per_level=2, + conv_size=5, + unet_feat_count=16, + feat_multiplier=2, + activation="elu", + skip_n_concatenations=2, + lr=1e-4, + wl2_epochs=1, + dice_epochs=50, + steps_per_epoch=10000, + checkpoint=None, +): """ This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on @@ -109,9 +110,9 @@ def training(list_paths_input_labels, tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation. :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled tensor for synthesising the elastic deformation field. - + # degradation of the input labels - :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or + :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network. :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random coefficients are applied. Set the minimum erosion/dilation coefficient here. @@ -139,80 +140,115 @@ def training(list_paths_input_labels, """ # check epochs - assert (wl2_epochs > 0) | (dice_epochs > 0), \ - 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) + assert (wl2_epochs > 0) | ( + dice_epochs > 0 + ), "either wl2_epochs or dice_epochs must be positive, had {0} and {1}".format( + wl2_epochs, dice_epochs + ) # prepare data files input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels) if target_segmentation_labels is None: target_label_list = input_label_list else: - target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels) + target_label_list, _ = utils.get_list_labels( + label_list=target_segmentation_labels + ) n_labels = np.size(target_label_list) # create augmentation model - labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4)) - augmentation_model = build_augmentation_model(labels_shape, - input_label_list, - crop_shape=output_shape, - output_div_by_n=2 ** n_levels, - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - prob_erosion_dilation=prob_erosion_dilation, - min_erosion_dilation=min_erosion_dilation, - max_erosion_dilation=max_erosion_dilation) + labels_shape, _, _, _, _, _ = utils.get_volume_info( + list_paths_input_labels[0], aff_ref=np.eye(4) + ) + augmentation_model = build_augmentation_model( + labels_shape, + input_label_list, + crop_shape=output_shape, + output_div_by_n=2**n_levels, + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + prob_erosion_dilation=prob_erosion_dilation, + min_erosion_dilation=min_erosion_dilation, + max_erosion_dilation=max_erosion_dilation, + ) unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] # prepare the segmentation model - l2l_model = nrn_models.unet(input_model=augmentation_model, - input_shape=unet_input_shape, - nb_labels=n_labels, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - skip_n_concatenations=skip_n_concatenations, - name='l2l') + l2l_model = nrn_models.unet( + input_model=augmentation_model, + input_shape=unet_input_shape, + nb_labels=n_labels, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + skip_n_concatenations=skip_n_concatenations, + name="l2l", + ) # input generator - model_inputs = build_model_inputs(path_inputs=list_paths_input_labels, - path_outputs=list_paths_target_labels, - batchsize=batchsize, - subjects_prob=subjects_prob, - dtype_input='int32') + model_inputs = build_model_inputs( + path_inputs=list_paths_input_labels, + path_outputs=list_paths_target_labels, + batchsize=batchsize, + subjects_prob=subjects_prob, + dtype_input="int32", + ) input_generator = utils.build_training_generator(model_inputs, batchsize) # pre-training with weighted L2, input is fit to the softmax rather than the probabilities if wl2_epochs > 0: - wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output]) - wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2') - train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) - checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) + wl2_model = models.Model( + l2l_model.inputs, [l2l_model.get_layer("l2l_likelihood").output] + ) + wl2_model = metrics.metrics_model(wl2_model, target_label_list, "wl2") + train_model( + wl2_model, + input_generator, + lr, + wl2_epochs, + steps_per_epoch, + model_dir, + "wl2", + checkpoint, + ) + checkpoint = os.path.join(model_dir, "wl2_%03d.h5" % wl2_epochs) # fine-tuning with dice metric - dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice') - train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) - - -def build_augmentation_model(labels_shape, - segmentation_labels, - crop_shape=None, - output_div_by_n=None, - scaling_bounds=0.15, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.0625, - prob_erosion_dilation=0.3, - min_erosion_dilation=4, - max_erosion_dilation=7): + dice_model = metrics.metrics_model(l2l_model, target_label_list, "dice") + train_model( + dice_model, + input_generator, + lr, + dice_epochs, + steps_per_epoch, + model_dir, + "dice", + checkpoint, + ) + + +def build_augmentation_model( + labels_shape, + segmentation_labels, + crop_shape=None, + output_div_by_n=None, + scaling_bounds=0.15, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, + prob_erosion_dilation=0.3, + min_erosion_dilation=4, + max_erosion_dilation=7, +): # reformat resolutions and get shapes labels_shape = utils.reformat_to_list(labels_shape) @@ -220,20 +256,32 @@ def build_augmentation_model(labels_shape, n_labels = len(segmentation_labels) # get shapes - crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n) + crop_shape, _ = get_shapes( + labels_shape, + crop_shape, + np.array([1] * n_dims), + np.array([1] * n_dims), + output_div_by_n, + ) # define model inputs - net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32') - target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32') + net_input = KL.Input( + shape=labels_shape + [1], name="l2l_noisy_labels_input", dtype="int32" + ) + target_input = KL.Input( + shape=labels_shape + [1], name="l2l_target_input", dtype="int32" + ) # deform labels - noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method='nearest')([net_input, target_input]) + noisy_labels, target = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )([net_input, target_input]) # cropping if crop_shape != labels_shape: @@ -241,16 +289,19 @@ def build_augmentation_model(labels_shape, # random erosion if prob_erosion_dilation > 0: - noisy_labels = layers.RandomDilationErosion(min_erosion_dilation, - max_erosion_dilation, - prob=prob_erosion_dilation)(noisy_labels) + noisy_labels = layers.RandomDilationErosion( + min_erosion_dilation, max_erosion_dilation, prob=prob_erosion_dilation + )(noisy_labels) # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels) - target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target) - noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels), - name='noisy_labels_out')([noisy_labels, target]) + target = KL.Lambda(lambda x: tf.cast(x[..., 0], "int32"), name="labels_out")(target) + noisy_labels = KL.Lambda( + lambda x: tf.one_hot(x[0][..., 0], depth=n_labels), name="noisy_labels_out" + )([noisy_labels, target]) # build model and return - brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target]) + brain_model = models.Model( + inputs=[net_input, target_input], outputs=[noisy_labels, target] + ) return brain_model diff --git a/nobrainer/ext/SynthSeg/training_group.py b/nobrainer/ext/SynthSeg/training_group.py index d8713356..c8951b58 100644 --- a/nobrainer/ext/SynthSeg/training_group.py +++ b/nobrainer/ext/SynthSeg/training_group.py @@ -18,72 +18,71 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -from keras import models -import keras.layers as KL # project imports from SynthSeg import metrics_model as metrics -from SynthSeg.training import train_model from SynthSeg.brain_generator import BrainGenerator from SynthSeg.labels_to_image_model import get_shapes +from SynthSeg.training import train_model # third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.neuron import models as nrn_models from ext.lab2im import edit_tensors as l2i_et +from ext.lab2im import layers, utils from ext.lab2im.edit_volumes import get_ras_axes +from ext.neuron import models as nrn_models +from keras import models +import keras.layers as KL +import numpy as np +import tensorflow as tf -def training(labels_dir, - model_dir, - generation_labels=None, - grouping_labels=None, - n_neutral_labels=None, - segmentation_labels=None, - subjects_prob=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False, - flipping=True, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - randomise_res=True, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - return_gradients=False, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - lr=1e-4, - wl2_epochs=1, - dice_epochs=50, - steps_per_epoch=10000, - checkpoint=None): - +def training( + labels_dir, + model_dir, + generation_labels=None, + grouping_labels=None, + n_neutral_labels=None, + segmentation_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + lr=1e-4, + wl2_epochs=1, + dice_epochs=50, + steps_per_epoch=10000, + checkpoint=None, +): """ This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on label maps. The difference with training.py is based on the fact that here the UNet now takes two inputs: @@ -242,11 +241,16 @@ def training(labels_dir, """ # check epochs - assert (wl2_epochs > 0) | (dice_epochs > 0), \ - 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) + assert (wl2_epochs > 0) | ( + dice_epochs > 0 + ), "either wl2_epochs or dice_epochs must be positive, had {0} and {1}".format( + wl2_epochs, dice_epochs + ) # get label lists - generation_labels, _ = utils.get_list_labels(label_list=generation_labels, labels_dir=labels_dir) + generation_labels, _ = utils.get_list_labels( + label_list=generation_labels, labels_dir=labels_dir + ) if segmentation_labels is not None: segmentation_labels, _ = utils.get_list_labels(label_list=segmentation_labels) else: @@ -254,69 +258,95 @@ def training(labels_dir, n_segmentation_labels = len(np.unique(segmentation_labels)) # instantiate BrainGenerator object - brain_generator = BrainGeneratorGroup(labels_dir=labels_dir, - generation_labels=generation_labels, - grouping_labels=grouping_labels, - n_neutral_labels=n_neutral_labels, - output_labels=segmentation_labels, - subjects_prob=subjects_prob, - batchsize=batchsize, - n_channels=n_channels, - target_res=target_res, - output_shape=output_shape, - output_div_by_n=2 ** n_levels, - generation_classes=generation_classes, - prior_distributions=prior_distributions, - prior_means=prior_means, - prior_stds=prior_stds, - use_specific_stats_for_channel=use_specific_stats_for_channel, - mix_prior_and_random=mix_prior_and_random, - flipping=flipping, - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - randomise_res=randomise_res, - max_res_iso=max_res_iso, - max_res_aniso=max_res_aniso, - data_res=data_res, - thickness=thickness, - bias_field_std=bias_field_std, - bias_scale=bias_scale, - return_gradients=return_gradients) + brain_generator = BrainGeneratorGroup( + labels_dir=labels_dir, + generation_labels=generation_labels, + grouping_labels=grouping_labels, + n_neutral_labels=n_neutral_labels, + output_labels=segmentation_labels, + subjects_prob=subjects_prob, + batchsize=batchsize, + n_channels=n_channels, + target_res=target_res, + output_shape=output_shape, + output_div_by_n=2**n_levels, + generation_classes=generation_classes, + prior_distributions=prior_distributions, + prior_means=prior_means, + prior_stds=prior_stds, + use_specific_stats_for_channel=use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random, + flipping=flipping, + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + randomise_res=randomise_res, + max_res_iso=max_res_iso, + max_res_aniso=max_res_aniso, + data_res=data_res, + thickness=thickness, + bias_field_std=bias_field_std, + bias_scale=bias_scale, + return_gradients=return_gradients, + ) # generation model labels_to_image_model = brain_generator.labels_to_image_model unet_input_shape = brain_generator.model_output_shape # prepare the segmentation model - unet_model = nrn_models.unet(input_model=labels_to_image_model, - input_shape=unet_input_shape, - nb_labels=n_segmentation_labels, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - name='unet2') + unet_model = nrn_models.unet( + input_model=labels_to_image_model, + input_shape=unet_input_shape, + nb_labels=n_segmentation_labels, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + name="unet2", + ) # input generator - input_generator = utils.build_training_generator(brain_generator.model_inputs_generator, batchsize) + input_generator = utils.build_training_generator( + brain_generator.model_inputs_generator, batchsize + ) # pre-training with weighted L2, input is fit to the softmax rather than the probabilities if wl2_epochs > 0: - wl2_model = models.Model(unet_model.inputs, [unet_model.get_layer('unet2_likelihood').output]) - wl2_model = metrics.metrics_model(wl2_model, segmentation_labels, 'wl2') - train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) - checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) + wl2_model = models.Model( + unet_model.inputs, [unet_model.get_layer("unet2_likelihood").output] + ) + wl2_model = metrics.metrics_model(wl2_model, segmentation_labels, "wl2") + train_model( + wl2_model, + input_generator, + lr, + wl2_epochs, + steps_per_epoch, + model_dir, + "wl2", + checkpoint, + ) + checkpoint = os.path.join(model_dir, "wl2_%03d.h5" % wl2_epochs) # fine-tuning with dice metric - dice_model = metrics.metrics_model(unet_model, segmentation_labels, 'dice') - train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) + dice_model = metrics.metrics_model(unet_model, segmentation_labels, "dice") + train_model( + dice_model, + input_generator, + lr, + dice_epochs, + steps_per_epoch, + model_dir, + "dice", + checkpoint, + ) class BrainGeneratorGroup(BrainGenerator): @@ -324,111 +354,150 @@ class BrainGeneratorGroup(BrainGenerator): def __init__(self, grouping_labels=None, **kwargs): super(BrainGeneratorGroup, self).__init__(**kwargs) self.grouping_labels = utils.load_array_if_path(grouping_labels) - self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model_group() + self.labels_to_image_model, self.model_output_shape = ( + self._build_labels_to_image_model_group() + ) def _build_labels_to_image_model_group(self): # build_model - lab_to_im_model = labels_to_image_model_group(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - n_neutral_labels=self.n_neutral_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - grouping_labels=self.grouping_labels, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - flipping=self.flipping, - aff=np.eye(4), - scaling_bounds=self.scaling_bounds, - rotation_bounds=self.rotation_bounds, - shearing_bounds=self.shearing_bounds, - translation_bounds=self.translation_bounds, - nonlin_std=self.nonlin_std, - nonlin_scale=self.nonlin_scale, - randomise_res=self.randomise_res, - max_res_iso=self.max_res_iso, - max_res_aniso=self.max_res_aniso, - data_res=self.data_res, - thickness=self.thickness, - bias_field_std=self.bias_field_std, - bias_scale=self.bias_scale, - return_gradients=self.return_gradients) + lab_to_im_model = labels_to_image_model_group( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + n_neutral_labels=self.n_neutral_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + grouping_labels=self.grouping_labels, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + flipping=self.flipping, + aff=np.eye(4), + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + randomise_res=self.randomise_res, + max_res_iso=self.max_res_iso, + max_res_aniso=self.max_res_aniso, + data_res=self.data_res, + thickness=self.thickness, + bias_field_std=self.bias_field_std, + bias_scale=self.bias_scale, + return_gradients=self.return_gradients, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape -def labels_to_image_model_group(labels_shape, - n_channels, - generation_labels, - output_labels, - n_neutral_labels, - atlas_res, - target_res, - grouping_labels, - output_shape=None, - output_div_by_n=None, - flipping=True, - aff=None, - scaling_bounds=0.2, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.0625, - randomise_res=False, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - return_gradients=False): +def labels_to_image_model_group( + labels_shape, + n_channels, + generation_labels, + output_labels, + n_neutral_labels, + atlas_res, + target_res, + grouping_labels, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, + randomise_res=False, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, +): # reformat resolutions labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims, n_channels) - data_res = atlas_res if data_res is None else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) - thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + data_res = ( + atlas_res + if data_res is None + else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) + ) + thickness = ( + data_res + if thickness is None + else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + ) atlas_res = atlas_res[0] - target_res = atlas_res if target_res is None else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if target_res is None + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape + [1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='std_devs_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="std_devs_input" + ) list_inputs = [labels_input, means_input, stds_input] # deform labels - labels = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )(labels_input) # get mask and further deforms/dilates it - grouped_labels = layers.ConvertLabels(generation_labels, dest_values=grouping_labels)(labels) - grouped_labels = layers.RandomSpatialDeformation(scaling_bounds=.05, - rotation_bounds=3, - shearing_bounds=0.007, - translation_bounds=2, - nonlin_std=1.5, - nonlin_scale=0.04, - inter_method='nearest', - prob_deform=0.9)(grouped_labels) + grouped_labels = layers.ConvertLabels( + generation_labels, dest_values=grouping_labels + )(labels) + grouped_labels = layers.RandomSpatialDeformation( + scaling_bounds=0.05, + rotation_bounds=3, + shearing_bounds=0.007, + translation_bounds=2, + nonlin_std=1.5, + nonlin_scale=0.04, + inter_method="nearest", + prob_deform=0.9, + )(grouped_labels) # randomly dilate/erode binary mask of each group n_group = len(np.unique(grouping_labels)) - grouped_labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x[..., 0], 'int32'), n_group, axis=-1))(grouped_labels) + grouped_labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x[..., 0], "int32"), n_group, axis=-1) + )(grouped_labels) split = KL.Lambda(lambda x: tf.split(x, [1] * n_group, axis=-1))(grouped_labels) - channels = [layers.RandomDilationErosion(1, 1, 5, 0.9, operation='random', return_mask=True)(c) for c in split] - grouped_labels = KL.Lambda(lambda x: tf.concat(x, -1), name='group_morph')(channels) + channels = [ + layers.RandomDilationErosion( + 1, 1, 5, 0.9, operation="random", return_mask=True + )(c) + for c in split + ] + grouped_labels = KL.Lambda(lambda x: tf.concat(x, -1), name="group_morph")(channels) # cropping if crop_shape != labels_shape: @@ -436,61 +505,101 @@ def labels_to_image_model_group(labels_shape, # flipping if flipping: - assert aff is not None, 'aff should not be None if flipping is True' - labels, grouped_labels = layers.RandomFlip(get_ras_axes(aff, n_dims)[0], [True, False], - generation_labels, n_neutral_labels)([labels, grouped_labels]) + assert aff is not None, "aff should not be None if flipping is True" + labels, grouped_labels = layers.RandomFlip( + get_ras_axes(aff, n_dims)[0], + [True, False], + generation_labels, + n_neutral_labels, + )([labels, grouped_labels]) # build synthetic image - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field if bias_field_std > 0: image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) # intensity augmentation - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.5, separate_channels=True)(image) + image = layers.IntensityAugmentation( + clip=300, normalise=True, gamma_std=0.5, separate_channels=True + )(image) # loop over channels channels = list() - split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) for i, channel in enumerate(split): if randomise_res: - max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) - max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) max_res = np.maximum(max_res_iso, max_res_aniso) - resolution, blur_res = layers.SampleResolution(atlas_res, max_res_iso, max_res_aniso)(means_input) - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) - channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), 1.03)([channel, sigma]) - channel = layers.MimicAcquisition(atlas_res, atlas_res, output_shape, False)([channel, resolution]) + resolution, blur_res = layers.SampleResolution( + atlas_res, max_res_iso, max_res_aniso + )(means_input) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), 1.03 + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, atlas_res, output_shape, False + )([channel, resolution]) channels.append(channel) else: - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, data_res[i], thickness=thickness[i]) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, data_res[i], thickness=thickness[i] + ) channel = layers.GaussianBlur(sigma, 1.03)(channel) - resolution = KL.Lambda(lambda x: tf.convert_to_tensor(data_res[i], dtype='float32'))([]) - channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)([channel, resolution]) + resolution = KL.Lambda( + lambda x: tf.convert_to_tensor(data_res[i], dtype="float32") + )([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)( + [channel, resolution] + ) channels.append(channel) # concatenate all channels back - image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) # compute image gradient if return_gradients: - image = layers.ImageGradients('sobel', True, name='image_gradients')(image) + image = layers.ImageGradients("sobel", True, name="image_gradients")(image) image = layers.IntensityAugmentation(clip=10, normalise=True)(image) # resample labels at target resolution if crop_shape != output_shape: - labels = l2i_et.resample_tensor(labels, output_shape, interp_method='nearest') - grouped_labels = l2i_et.resample_tensor(grouped_labels, output_shape, interp_method='nearest') + labels = l2i_et.resample_tensor(labels, output_shape, interp_method="nearest") + grouped_labels = l2i_et.resample_tensor( + grouped_labels, output_shape, interp_method="nearest" + ) # map generation labels to segmentation values - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: tf.concat([x[0], tf.cast(x[1], dtype=x[0].dtype)], axis=-1), - name='image_out')([image, grouped_labels, labels]) + image = KL.Lambda( + lambda x: tf.concat([x[0], tf.cast(x[1], dtype=x[0].dtype)], axis=-1), + name="image_out", + )([image, grouped_labels, labels]) brain_model = models.Model(inputs=list_inputs, outputs=[image, labels]) return brain_model diff --git a/nobrainer/ext/SynthSeg/training_qc.py b/nobrainer/ext/SynthSeg/training_qc.py index c35576df..ad87a1c0 100644 --- a/nobrainer/ext/SynthSeg/training_qc.py +++ b/nobrainer/ext/SynthSeg/training_qc.py @@ -17,55 +17,56 @@ License. """ +from inspect import getmembers, isclass # python imports import os -import keras -import numpy as np -import tensorflow as tf -from keras import models -import keras.layers as KL -import keras.backend as K -import keras.callbacks as KC -from keras.optimizers import Adam -from inspect import getmembers, isclass # project imports from SynthSeg import metrics_model as metrics # third-party imports -from ext.lab2im import utils from ext.lab2im import layers as l2i_layers -from ext.neuron import utils as nrn_utils +from ext.lab2im import utils from ext.neuron import layers as nrn_layers from ext.neuron import models as nrn_models +from ext.neuron import utils as nrn_utils +import keras +from keras import models +import keras.backend as K +import keras.callbacks as KC +import keras.layers as KL +from keras.optimizers import Adam +import numpy as np +import tensorflow as tf -def training(list_paths_input_labels, - list_paths_target_labels, - model_dir, - labels_list, - labels_list_to_convert=None, - subjects_prob=None, - batchsize=1, - output_shape=None, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - n_levels=5, - nb_conv_per_level=3, - conv_size=5, - unet_feat_count=24, - feat_multiplier=2, - activation='relu', - lr=1e-4, - epochs=300, - steps_per_epoch=1000, - checkpoint=None): - +def training( + list_paths_input_labels, + list_paths_target_labels, + model_dir, + labels_list, + labels_list_to_convert=None, + subjects_prob=None, + batchsize=1, + output_shape=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + n_levels=5, + nb_conv_per_level=3, + conv_size=5, + unet_feat_count=24, + feat_multiplier=2, + activation="relu", + lr=1e-4, + epochs=300, + steps_per_epoch=1000, + checkpoint=None, +): """ This function trains a regressor network to predict Dice scores between segmentations (typically obtained with an automated algorithm), and their ground truth (to which we typically do not have access at test time). @@ -136,55 +137,76 @@ def training(list_paths_input_labels, # prepare data files labels_list, _ = utils.get_list_labels(label_list=labels_list) if labels_list_to_convert is not None: - labels_list_to_convert, _ = utils.get_list_labels(label_list=labels_list_to_convert) + labels_list_to_convert, _ = utils.get_list_labels( + label_list=labels_list_to_convert + ) n_labels = len(np.unique(labels_list)) # create augmentation model - labels_shape, _, n_dims, _, _, _ = utils.get_volume_info(list_paths_target_labels[0], aff_ref=np.eye(4)) - augmentation_model = build_augmentation_model(labels_shape, - labels_list, - labels_list_to_convert=labels_list_to_convert, - output_shape=output_shape, - output_div_by_n=2 ** n_levels, - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale) + labels_shape, _, n_dims, _, _, _ = utils.get_volume_info( + list_paths_target_labels[0], aff_ref=np.eye(4) + ) + augmentation_model = build_augmentation_model( + labels_shape, + labels_list, + labels_list_to_convert=labels_list_to_convert, + output_shape=output_shape, + output_div_by_n=2**n_levels, + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + ) # prepare QC model - regression_model = build_qc_model(input_model=augmentation_model, - n_labels=n_labels, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation) + regression_model = build_qc_model( + input_model=augmentation_model, + n_labels=n_labels, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + ) qc_model = build_qc_loss(regression_model) # input generator - model_inputs = build_model_inputs(path_input_label_maps=list_paths_input_labels, - path_target_label_maps=list_paths_target_labels, - batchsize=batchsize, - subjects_prob=subjects_prob) + model_inputs = build_model_inputs( + path_input_label_maps=list_paths_input_labels, + path_target_label_maps=list_paths_target_labels, + batchsize=batchsize, + subjects_prob=subjects_prob, + ) input_generator = utils.build_training_generator(model_inputs, batchsize) - train_model(qc_model, input_generator, lr, epochs, steps_per_epoch, model_dir, 'qc', checkpoint) - - -def build_augmentation_model(labels_shape, - labels_list, - labels_list_to_convert=None, - output_shape=None, - output_div_by_n=None, - scaling_bounds=0.15, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.0625): + train_model( + qc_model, + input_generator, + lr, + epochs, + steps_per_epoch, + model_dir, + "qc", + checkpoint, + ) + + +def build_augmentation_model( + labels_shape, + labels_list, + labels_list_to_convert=None, + output_shape=None, + output_div_by_n=None, + scaling_bounds=0.15, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, +): # reformat resolutions and get shapes labels_shape = utils.reformat_to_list(labels_shape) @@ -194,74 +216,105 @@ def build_augmentation_model(labels_shape, output_shape = get_shapes(labels_shape, output_shape, output_div_by_n, n_dims) # define model inputs - net_input = KL.Input(shape=labels_shape + [1], name='noisy_labels_input', dtype='int32') - target_input = KL.Input(shape=labels_shape + [1], name='target_input', dtype='int32') + net_input = KL.Input( + shape=labels_shape + [1], name="noisy_labels_input", dtype="int32" + ) + target_input = KL.Input( + shape=labels_shape + [1], name="target_input", dtype="int32" + ) # convert labels if necessary if labels_list_to_convert is not None: - noisy_labels = l2i_layers.ConvertLabels(labels_list_to_convert, labels_list, name='convert_noisy')(net_input) - target = l2i_layers.ConvertLabels(labels_list_to_convert, labels_list, name='convert_target')(target_input) + noisy_labels = l2i_layers.ConvertLabels( + labels_list_to_convert, labels_list, name="convert_noisy" + )(net_input) + target = l2i_layers.ConvertLabels( + labels_list_to_convert, labels_list, name="convert_target" + )(target_input) else: noisy_labels = net_input target = target_input # deform labels - noisy_labels, target = l2i_layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method='nearest')([noisy_labels, target]) + noisy_labels, target = l2i_layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )([noisy_labels, target]) # mask image, compute Dice score with full GT, and crop noisy labels - noisy_labels, scores = SimulatePartialFOV(crop_shape=output_shape[0], - labels_list=np.unique(labels_list), - min_fov_shape=70, - prob_mask=0.3, name='partial_fov')([noisy_labels, target]) + noisy_labels, scores = SimulatePartialFOV( + crop_shape=output_shape[0], + labels_list=np.unique(labels_list), + min_fov_shape=70, + prob_mask=0.3, + name="partial_fov", + )([noisy_labels, target]) # dummy layers - scores = KL.Lambda(lambda x: x, name='dice_gt')(scores) - noisy_labels = KL.Lambda(lambda x: x[0], name='labels_out')([noisy_labels, scores]) + scores = KL.Lambda(lambda x: x, name="dice_gt")(scores) + noisy_labels = KL.Lambda(lambda x: x[0], name="labels_out")([noisy_labels, scores]) # build model and return brain_model = models.Model(inputs=[net_input, target_input], outputs=noisy_labels) return brain_model -def build_qc_model(input_model, - n_labels, - n_levels, - nb_conv_per_level, - conv_size, - unet_feat_count, - feat_multiplier, - activation): +def build_qc_model( + input_model, + n_labels, + n_levels, + nb_conv_per_level, + conv_size, + unet_feat_count, + feat_multiplier, + activation, +): # get prediction last_tensor = input_model.outputs[0] input_shape = last_tensor.get_shape().as_list()[1:] - assert input_shape[-1] == n_labels, 'mismatch between number of predicted labels, and segmentation labels' + assert ( + input_shape[-1] == n_labels + ), "mismatch between number of predicted labels, and segmentation labels" # build model - model = nrn_models.conv_enc(input_model=input_model, - input_shape=input_shape, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - use_residuals=True, - name='qc') + model = nrn_models.conv_enc( + input_model=input_model, + input_shape=input_shape, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + use_residuals=True, + name="qc", + ) last = model.outputs[0] - conv_kwargs = {'padding': 'same', 'activation': 'relu', 'data_format': 'channels_last'} - last = KL.MaxPool3D(pool_size=(2, 2, 2), name='qc_maxpool_%s' % (n_levels - 1), padding='same')(last) - last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_0')(last) - last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name='qc_final_conv_1')(last) - last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name='qc_final_pred')(last) + conv_kwargs = { + "padding": "same", + "activation": "relu", + "data_format": "channels_last", + } + last = KL.MaxPool3D( + pool_size=(2, 2, 2), name="qc_maxpool_%s" % (n_levels - 1), padding="same" + )(last) + last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name="qc_final_conv_0")( + last + ) + last = KL.Conv3D(n_labels, kernel_size=5, **conv_kwargs, name="qc_final_conv_1")( + last + ) + last = KL.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2, 3]), name="qc_final_pred")( + last + ) return models.Model(input_model.inputs, last) @@ -269,20 +322,21 @@ def build_qc_model(input_model, def build_qc_loss(input_model): # get Dice scores - dice_gt = input_model.get_layer('dice_gt').output + dice_gt = input_model.get_layer("dice_gt").output dice_pred = input_model.outputs[0] # get loss - loss = KL.Lambda(lambda x: K.sum(K.mean(K.square(x[0] - x[1]), axis=0)), name='qc_loss')([dice_gt, dice_pred]) + loss = KL.Lambda( + lambda x: K.sum(K.mean(K.square(x[0] - x[1]), axis=0)), name="qc_loss" + )([dice_gt, dice_pred]) loss._keras_shape = tuple(loss.get_shape().as_list()) return models.Model(inputs=input_model.inputs, outputs=loss) -def build_model_inputs(path_input_label_maps, - path_target_label_maps, - batchsize=1, - subjects_prob=None): +def build_model_inputs( + path_input_label_maps, path_target_label_maps, batchsize=1, subjects_prob=None +): # make sure subjects_prob sums to 1 subjects_prob = utils.load_array_if_path(subjects_prob) @@ -293,7 +347,9 @@ def build_model_inputs(path_input_label_maps, while True: # randomly pick as many images as batchsize - indices = np.random.choice(np.arange(len(path_input_label_maps)), size=batchsize, p=subjects_prob) + indices = np.random.choice( + np.arange(len(path_input_label_maps)), size=batchsize, p=subjects_prob + ) # initialise input lists list_input_label_maps = list() @@ -302,17 +358,23 @@ def build_model_inputs(path_input_label_maps, for idx in indices: # load input - input_net = utils.load_volume(path_input_label_maps[idx], dtype='int', aff_ref=np.eye(4)) + input_net = utils.load_volume( + path_input_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) list_input_label_maps.append(utils.add_axis(input_net, axis=[0, -1])) # load target - target = utils.load_volume(path_target_label_maps[idx], dtype='int', aff_ref=np.eye(4)) + target = utils.load_volume( + path_target_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) list_target_label_maps.append(utils.add_axis(target, axis=[0, -1])) # build list of training pairs list_training_pairs = [list_input_label_maps, list_target_label_maps] if batchsize > 1: # concatenate individual input types if batchsize > 1 - list_training_pairs = [np.concatenate(item, 0) for item in list_training_pairs] + list_training_pairs = [ + np.concatenate(item, 0) for item in list_training_pairs + ] else: list_training_pairs = [item[0] for item in list_training_pairs] @@ -323,17 +385,27 @@ def get_shapes(labels_shape, cropping_shape, output_div_by_n, n_dims): # cropping shape specified, make sure it's okay if cropping_shape is not None: - cropping_shape = utils.reformat_to_list(cropping_shape, length=n_dims, dtype='int') + cropping_shape = utils.reformat_to_list( + cropping_shape, length=n_dims, dtype="int" + ) # make sure that cropping shape is smaller or equal to label shape - cropping_shape = [min(labels_shape[i], cropping_shape[i]) for i in range(n_dims)] + cropping_shape = [ + min(labels_shape[i], cropping_shape[i]) for i in range(n_dims) + ] # make sure cropping shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in cropping_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in cropping_shape + ] if cropping_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(cropping_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + cropping_shape, output_div_by_n, tmp_shape + ) + ) cropping_shape = tmp_shape # no cropping shape specified, so no cropping unless label_shape is not divisible by output_div_by_n @@ -341,7 +413,10 @@ def get_shapes(labels_shape, cropping_shape, output_div_by_n, n_dims): # make sure labels shape is divisible by output_div_by_n if output_div_by_n is not None: - cropping_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in labels_shape] + cropping_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in labels_shape + ] # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape else: @@ -350,33 +425,53 @@ def get_shapes(labels_shape, cropping_shape, output_div_by_n, n_dims): return cropping_shape -def train_model(model, - generator, - learning_rate, - n_epochs, - n_steps, - model_dir, - metric_type, - path_checkpoint=None, - reinitialise_momentum=False): +def train_model( + model, + generator, + learning_rate, + n_epochs, + n_steps, + model_dir, + metric_type, + path_checkpoint=None, + reinitialise_momentum=False, +): # prepare model and log folders utils.mkdir(model_dir) - log_dir = os.path.join(model_dir, 'logs') + log_dir = os.path.join(model_dir, "logs") utils.mkdir(log_dir) # model saving callback - save_file_name = os.path.join(model_dir, 'qc_{epoch:03d}.h5') - callbacks = [KC.ModelCheckpoint(save_file_name, verbose=1), - KC.TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False)] + save_file_name = os.path.join(model_dir, "qc_{epoch:03d}.h5") + callbacks = [ + KC.ModelCheckpoint(save_file_name, verbose=1), + KC.TensorBoard( + log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False + ), + ] compile_model = True init_epoch = 0 if (path_checkpoint is not None) & (not reinitialise_momentum): init_epoch = int(os.path.basename(path_checkpoint).split(metric_type)[1][1:-3]) - custom_l2i = {key: value for (key, value) in getmembers(l2i_layers, isclass) if key != 'Layer'} - custom_nrn = {key: value for (key, value) in getmembers(nrn_layers, isclass) if key != 'Layer'} - custom_objects = {**custom_l2i, **custom_nrn, 'tf': tf, 'keras': keras, 'loss': metrics.IdentityLoss().loss} + custom_l2i = { + key: value + for (key, value) in getmembers(l2i_layers, isclass) + if key != "Layer" + } + custom_nrn = { + key: value + for (key, value) in getmembers(nrn_layers, isclass) + if key != "Layer" + } + custom_objects = { + **custom_l2i, + **custom_nrn, + "tf": tf, + "keras": keras, + "loss": metrics.IdentityLoss().loss, + } model = models.load_model(path_checkpoint, custom_objects=custom_objects) compile_model = False elif path_checkpoint is not None: @@ -384,14 +479,18 @@ def train_model(model, # compile if compile_model: - model.compile(optimizer=Adam(lr=learning_rate), loss=metrics.IdentityLoss().loss) + model.compile( + optimizer=Adam(lr=learning_rate), loss=metrics.IdentityLoss().loss + ) # fit - model.fit_generator(generator, - epochs=n_epochs, - steps_per_epoch=n_steps, - callbacks=callbacks, - initial_epoch=init_epoch) + model.fit_generator( + generator, + epochs=n_epochs, + steps_per_epoch=n_steps, + callbacks=callbacks, + initial_epoch=init_epoch, + ) class SimulatePartialFOV(KL.Layer): @@ -428,14 +527,20 @@ def get_config(self): def build(self, input_shape): # check shapes - assert len(input_shape) == 2, 'SimulatePartialFOV expects 2 inputs: labels to deform and GT (for Dice scores).' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "SimulatePartialFOV expects 2 inputs: labels to deform and GT (for Dice scores)." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." self.n_dims = len(input_shape[0]) - 2 - self.inshape = input_shape[0][1:self.n_dims + 1] + self.inshape = input_shape[0][1 : self.n_dims + 1] self.crop_max_val = self.inshape[0] - self.crop_shape self.meshgrid = nrn_utils.volshape_to_ndgrid(self.inshape) - self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.labels_list), dtype='int32') + self.lut = tf.convert_to_tensor( + utils.get_mapping_lut(self.labels_list), dtype="int32" + ) self.built = True super(SimulatePartialFOV, self).build(input_shape) @@ -448,17 +553,30 @@ def call(self, inputs, **kwargs): # sample cropping indices batchsize = tf.split(tf.shape(x), [1, -1])[0] - sample_shape = tf.concat([batchsize, self.n_dims * tf.ones([1], dtype='int32')], 0) + sample_shape = tf.concat( + [batchsize, self.n_dims * tf.ones([1], dtype="int32")], 0 + ) if self.crop_max_val > 0: - crop_idx_inf = tf.random.uniform(shape=sample_shape, minval=0, maxval=self.crop_max_val, dtype='int32') + crop_idx_inf = tf.random.uniform( + shape=sample_shape, minval=0, maxval=self.crop_max_val, dtype="int32" + ) crop_idx_sup = crop_idx_inf + self.crop_shape else: crop_idx_inf = crop_idx_sup = None # sample masking indices - fov_shape = tf.random.uniform(sample_shape, minval=self.min_fov_shape, maxval=self.inshape[0], dtype='int32') - mask_idx_inf = tf.random.uniform(shape=sample_shape, minval=0, maxval=1, dtype='float32') - mask_idx_inf_tmp = tf.cast(mask_idx_inf * tf.cast(self.inshape[0]-fov_shape, 'float32'), 'int32') + fov_shape = tf.random.uniform( + sample_shape, + minval=self.min_fov_shape, + maxval=self.inshape[0], + dtype="int32", + ) + mask_idx_inf = tf.random.uniform( + shape=sample_shape, minval=0, maxval=1, dtype="float32" + ) + mask_idx_inf_tmp = tf.cast( + mask_idx_inf * tf.cast(self.inshape[0] - fov_shape, "float32"), "int32" + ) mask_idx_sup_tmp = mask_idx_inf_tmp + fov_shape if self.crop_max_val > 0: mask_idx_inf = tf.maximum(mask_idx_inf_tmp, crop_idx_inf) @@ -468,8 +586,14 @@ def call(self, inputs, **kwargs): mask_idx_sup = mask_idx_sup_tmp # mask input labels - mask = tf.map_fn(self._single_build_mask, [x, mask_idx_inf, mask_idx_sup], tf.int32) - x = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), x * mask, x) + mask = tf.map_fn( + self._single_build_mask, [x, mask_idx_inf, mask_idx_sup], tf.int32 + ) + x = K.switch( + tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + x * mask, + x, + ) # compute dice score for each label value x = tf.one_hot(tf.gather(self.lut, x), depth=self.n_labels, axis=-1) @@ -481,7 +605,9 @@ def call(self, inputs, **kwargs): # crop input labels if self.crop_max_val > 0: - x_cropped = tf.map_fn(self._single_slice, [x, crop_idx_inf], dtype=tf.float32) + x_cropped = tf.map_fn( + self._single_slice, [x, crop_idx_inf], dtype=tf.float32 + ) else: x_cropped = x @@ -491,19 +617,28 @@ def _single_build_mask(self, inputs): vol = inputs[0] mask_idx_inf = inputs[1] mask_idx_sup = inputs[2] - mask = tf.ones(vol.shape, dtype='bool') + mask = tf.ones(vol.shape, dtype="bool") for i in range(self.n_dims): tmp_mask_inf = tf.less(self.meshgrid[i], mask_idx_inf[i]) tmp_mask_sup = tf.greater(self.meshgrid[i], mask_idx_sup[i]) - mask = tf.logical_and(mask, tf.logical_not(tf.logical_or(tmp_mask_inf, tmp_mask_sup))) - return tf.cast(mask, 'int32') + mask = tf.logical_and( + mask, tf.logical_not(tf.logical_or(tmp_mask_inf, tmp_mask_sup)) + ) + return tf.cast(mask, "int32") def _single_slice(self, inputs): vol = inputs[0] crop_idx_inf = inputs[1] - crop_idx_inf = tf.concat([tf.cast(crop_idx_inf, 'int32'), tf.zeros([1], dtype='int32')], axis=0) - crop_size = tf.convert_to_tensor([self.crop_shape] * self.n_dims + [-1], dtype='int32') + crop_idx_inf = tf.concat( + [tf.cast(crop_idx_inf, "int32"), tf.zeros([1], dtype="int32")], axis=0 + ) + crop_size = tf.convert_to_tensor( + [self.crop_shape] * self.n_dims + [-1], dtype="int32" + ) return tf.slice(vol, begin=crop_idx_inf, size=crop_size) def compute_output_shape(self, input_shape): - return [(None, *[self.crop_shape] * self.n_dims, self.n_labels), (None, self.n_labels)] + return [ + (None, *[self.crop_shape] * self.n_dims, self.n_labels), + (None, self.n_labels), + ] diff --git a/nobrainer/ext/SynthSeg/training_supervised.py b/nobrainer/ext/SynthSeg/training_supervised.py index 185a1722..7cfa21f3 100644 --- a/nobrainer/ext/SynthSeg/training_supervised.py +++ b/nobrainer/ext/SynthSeg/training_supervised.py @@ -18,62 +18,62 @@ License. """ - # python imports import os -import numpy as np -import tensorflow as tf -from keras import models -import keras.layers as KL -import numpy.random as npr # project imports from SynthSeg import metrics_model as metrics -from SynthSeg.training import train_model from SynthSeg.labels_to_image_model import get_shapes +from SynthSeg.training import train_model # third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.neuron import models as nrn_models from ext.lab2im import edit_tensors as l2i_et +from ext.lab2im import layers, utils from ext.lab2im.edit_volumes import get_ras_axes +from ext.neuron import models as nrn_models +from keras import models +import keras.layers as KL +import numpy as np +import numpy.random as npr +import tensorflow as tf -def training(image_dir, - labels_dir, - model_dir, - segmentation_labels=None, - n_neutral_labels=None, - subjects_prob=None, - batchsize=1, - target_res=None, - output_shape=None, - flipping=True, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - randomise_res=True, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - lr=1e-4, - wl2_epochs=1, - dice_epochs=50, - steps_per_epoch=10000, - checkpoint=None): +def training( + image_dir, + labels_dir, + model_dir, + segmentation_labels=None, + n_neutral_labels=None, + subjects_prob=None, + batchsize=1, + target_res=None, + output_shape=None, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + lr=1e-4, + wl2_epochs=1, + dice_epochs=50, + steps_per_epoch=10000, + checkpoint=None, +): """ This function trains a UNet to segment MRI images with real scans and corresponding ground truth labels. We regroup the parameters in four categories: General, Augmentation, Architecture, Training. @@ -184,57 +184,70 @@ def training(image_dir, """ # check epochs - assert (wl2_epochs > 0) | (dice_epochs > 0), \ - 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) + assert (wl2_epochs > 0) | ( + dice_epochs > 0 + ), "either wl2_epochs or dice_epochs must be positive, had {0} and {1}".format( + wl2_epochs, dice_epochs + ) # prepare data files path_images = utils.list_images_in_folder(image_dir) path_labels = utils.list_images_in_folder(labels_dir) - assert len(path_images) == len(path_labels), "There should be as many images as label maps." + assert len(path_images) == len( + path_labels + ), "There should be as many images as label maps." # get label lists - label_list, _ = utils.get_list_labels(label_list=segmentation_labels, labels_dir=labels_dir) + label_list, _ = utils.get_list_labels( + label_list=segmentation_labels, labels_dir=labels_dir + ) n_labels = np.size(label_list) # create augmentation model - im_shape, _, _, n_channels, _, atlas_res = utils.get_volume_info(path_images[0], aff_ref=np.eye(4)) - augmentation_model = build_augmentation_model(im_shape, - n_channels, - label_list, - n_neutral_labels, - atlas_res, - target_res, - output_shape=output_shape, - output_div_by_n=2 ** n_levels, - flipping=flipping, - aff=np.eye(4), - scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - randomise_res=randomise_res, - max_res_iso=max_res_iso, - max_res_aniso=max_res_aniso, - data_res=data_res, - thickness=thickness, - bias_field_std=bias_field_std, - bias_scale=bias_scale) + im_shape, _, _, n_channels, _, atlas_res = utils.get_volume_info( + path_images[0], aff_ref=np.eye(4) + ) + augmentation_model = build_augmentation_model( + im_shape, + n_channels, + label_list, + n_neutral_labels, + atlas_res, + target_res, + output_shape=output_shape, + output_div_by_n=2**n_levels, + flipping=flipping, + aff=np.eye(4), + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + randomise_res=randomise_res, + max_res_iso=max_res_iso, + max_res_aniso=max_res_aniso, + data_res=data_res, + thickness=thickness, + bias_field_std=bias_field_std, + bias_scale=bias_scale, + ) unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] # prepare the segmentation model - unet_model = nrn_models.unet(input_model=augmentation_model, - input_shape=unet_input_shape, - nb_labels=n_labels, - nb_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - nb_features=unet_feat_count, - feat_mult=feat_multiplier, - activation=activation, - batch_norm=-1, - name='unet') + unet_model = nrn_models.unet( + input_model=augmentation_model, + input_shape=unet_input_shape, + nb_labels=n_labels, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + name="unet", + ) # input generator generator = build_model_inputs(path_images, path_labels, batchsize, subjects_prob) @@ -242,65 +255,99 @@ def training(image_dir, # pre-training with weighted L2, input is fit to the softmax rather than the probabilities if wl2_epochs > 0: - wl2_model = models.Model(unet_model.inputs, [unet_model.get_layer('unet_likelihood').output]) - wl2_model = metrics.metrics_model(wl2_model, label_list, 'wl2') - train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) - checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) + wl2_model = models.Model( + unet_model.inputs, [unet_model.get_layer("unet_likelihood").output] + ) + wl2_model = metrics.metrics_model(wl2_model, label_list, "wl2") + train_model( + wl2_model, + input_generator, + lr, + wl2_epochs, + steps_per_epoch, + model_dir, + "wl2", + checkpoint, + ) + checkpoint = os.path.join(model_dir, "wl2_%03d.h5" % wl2_epochs) # fine-tuning with dice metric - dice_model = metrics.metrics_model(unet_model, label_list, 'dice') - train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) - - -def build_augmentation_model(im_shape, - n_channels, - segmentation_labels, - n_neutral_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - flipping=True, - aff=None, - scaling_bounds=0.2, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.0625, - randomise_res=False, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025): + dice_model = metrics.metrics_model(unet_model, label_list, "dice") + train_model( + dice_model, + input_generator, + lr, + dice_epochs, + steps_per_epoch, + model_dir, + "dice", + checkpoint, + ) + + +def build_augmentation_model( + im_shape, + n_channels, + segmentation_labels, + n_neutral_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.0625, + randomise_res=False, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, +): # reformat resolutions and get shapes im_shape = utils.reformat_to_list(im_shape) n_dims, _ = utils.get_dims(im_shape) if data_res is not None: data_res = utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) - thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) - target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + thickness = ( + data_res + if thickness is None + else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + ) + target_res = ( + atlas_res + if (target_res is None) + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) else: target_res = atlas_res # get shapes - crop_shape, output_shape = get_shapes(im_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + im_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - image_input = KL.Input(shape=im_shape+[n_channels], name='image_input') - labels_input = KL.Input(shape=im_shape + [1], name='labels_input', dtype='int32') + image_input = KL.Input(shape=im_shape + [n_channels], name="image_input") + labels_input = KL.Input(shape=im_shape + [1], name="labels_input", dtype="int32") # deform labels - labels, image = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method=['nearest', 'linear'])([labels_input, image_input]) + labels, image = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method=["nearest", "linear"], + )([labels_input, image_input]) # cropping if crop_shape != im_shape: @@ -308,61 +355,99 @@ def build_augmentation_model(im_shape, # flipping if flipping: - assert aff is not None, 'aff should not be None if flipping is True' - labels, image = layers.RandomFlip(get_ras_axes(aff, n_dims)[0], [True, False], - segmentation_labels, n_neutral_labels)([labels, image]) + assert aff is not None, "aff should not be None if flipping is True" + labels, image = layers.RandomFlip( + get_ras_axes(aff, n_dims)[0], + [True, False], + segmentation_labels, + n_neutral_labels, + )([labels, image]) # apply bias field if bias_field_std > 0: image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) # intensity augmentation - image = layers.IntensityAugmentation(6, clip=False, normalise=True, gamma_std=.5, separate_channels=True)(image) + image = layers.IntensityAugmentation( + 6, clip=False, normalise=True, gamma_std=0.5, separate_channels=True + )(image) # if necessary, loop over channels to 1) blur, 2) downsample to simulated LR, and 3) upsample to target if data_res is not None: channels = list() - split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) for i, channel in enumerate(split): if randomise_res: - max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) - max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) max_res = np.maximum(max_res_iso, max_res_aniso) - resolution, blur_res = layers.SampleResolution(atlas_res, max_res_iso, max_res_aniso)(image) - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) - channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), 1.03)([channel, sigma]) - channel = layers.MimicAcquisition(atlas_res, atlas_res, output_shape, False)([channel, resolution]) + resolution, blur_res = layers.SampleResolution( + atlas_res, max_res_iso, max_res_aniso + )(image) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), 1.03 + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, atlas_res, output_shape, False + )([channel, resolution]) channels.append(channel) else: - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, data_res[i], thickness=thickness[i]) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, data_res[i], thickness=thickness[i] + ) channel = layers.GaussianBlur(sigma, 1.03)(channel) - resolution = KL.Lambda(lambda x: tf.convert_to_tensor(data_res[i], dtype='float32'))([]) - channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)([channel, resolution]) + resolution = KL.Lambda( + lambda x: tf.convert_to_tensor(data_res[i], dtype="float32") + )([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)( + [channel, resolution] + ) channels.append(channel) # concatenate all channels back - image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) # resample labels at target resolution if crop_shape != output_shape: - labels = l2i_et.resample_tensor(labels, output_shape, interp_method='nearest') + labels = l2i_et.resample_tensor( + labels, output_shape, interp_method="nearest" + ) # build model (dummy layer enables to keep the labels when plugging this model to other models) - labels = KL.Lambda(lambda x: tf.cast(x, dtype='int32'), name='labels_out')(labels) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) - brain_model = models.Model(inputs=[image_input, labels_input], outputs=[image, labels]) + labels = KL.Lambda(lambda x: tf.cast(x, dtype="int32"), name="labels_out")(labels) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = models.Model( + inputs=[image_input, labels_input], outputs=[image, labels] + ) return brain_model -def build_model_inputs(path_inputs, - path_outputs, - batchsize=1, - subjects_prob=None, - dtype_input='float32', - dtype_output='int32'): +def build_model_inputs( + path_inputs, + path_outputs, + batchsize=1, + subjects_prob=None, + dtype_input="float32", + dtype_output="int32", +): # get label info _, _, _, n_channels, _, _ = utils.get_volume_info(path_inputs[0]) @@ -375,7 +460,9 @@ def build_model_inputs(path_inputs, while True: # randomly pick as many images as batchsize - indices = npr.choice(np.arange(len(path_outputs)), size=batchsize, p=subjects_prob) + indices = npr.choice( + np.arange(len(path_outputs)), size=batchsize, p=subjects_prob + ) # initialise input lists list_batch_inputs = list() @@ -384,20 +471,26 @@ def build_model_inputs(path_inputs, for idx in indices: # get a batch input - batch_input = utils.load_volume(path_inputs[idx], aff_ref=np.eye(4), dtype=dtype_input) + batch_input = utils.load_volume( + path_inputs[idx], aff_ref=np.eye(4), dtype=dtype_input + ) if n_channels > 1: list_batch_inputs.append(utils.add_axis(batch_input, axis=0)) else: list_batch_inputs.append(utils.add_axis(batch_input, axis=[0, -1])) # get a batch output - batch_output = utils.load_volume(path_outputs[idx], aff_ref=np.eye(4), dtype=dtype_output) + batch_output = utils.load_volume( + path_outputs[idx], aff_ref=np.eye(4), dtype=dtype_output + ) list_batch_outputs.append(utils.add_axis(batch_output, axis=[0, -1])) # build list of training pairs list_training_pairs = [list_batch_inputs, list_batch_outputs] if batchsize > 1: # concatenate individual input types if batchsize > 1 - list_training_pairs = [np.concatenate(item, 0) for item in list_training_pairs] + list_training_pairs = [ + np.concatenate(item, 0) for item in list_training_pairs + ] else: list_training_pairs = [item[0] for item in list_training_pairs] diff --git a/nobrainer/ext/SynthSeg/validate.py b/nobrainer/ext/SynthSeg/validate.py index 48bcdd53..bce58606 100644 --- a/nobrainer/ext/SynthSeg/validate.py +++ b/nobrainer/ext/SynthSeg/validate.py @@ -13,45 +13,47 @@ License. """ +import logging # python imports import os import re -import logging -import numpy as np -import matplotlib.pyplot as plt -from tensorflow.python.summary.summary_iterator import summary_iterator # project imports from SynthSeg.predict import predict # third-party imports from ext.lab2im import utils +import matplotlib.pyplot as plt +import numpy as np +from tensorflow.python.summary.summary_iterator import summary_iterator -def validate_training(image_dir, - gt_dir, - models_dir, - validation_main_dir, - labels_segmentation, - n_neutral_labels=None, - evaluation_labels=None, - step_eval=1, - min_pad=None, - cropping=None, - target_res=1., - gradients=False, - flip=False, - topology_classes=None, - sigma_smoothing=0, - keep_biggest_component=False, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - recompute=False): +def validate_training( + image_dir, + gt_dir, + models_dir, + validation_main_dir, + labels_segmentation, + n_neutral_labels=None, + evaluation_labels=None, + step_eval=1, + min_pad=None, + cropping=None, + target_res=1.0, + gradients=False, + flip=False, + topology_classes=None, + sigma_smoothing=0, + keep_biggest_component=False, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + recompute=False, +): """This function validates models saved at different epochs of the same training. All models are assumed to be in the same folder. The results of each model are saved in a subfolder in validation_main_dir. @@ -87,52 +89,70 @@ def validate_training(image_dir, :param unet_feat_count: (optional) number of feature maps for the first level. Default is 24. :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 1. :param activation: (optional) activation function. Can be 'elu', 'relu'. - :param recompute: (optional) whether to recompute result files even if they already exist.""" + :param recompute: (optional) whether to recompute result files even if they already exist. + """ # create result folder utils.mkdir(validation_main_dir) # loop over models - list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval] + list_models = utils.list_files(models_dir, expr=["dice", ".h5"], cond_type="and")[ + ::step_eval + ] # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0] - loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) + loop_info = utils.LoopInfo(len(list_models), 1, "validating", True) for model_idx, path_model in enumerate(list_models): # build names and create folders - model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) - dice_path = os.path.join(model_val_dir, 'dice.npy') + model_val_dir = os.path.join( + validation_main_dir, os.path.basename(path_model).replace(".h5", "") + ) + dice_path = os.path.join(model_val_dir, "dice.npy") utils.mkdir(model_val_dir) if (not os.path.isfile(dice_path)) | recompute: loop_info.update(model_idx) - predict(path_images=image_dir, - path_segmentations=model_val_dir, - path_model=path_model, - labels_segmentation=labels_segmentation, - n_neutral_labels=n_neutral_labels, - min_pad=min_pad, - cropping=cropping, - target_res=target_res, - gradients=gradients, - flip=flip, - topology_classes=topology_classes, - sigma_smoothing=sigma_smoothing, - keep_biggest_component=keep_biggest_component, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - gt_folder=gt_dir, - evaluation_labels=evaluation_labels, - recompute=recompute, - verbose=False) - - -def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_indices=None, - skip_first_dice_row=True, size_max_circle=100, figsize=(11, 6), y_lim=None, fontsize=18, - list_linestyles=None, list_colours=None, plot_legend=False, draw_line=None): + predict( + path_images=image_dir, + path_segmentations=model_val_dir, + path_model=path_model, + labels_segmentation=labels_segmentation, + n_neutral_labels=n_neutral_labels, + min_pad=min_pad, + cropping=cropping, + target_res=target_res, + gradients=gradients, + flip=flip, + topology_classes=topology_classes, + sigma_smoothing=sigma_smoothing, + keep_biggest_component=keep_biggest_component, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + gt_folder=gt_dir, + evaluation_labels=evaluation_labels, + recompute=recompute, + verbose=False, + ) + + +def plot_validation_curves( + list_validation_dirs, + architecture_names=None, + eval_indices=None, + skip_first_dice_row=True, + size_max_circle=100, + figsize=(11, 6), + y_lim=None, + fontsize=18, + list_linestyles=None, + list_colours=None, + plot_legend=False, + draw_line=None, +): """This function plots the validation curves of several networks, based on the results of validate_training(). It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores for the corresponding validated epoch. @@ -150,33 +170,50 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i if isinstance(eval_indices, (np.ndarray, str)): if isinstance(eval_indices, str): eval_indices = np.load(eval_indices) - eval_indices = np.squeeze(utils.reformat_to_n_channels_array(eval_indices, n_dims=len(eval_indices))) + eval_indices = np.squeeze( + utils.reformat_to_n_channels_array( + eval_indices, n_dims=len(eval_indices) + ) + ) eval_indices = [eval_indices] * len(list_validation_dirs) elif isinstance(eval_indices, list): - for (i, e) in enumerate(eval_indices): + for i, e in enumerate(eval_indices): if isinstance(e, np.ndarray): - eval_indices[i] = np.squeeze(utils.reformat_to_n_channels_array(e, n_dims=len(e))) + eval_indices[i] = np.squeeze( + utils.reformat_to_n_channels_array(e, n_dims=len(e)) + ) else: - raise TypeError('if provided as a list, eval_indices should only contain numpy arrays') + raise TypeError( + "if provided as a list, eval_indices should only contain numpy arrays" + ) else: - raise TypeError('eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.') + raise TypeError( + "eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays." + ) else: eval_indices = [None] * len(list_validation_dirs) # reformat model names if architecture_names is None: - architecture_names = [os.path.basename(os.path.dirname(d)) for d in list_validation_dirs] + architecture_names = [ + os.path.basename(os.path.dirname(d)) for d in list_validation_dirs + ] else: - architecture_names = utils.reformat_to_list(architecture_names, len(list_validation_dirs)) + architecture_names = utils.reformat_to_list( + architecture_names, len(list_validation_dirs) + ) # prepare legend labels if plot_legend is False: - list_legend_labels = ['_nolegend_'] * n_curves + list_legend_labels = ["_nolegend_"] * n_curves elif plot_legend is True: list_legend_labels = architecture_names else: list_legend_labels = architecture_names - list_legend_labels = ['_nolegend_' if i >= plot_legend else list_legend_labels[i] for i in range(n_curves)] + list_legend_labels = [ + "_nolegend_" if i >= plot_legend else list_legend_labels[i] + for i in range(n_curves) + ] # prepare linestyles if list_linestyles is not None: @@ -192,12 +229,23 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i # loop over architectures plt.figure(figsize=figsize) - for idx, (net_val_dir, net_name, linestyle, colour, legend_label, eval_idx) in enumerate(zip(list_validation_dirs, - architecture_names, - list_linestyles, - list_colours, - list_legend_labels, - eval_indices)): + for idx, ( + net_val_dir, + net_name, + linestyle, + colour, + legend_label, + eval_idx, + ) in enumerate( + zip( + list_validation_dirs, + architecture_names, + list_linestyles, + list_colours, + list_legend_labels, + eval_indices, + ) + ): list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False) @@ -207,16 +255,22 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i for epoch_dir in list_epochs_dir: # build names and create folders - path_epoch_scores = os.path.join(net_val_dir, epoch_dir, 'dice.npy') + path_epoch_scores = os.path.join(net_val_dir, epoch_dir, "dice.npy") if os.path.isfile(path_epoch_scores): if eval_idx is not None: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :]))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :])) + ) else: if skip_first_dice_row: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[1:, :]))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores)[1:, :])) + ) else: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)))) - list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores))) + ) + list_epochs.append(int(re.sub("[^0-9]", "", epoch_dir))) # plot validation scores for current architecture if list_net_scores: # check that archi has been validated for at least 1 epoch @@ -226,33 +280,45 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i list_net_scores = list_net_scores[idx] max_score = np.max(list_net_scores) epoch_max_score = list_epochs[np.argmax(list_net_scores)] - print('\n'+net_name) - print('epoch max score: %d' % epoch_max_score) - print('max score: %0.3f' % max_score) - plt.plot(list_epochs, list_net_scores, label=legend_label, linestyle=linestyle, color=colour) + print("\n" + net_name) + print("epoch max score: %d" % epoch_max_score) + print("max score: %0.3f" % max_score) + plt.plot( + list_epochs, + list_net_scores, + label=legend_label, + linestyle=linestyle, + color=colour, + ) plt.scatter(epoch_max_score, max_score, s=size_max_circle, color=colour) # finalise plot plt.grid() if draw_line is not None: draw_line = utils.reformat_to_list(draw_line) - list_linestyles = ['dotted', 'dashed', 'solid', 'dashdot'][:len(draw_line)] + list_linestyles = ["dotted", "dashed", "solid", "dashdot"][: len(draw_line)] for line, linestyle in zip(draw_line, list_linestyles): - plt.axhline(line, color='black', linestyle=linestyle) - plt.tick_params(axis='both', labelsize=fontsize) - plt.ylabel('Scores', fontsize=fontsize) - plt.xlabel('Epochs', fontsize=fontsize) + plt.axhline(line, color="black", linestyle=linestyle) + plt.tick_params(axis="both", labelsize=fontsize) + plt.ylabel("Scores", fontsize=fontsize) + plt.xlabel("Epochs", fontsize=fontsize) if y_lim is not None: plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot - plt.title('Validation curves', fontsize=fontsize) + plt.title("Validation curves", fontsize=fontsize) if plot_legend: plt.legend(fontsize=fontsize) plt.tight_layout(pad=1) plt.show() -def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, 6), fontsize=18, - y_lim=None, remove_legend=False): +def draw_learning_curve( + path_tensorboard_files, + architecture_names, + figsize=(11, 6), + fontsize=18, + y_lim=None, + remove_legend=False, +): """This function draws the learning curve of several trainings on the same graph. :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot. :param architecture_names: list of the names of the models @@ -263,7 +329,9 @@ def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, # reformat inputs path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files) architecture_names = utils.reformat_to_list(architecture_names) - assert len(path_tensorboard_files) == len(architecture_names), 'names and tensorboard lists should have same length' + assert len(path_tensorboard_files) == len( + architecture_names + ), "names and tensorboard lists should have same length" # loop over architectures plt.figure(figsize=figsize) @@ -274,24 +342,26 @@ def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, # extract loss at the end of all epochs list_losses = list() list_epochs = list() - logging.getLogger('tensorflow').disabled = True + logging.getLogger("tensorflow").disabled = True for path in path_tensorboard_file: for e in summary_iterator(path): for v in e.summary.value: - if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss': + if v.tag == "loss" or v.tag == "accuracy" or v.tag == "epoch_loss": list_losses.append(v.simple_value) list_epochs.append(e.step) - plt.plot(np.array(list_epochs), 1-np.array(list_losses), label=name, linewidth=2) + plt.plot( + np.array(list_epochs), 1 - np.array(list_losses), label=name, linewidth=2 + ) # finalise plot plt.grid() if not remove_legend: plt.legend(fontsize=fontsize) - plt.xlabel('Epochs', fontsize=fontsize) - plt.ylabel('Scores', fontsize=fontsize) + plt.xlabel("Epochs", fontsize=fontsize) + plt.ylabel("Scores", fontsize=fontsize) if y_lim is not None: plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot - plt.tick_params(axis='both', labelsize=fontsize) - plt.title('Learning curves', fontsize=fontsize) + plt.tick_params(axis="both", labelsize=fontsize) + plt.title("Learning curves", fontsize=fontsize) plt.tight_layout(pad=1) plt.show() diff --git a/nobrainer/ext/SynthSeg/validate_denoiser.py b/nobrainer/ext/SynthSeg/validate_denoiser.py index 82f238a8..155435b8 100644 --- a/nobrainer/ext/SynthSeg/validate_denoiser.py +++ b/nobrainer/ext/SynthSeg/validate_denoiser.py @@ -13,7 +13,6 @@ License. """ - # python imports import os @@ -24,62 +23,70 @@ from ext.lab2im import utils -def validate_training(prediction_dir, - gt_dir, - models_dir, - validation_main_dir, - target_segmentation_labels, - input_segmentation_labels=None, - evaluation_labels=None, - step_eval=1, - min_pad=None, - cropping=None, - topology_classes=None, - sigma_smoothing=0, - keep_biggest_component=False, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - skip_n_concatenations=0, - recompute=True): +def validate_training( + prediction_dir, + gt_dir, + models_dir, + validation_main_dir, + target_segmentation_labels, + input_segmentation_labels=None, + evaluation_labels=None, + step_eval=1, + min_pad=None, + cropping=None, + topology_classes=None, + sigma_smoothing=0, + keep_biggest_component=False, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + skip_n_concatenations=0, + recompute=True, +): # create result folder utils.mkdir(validation_main_dir) # loop over models - list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval] + list_models = utils.list_files(models_dir, expr=["dice", ".h5"], cond_type="and")[ + ::step_eval + ] # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 2 == 0] - loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) + loop_info = utils.LoopInfo(len(list_models), 1, "validating", True) for model_idx, path_model in enumerate(list_models): # build names and create folders - model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) - dice_path = os.path.join(model_val_dir, 'dice.npy') + model_val_dir = os.path.join( + validation_main_dir, os.path.basename(path_model).replace(".h5", "") + ) + dice_path = os.path.join(model_val_dir, "dice.npy") utils.mkdir(model_val_dir) if (not os.path.isfile(dice_path)) | recompute: loop_info.update(model_idx) - predict(path_predictions=prediction_dir, - path_corrections=model_val_dir, - path_model=path_model, - target_segmentation_labels=target_segmentation_labels, - input_segmentation_labels=input_segmentation_labels, - min_pad=min_pad, - cropping=cropping, - topology_classes=topology_classes, - sigma_smoothing=sigma_smoothing, - keep_biggest_component=keep_biggest_component, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - skip_n_concatenations=skip_n_concatenations, - gt_folder=gt_dir, - evaluation_labels=evaluation_labels, - recompute=recompute, - verbose=False) + predict( + path_predictions=prediction_dir, + path_corrections=model_val_dir, + path_model=path_model, + target_segmentation_labels=target_segmentation_labels, + input_segmentation_labels=input_segmentation_labels, + min_pad=min_pad, + cropping=cropping, + topology_classes=topology_classes, + sigma_smoothing=sigma_smoothing, + keep_biggest_component=keep_biggest_component, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + skip_n_concatenations=skip_n_concatenations, + gt_folder=gt_dir, + evaluation_labels=evaluation_labels, + recompute=recompute, + verbose=False, + ) diff --git a/nobrainer/ext/SynthSeg/validate_group.py b/nobrainer/ext/SynthSeg/validate_group.py index bfccf298..9527baf5 100644 --- a/nobrainer/ext/SynthSeg/validate_group.py +++ b/nobrainer/ext/SynthSeg/validate_group.py @@ -13,7 +13,6 @@ License. """ - # python imports import os @@ -24,66 +23,74 @@ from ext.lab2im import utils -def validate_training(image_dir, - mask_dir, - gt_dir, - models_dir, - validation_main_dir, - labels_segmentation, - labels_mask, - evaluation_labels=None, - step_eval=1, - min_pad=None, - cropping=None, - sigma_smoothing=0, - strict_masking=False, - keep_biggest_component=False, - n_levels=5, - nb_conv_per_level=2, - conv_size=3, - unet_feat_count=24, - feat_multiplier=2, - activation='elu', - list_incorrect_labels=None, - list_correct_labels=None, - recompute=False): +def validate_training( + image_dir, + mask_dir, + gt_dir, + models_dir, + validation_main_dir, + labels_segmentation, + labels_mask, + evaluation_labels=None, + step_eval=1, + min_pad=None, + cropping=None, + sigma_smoothing=0, + strict_masking=False, + keep_biggest_component=False, + n_levels=5, + nb_conv_per_level=2, + conv_size=3, + unet_feat_count=24, + feat_multiplier=2, + activation="elu", + list_incorrect_labels=None, + list_correct_labels=None, + recompute=False, +): # create result folder utils.mkdir(validation_main_dir) # loop over models - list_models = utils.list_files(models_dir, expr=['dice', '.h5'], cond_type='and')[::step_eval] + list_models = utils.list_files(models_dir, expr=["dice", ".h5"], cond_type="and")[ + ::step_eval + ] # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0] - loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) + loop_info = utils.LoopInfo(len(list_models), 1, "validating", True) for model_idx, path_model in enumerate(list_models): # build names and create folders - model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) - dice_path = os.path.join(model_val_dir, 'dice.npy') + model_val_dir = os.path.join( + validation_main_dir, os.path.basename(path_model).replace(".h5", "") + ) + dice_path = os.path.join(model_val_dir, "dice.npy") utils.mkdir(model_val_dir) if (not os.path.isfile(dice_path)) | recompute: loop_info.update(model_idx) - predict(path_images=image_dir, - path_masks=mask_dir, - path_segmentations=model_val_dir, - path_model=path_model, - labels_segmentation=labels_segmentation, - labels_mask=labels_mask, - min_pad=min_pad, - cropping=cropping, - sigma_smoothing=sigma_smoothing, - strict_masking=strict_masking, - keep_biggest_component=keep_biggest_component, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - gt_folder=gt_dir, - evaluation_labels=evaluation_labels, - list_incorrect_labels=list_incorrect_labels, - list_correct_labels=list_correct_labels, - recompute=recompute, - verbose=False) + predict( + path_images=image_dir, + path_masks=mask_dir, + path_segmentations=model_val_dir, + path_model=path_model, + labels_segmentation=labels_segmentation, + labels_mask=labels_mask, + min_pad=min_pad, + cropping=cropping, + sigma_smoothing=sigma_smoothing, + strict_masking=strict_masking, + keep_biggest_component=keep_biggest_component, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + gt_folder=gt_dir, + evaluation_labels=evaluation_labels, + list_incorrect_labels=list_incorrect_labels, + list_correct_labels=list_correct_labels, + recompute=recompute, + verbose=False, + ) diff --git a/nobrainer/ext/SynthSeg/validate_qc.py b/nobrainer/ext/SynthSeg/validate_qc.py index 35c23563..e5c6e92d 100644 --- a/nobrainer/ext/SynthSeg/validate_qc.py +++ b/nobrainer/ext/SynthSeg/validate_qc.py @@ -13,75 +13,93 @@ License. """ +import logging # python imports import os import re -import logging -import numpy as np -import matplotlib.pyplot as plt -from tensorflow.python.summary.summary_iterator import summary_iterator # project imports from SynthSeg.predict_qc import predict # third-party imports from ext.lab2im import utils +import matplotlib.pyplot as plt +import numpy as np +from tensorflow.python.summary.summary_iterator import summary_iterator -def validate_training(prediction_dir, - gt_dir, - models_dir, - validation_main_dir, - labels_list, - labels_to_convert=None, - convert_gt=False, - shape=224, - n_levels=5, - nb_conv_per_level=3, - conv_size=5, - unet_feat_count=24, - feat_multiplier=2, - activation='relu', - step_eval=1, - recompute=False): +def validate_training( + prediction_dir, + gt_dir, + models_dir, + validation_main_dir, + labels_list, + labels_to_convert=None, + convert_gt=False, + shape=224, + n_levels=5, + nb_conv_per_level=3, + conv_size=5, + unet_feat_count=24, + feat_multiplier=2, + activation="relu", + step_eval=1, + recompute=False, +): # create result folder utils.mkdir(validation_main_dir) # loop over models - list_models = utils.list_files(models_dir, expr=['qc', '.h5'], cond_type='and')[::step_eval] + list_models = utils.list_files(models_dir, expr=["qc", ".h5"], cond_type="and")[ + ::step_eval + ] # list_models = [p for p in list_models if int(os.path.basename(p)[-6:-3]) % 10 == 0] - loop_info = utils.LoopInfo(len(list_models), 1, 'validating', True) + loop_info = utils.LoopInfo(len(list_models), 1, "validating", True) for model_idx, path_model in enumerate(list_models): # build names and create folders - model_val_dir = os.path.join(validation_main_dir, os.path.basename(path_model).replace('.h5', '')) - score_path = os.path.join(model_val_dir, 'pred_qc_results.npy') + model_val_dir = os.path.join( + validation_main_dir, os.path.basename(path_model).replace(".h5", "") + ) + score_path = os.path.join(model_val_dir, "pred_qc_results.npy") utils.mkdir(model_val_dir) if (not os.path.isfile(score_path)) | recompute: loop_info.update(model_idx) - predict(path_predictions=prediction_dir, - path_qc_results=score_path, - path_model=path_model, - labels_list=labels_list, - labels_to_convert=labels_to_convert, - convert_gt=convert_gt, - shape=shape, - n_levels=n_levels, - nb_conv_per_level=nb_conv_per_level, - conv_size=conv_size, - unet_feat_count=unet_feat_count, - feat_multiplier=feat_multiplier, - activation=activation, - path_gts=gt_dir, - verbose=False) - - -def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_indices=None, - skip_first_dice_row=True, size_max_circle=100, figsize=(11, 6), y_lim=None, fontsize=18, - list_linestyles=None, list_colours=None, plot_legend=False): + predict( + path_predictions=prediction_dir, + path_qc_results=score_path, + path_model=path_model, + labels_list=labels_list, + labels_to_convert=labels_to_convert, + convert_gt=convert_gt, + shape=shape, + n_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + unet_feat_count=unet_feat_count, + feat_multiplier=feat_multiplier, + activation=activation, + path_gts=gt_dir, + verbose=False, + ) + + +def plot_validation_curves( + list_validation_dirs, + architecture_names=None, + eval_indices=None, + skip_first_dice_row=True, + size_max_circle=100, + figsize=(11, 6), + y_lim=None, + fontsize=18, + list_linestyles=None, + list_colours=None, + plot_legend=False, +): """This function plots the validation curves of several networks, based on the results of validate_training(). It takes as input a list of validation folders (one for each network), each containing subfolders with dice scores for the corresponding validated epoch. @@ -99,33 +117,50 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i if isinstance(eval_indices, (np.ndarray, str)): if isinstance(eval_indices, str): eval_indices = np.load(eval_indices) - eval_indices = np.squeeze(utils.reformat_to_n_channels_array(eval_indices, n_dims=len(eval_indices))) + eval_indices = np.squeeze( + utils.reformat_to_n_channels_array( + eval_indices, n_dims=len(eval_indices) + ) + ) eval_indices = [eval_indices] * len(list_validation_dirs) elif isinstance(eval_indices, list): - for (i, e) in enumerate(eval_indices): + for i, e in enumerate(eval_indices): if isinstance(e, np.ndarray): - eval_indices[i] = np.squeeze(utils.reformat_to_n_channels_array(e, n_dims=len(e))) + eval_indices[i] = np.squeeze( + utils.reformat_to_n_channels_array(e, n_dims=len(e)) + ) else: - raise TypeError('if provided as a list, eval_indices should only contain numpy arrays') + raise TypeError( + "if provided as a list, eval_indices should only contain numpy arrays" + ) else: - raise TypeError('eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays.') + raise TypeError( + "eval_indices can be a numpy array, a path to a numpy array, or a list of numpy arrays." + ) else: eval_indices = [None] * len(list_validation_dirs) # reformat model names if architecture_names is None: - architecture_names = [os.path.basename(os.path.dirname(d)) for d in list_validation_dirs] + architecture_names = [ + os.path.basename(os.path.dirname(d)) for d in list_validation_dirs + ] else: - architecture_names = utils.reformat_to_list(architecture_names, len(list_validation_dirs)) + architecture_names = utils.reformat_to_list( + architecture_names, len(list_validation_dirs) + ) # prepare legend labels if plot_legend is False: - list_legend_labels = ['_nolegend_'] * n_curves + list_legend_labels = ["_nolegend_"] * n_curves elif plot_legend is True: list_legend_labels = architecture_names else: list_legend_labels = architecture_names - list_legend_labels = ['_nolegend_' if i >= plot_legend else list_legend_labels[i] for i in range(n_curves)] + list_legend_labels = [ + "_nolegend_" if i >= plot_legend else list_legend_labels[i] + for i in range(n_curves) + ] # prepare linestyles if list_linestyles is not None: @@ -141,12 +176,23 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i # loop over architectures plt.figure(figsize=figsize) - for idx, (net_val_dir, net_name, linestyle, colour, legend_label, eval_idx) in enumerate(zip(list_validation_dirs, - architecture_names, - list_linestyles, - list_colours, - list_legend_labels, - eval_indices)): + for idx, ( + net_val_dir, + net_name, + linestyle, + colour, + legend_label, + eval_idx, + ) in enumerate( + zip( + list_validation_dirs, + architecture_names, + list_linestyles, + list_colours, + list_legend_labels, + eval_indices, + ) + ): list_epochs_dir = utils.list_subfolders(net_val_dir, whole_path=False) @@ -156,17 +202,25 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i for epoch_dir in list_epochs_dir: # build names and create folders - path_epoch_scores = utils.list_files(os.path.join(net_val_dir, epoch_dir), expr='diff') + path_epoch_scores = utils.list_files( + os.path.join(net_val_dir, epoch_dir), expr="diff" + ) if len(path_epoch_scores) == 1: path_epoch_scores = path_epoch_scores[0] if eval_idx is not None: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :]))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores)[eval_idx, :])) + ) else: if skip_first_dice_row: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)[1:, :]))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores)[1:, :])) + ) else: - list_net_scores.append(np.mean(np.abs(np.load(path_epoch_scores)))) - list_epochs.append(int(re.sub('[^0-9]', '', epoch_dir))) + list_net_scores.append( + np.mean(np.abs(np.load(path_epoch_scores))) + ) + list_epochs.append(int(re.sub("[^0-9]", "", epoch_dir))) # plot validation scores for current architecture if list_net_scores: # check that archi has been validated for at least 1 epoch @@ -176,28 +230,40 @@ def plot_validation_curves(list_validation_dirs, architecture_names=None, eval_i list_net_scores = list_net_scores[idx] min_score = np.min(list_net_scores) epoch_min_score = list_epochs[np.argmin(list_net_scores)] - print('\n'+net_name) - print('epoch min score: %d' % epoch_min_score) - print('min score: %0.3f' % min_score) - plt.plot(list_epochs, list_net_scores, label=legend_label, linestyle=linestyle, color=colour) + print("\n" + net_name) + print("epoch min score: %d" % epoch_min_score) + print("min score: %0.3f" % min_score) + plt.plot( + list_epochs, + list_net_scores, + label=legend_label, + linestyle=linestyle, + color=colour, + ) plt.scatter(epoch_min_score, min_score, s=size_max_circle, color=colour) # finalise plot plt.grid() - plt.tick_params(axis='both', labelsize=fontsize) - plt.ylabel('Scores', fontsize=fontsize) - plt.xlabel('Epochs', fontsize=fontsize) + plt.tick_params(axis="both", labelsize=fontsize) + plt.ylabel("Scores", fontsize=fontsize) + plt.xlabel("Epochs", fontsize=fontsize) if y_lim is not None: plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot - plt.title('Validation curves', fontsize=fontsize) + plt.title("Validation curves", fontsize=fontsize) if plot_legend: plt.legend(fontsize=fontsize) plt.tight_layout(pad=1) plt.show() -def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, 6), fontsize=18, - y_lim=None, remove_legend=False): +def draw_learning_curve( + path_tensorboard_files, + architecture_names, + figsize=(11, 6), + fontsize=18, + y_lim=None, + remove_legend=False, +): """This function draws the learning curve of several trainings on the same graph. :param path_tensorboard_files: list of tensorboard files corresponding to the models to plot. :param architecture_names: list of the names of the models @@ -208,7 +274,9 @@ def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, # reformat inputs path_tensorboard_files = utils.reformat_to_list(path_tensorboard_files) architecture_names = utils.reformat_to_list(architecture_names) - assert len(path_tensorboard_files) == len(architecture_names), 'names and tensorboard lists should have same length' + assert len(path_tensorboard_files) == len( + architecture_names + ), "names and tensorboard lists should have same length" # loop over architectures plt.figure(figsize=figsize) @@ -219,11 +287,11 @@ def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, # extract loss at the end of all epochs list_losses = list() list_epochs = list() - logging.getLogger('tensorflow').disabled = True + logging.getLogger("tensorflow").disabled = True for path in path_tensorboard_file: for e in summary_iterator(path): for v in e.summary.value: - if v.tag == 'loss' or v.tag == 'accuracy' or v.tag == 'epoch_loss': + if v.tag == "loss" or v.tag == "accuracy" or v.tag == "epoch_loss": list_losses.append(v.simple_value) list_epochs.append(e.step) plt.plot(np.array(list_epochs), np.array(list_losses), label=name, linewidth=2) @@ -232,11 +300,11 @@ def draw_learning_curve(path_tensorboard_files, architecture_names, figsize=(11, plt.grid() if not remove_legend: plt.legend(fontsize=fontsize) - plt.xlabel('Epochs', fontsize=fontsize) - plt.ylabel('Scores', fontsize=fontsize) + plt.xlabel("Epochs", fontsize=fontsize) + plt.ylabel("Scores", fontsize=fontsize) if y_lim is not None: plt.ylim(y_lim[0], y_lim[1] + 0.01) # set right/left limits of plot - plt.tick_params(axis='both', labelsize=fontsize) - plt.title('Learning curves', fontsize=fontsize) + plt.tick_params(axis="both", labelsize=fontsize) + plt.title("Learning curves", fontsize=fontsize) plt.tight_layout(pad=1) plt.show() diff --git a/nobrainer/models/labels_to_image_model.py b/nobrainer/models/labels_to_image_model.py index fa13a9e2..1068f960 100644 --- a/nobrainer/models/labels_to_image_model.py +++ b/nobrainer/models/labels_to_image_model.py @@ -13,45 +13,45 @@ License. """ +# third-party imports +from ext.lab2im import edit_tensors as l2i_et +from ext.lab2im import layers, utils +from ext.lab2im.edit_volumes import get_ras_axes +import keras.layers as KL +from keras.models import Model # python imports import numpy as np import tensorflow as tf -import keras.layers as KL -from keras.models import Model - -# third-party imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im import edit_tensors as l2i_et -from ext.lab2im.edit_volumes import get_ras_axes -def labels_to_image_model(labels_shape, - n_channels, - generation_labels, - output_labels, - n_neutral_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - flipping=True, - aff=None, - scaling_bounds=0.2, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.0625, - randomise_res=False, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.5, - bias_scale=.025, - return_gradients=False): +def labels_to_image_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + n_neutral_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, + randomise_res=False, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.5, + bias_scale=0.025, + return_gradients=False, +): """ This function builds a keras/tensorflow model to generate images from provided label maps. The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. @@ -147,28 +147,50 @@ def labels_to_image_model(labels_shape, labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims, n_channels) - data_res = atlas_res if data_res is None else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) - thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + data_res = ( + atlas_res + if data_res is None + else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) + ) + thickness = ( + data_res + if thickness is None + else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + ) atlas_res = atlas_res[0] - target_res = atlas_res if target_res is None else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if target_res is None + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape + [1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='std_devs_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="std_devs_input" + ) list_inputs = [labels_input, means_input, stds_input] # deform labels - labels = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )(labels_input) # cropping if crop_shape != labels_shape: @@ -176,58 +198,92 @@ def labels_to_image_model(labels_shape, # flipping if flipping: - assert aff is not None, 'aff should not be None if flipping is True' - labels = layers.RandomFlip(get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels)(labels) + assert aff is not None, "aff should not be None if flipping is True" + labels = layers.RandomFlip( + get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels + )(labels) # build synthetic image - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field if bias_field_std > 0: image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) # intensity augmentation - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.5, separate_channels=True)(image) + image = layers.IntensityAugmentation( + clip=300, normalise=True, gamma_std=0.5, separate_channels=True + )(image) # loop over channels channels = list() - split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) for i, channel in enumerate(split): if randomise_res: - max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) - max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) max_res = np.maximum(max_res_iso, max_res_aniso) - resolution, blur_res = layers.SampleResolution(atlas_res, max_res_iso, max_res_aniso)(means_input) - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) - channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), 1.03)([channel, sigma]) - channel = layers.MimicAcquisition(atlas_res, atlas_res, output_shape, False)([channel, resolution]) + resolution, blur_res = layers.SampleResolution( + atlas_res, max_res_iso, max_res_aniso + )(means_input) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), 1.03 + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, atlas_res, output_shape, False + )([channel, resolution]) channels.append(channel) else: - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, data_res[i], thickness=thickness[i]) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, data_res[i], thickness=thickness[i] + ) channel = layers.GaussianBlur(sigma, 1.03)(channel) - resolution = KL.Lambda(lambda x: tf.convert_to_tensor(data_res[i], dtype='float32'))([]) - channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)([channel, resolution]) + resolution = KL.Lambda( + lambda x: tf.convert_to_tensor(data_res[i], dtype="float32") + )([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)( + [channel, resolution] + ) channels.append(channel) # concatenate all channels back - image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) # compute image gradient if return_gradients: - image = layers.ImageGradients('sobel', True, name='image_gradients')(image) + image = layers.ImageGradients("sobel", True, name="image_gradients")(image) image = layers.IntensityAugmentation(clip=10, normalise=True)(image) # resample labels at target resolution if crop_shape != output_shape: - labels = l2i_et.resample_tensor(labels, output_shape, interp_method='nearest') + labels = l2i_et.resample_tensor(labels, output_shape, interp_method="nearest") # map generation labels to segmentation values - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) brain_model = Model(inputs=list_inputs, outputs=[image, labels]) return brain_model @@ -248,25 +304,39 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # output shape specified, need to get cropping shape, and resample shape if necessary if output_shape is not None: - output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") # make sure that output shape is smaller or equal to label shape if resample_factor is not None: - output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] else: - output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] if output_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) output_shape = tmp_shape # get cropping and resample shape if resample_factor is not None: - cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] else: cropping_shape = output_shape @@ -278,19 +348,32 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # if resampling, get the potential output_shape and check if it is divisible by n if resample_factor is not None: - output_shape = [int(labels_shape[i] * resample_factor[i]) for i in range(n_dims)] - output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] - cropping_shape = [int(np.around(output_shape[i] / resample_factor[i], 0)) for i in range(n_dims)] + output_shape = [ + int(labels_shape[i] * resample_factor[i]) for i in range(n_dims) + ] + output_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] # if no resampling, simply check if image_shape is divisible by n else: - cropping_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in labels_shape] + cropping_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in labels_shape + ] output_shape = cropping_shape # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape else: cropping_shape = labels_shape if resample_factor is not None: - output_shape = [int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims)] + output_shape = [ + int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims) + ] else: output_shape = cropping_shape diff --git a/nobrainer/processing/brain_generator.py b/nobrainer/processing/brain_generator.py index 42ce58bc..f368f539 100644 --- a/nobrainer/processing/brain_generator.py +++ b/nobrainer/processing/brain_generator.py @@ -13,52 +13,54 @@ License. """ - # python imports import numpy as np +from nobrainer.ext.SynthSeg.labels_to_image_model import labels_to_image_model + # project imports from nobrainer.ext.SynthSeg.model_inputs import build_model_inputs -from nobrainer.ext.SynthSeg.labels_to_image_model import labels_to_image_model # third-party imports -from nobrainer.ext.lab2im import utils, edit_volumes +from nobrainer.ext.lab2im import edit_volumes, utils class BrainGenerator: - def __init__(self, - labels_dir, - generation_labels=None, - n_neutral_labels=None, - output_labels=None, - subjects_prob=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - output_div_by_n=None, - prior_distributions='uniform', - generation_classes=None, - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False, - flipping=True, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - randomise_res=True, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - return_gradients=False): + def __init__( + self, + labels_dir, + generation_labels=None, + n_neutral_labels=None, + output_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + prior_distributions="uniform", + generation_classes=None, + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, + ): """ This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images from labels maps, and a python generator that supplies the input data for this model. @@ -195,16 +197,21 @@ def __init__(self, # prepare data files self.labels_paths = utils.list_images_in_folder(labels_dir) if subjects_prob is not None: - self.subjects_prob = np.array(utils.reformat_to_list(subjects_prob, load_as_numpy=True), dtype='float32') - assert len(self.subjects_prob) == len(self.labels_paths), \ - 'subjects_prob should have the same length as labels_path, ' \ - 'had {} and {}'.format(len(self.subjects_prob), len(self.labels_paths)) + self.subjects_prob = np.array( + utils.reformat_to_list(subjects_prob, load_as_numpy=True), + dtype="float32", + ) + assert len(self.subjects_prob) == len(self.labels_paths), ( + "subjects_prob should have the same length as labels_path, " + "had {} and {}".format(len(self.subjects_prob), len(self.labels_paths)) + ) else: self.subjects_prob = None # generation parameters - self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) self.n_channels = n_channels if generation_labels is not None: self.generation_labels = utils.load_array_if_path(generation_labels) @@ -228,11 +235,13 @@ def __init__(self, self.prior_distributions = prior_distributions if generation_classes is not None: self.generation_classes = utils.load_array_if_path(generation_classes) - assert self.generation_classes.shape == self.generation_labels.shape, \ - 'if provided, generation_classes should have the same shape as generation_labels' + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation_classes should have the same shape as generation_labels" unique_classes = np.unique(self.generation_classes) - assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ - 'generation_classes should a linear range between 0 and its maximum value.' + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." else: self.generation_classes = np.arange(self.generation_labels.shape[0]) self.prior_means = utils.load_array_if_path(prior_means) @@ -251,8 +260,9 @@ def __init__(self, self.max_res_iso = max_res_iso self.max_res_aniso = max_res_aniso self.data_res = utils.load_array_if_path(data_res) - assert not (self.randomise_res & (self.data_res is not None)), \ - 'randomise_res and data_res cannot be provided at the same time' + assert not ( + self.randomise_res & (self.data_res is not None) + ), "randomise_res and data_res cannot be provided at the same time" self.thickness = utils.load_array_if_path(thickness) # bias field parameters self.bias_field_std = bias_field_std @@ -260,57 +270,65 @@ def __init__(self, self.return_gradients = return_gradients # build transformation model - self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model() + self.labels_to_image_model, self.model_output_shape = ( + self._build_labels_to_image_model() + ) # build generator for model inputs - self.model_inputs_generator = self._build_model_inputs_generator(mix_prior_and_random) + self.model_inputs_generator = self._build_model_inputs_generator( + mix_prior_and_random + ) # build brain generator self.brain_generator = self._build_brain_generator() def _build_labels_to_image_model(self): # build_model - lab_to_im_model = labels_to_image_model(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - n_neutral_labels=self.n_neutral_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - flipping=self.flipping, - aff=np.eye(4), - scaling_bounds=self.scaling_bounds, - rotation_bounds=self.rotation_bounds, - shearing_bounds=self.shearing_bounds, - translation_bounds=self.translation_bounds, - nonlin_std=self.nonlin_std, - nonlin_scale=self.nonlin_scale, - randomise_res=self.randomise_res, - max_res_iso=self.max_res_iso, - max_res_aniso=self.max_res_aniso, - data_res=self.data_res, - thickness=self.thickness, - bias_field_std=self.bias_field_std, - bias_scale=self.bias_scale, - return_gradients=self.return_gradients) + lab_to_im_model = labels_to_image_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + n_neutral_labels=self.n_neutral_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + flipping=self.flipping, + aff=np.eye(4), + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + randomise_res=self.randomise_res, + max_res_iso=self.max_res_iso, + max_res_aniso=self.max_res_aniso, + data_res=self.data_res, + thickness=self.thickness, + bias_field_std=self.bias_field_std, + bias_scale=self.bias_scale, + return_gradients=self.return_gradients, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape def _build_model_inputs_generator(self, mix_prior_and_random): # build model's inputs generator - model_inputs_generator = build_model_inputs(path_label_maps=self.labels_paths, - n_labels=len(self.generation_labels), - batchsize=self.batchsize, - n_channels=self.n_channels, - subjects_prob=self.subjects_prob, - generation_classes=self.generation_classes, - prior_means=self.prior_means, - prior_stds=self.prior_stds, - prior_distributions=self.prior_distributions, - use_specific_stats_for_channel=self.use_specific_stats_for_channel, - mix_prior_and_random=mix_prior_and_random) + model_inputs_generator = build_model_inputs( + path_label_maps=self.labels_paths, + n_labels=len(self.generation_labels), + batchsize=self.batchsize, + n_channels=self.n_channels, + subjects_prob=self.subjects_prob, + generation_classes=self.generation_classes, + prior_means=self.prior_means, + prior_stds=self.prior_stds, + prior_distributions=self.prior_distributions, + use_specific_stats_for_channel=self.use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random, + ) return model_inputs_generator def _build_brain_generator(self): @@ -326,10 +344,16 @@ def generate_brain(self): list_images = list() list_labels = list() for i in range(self.batchsize): - list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), - aff_ref=self.aff, n_dims=self.n_dims)) - list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), - aff_ref=self.aff, n_dims=self.n_dims)) + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) image = np.squeeze(np.stack(list_images, axis=0)) labels = np.squeeze(np.stack(list_labels, axis=0)) return image, labels