-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f1feed4
commit ad6afbc
Showing
41 changed files
with
7,227 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import math | ||
|
||
import cv2 | ||
import matplotlib.cm as cm | ||
import numpy as np | ||
|
||
from scipy import ndimage | ||
from scipy.ndimage import measurements | ||
from scipy.ndimage.filters import gaussian_filter | ||
from scipy.ndimage.interpolation import affine_transform, map_coordinates | ||
|
||
from skimage import morphology as morph | ||
|
||
from misc.utils import cropping_center, get_bounding_box | ||
|
||
|
||
#### | ||
def fix_mirror_padding(ann): | ||
"""Deal with duplicated instances due to mirroring in interpolation | ||
during shape augmentation (scale, rotation etc.). | ||
""" | ||
current_max_id = np.amax(ann) | ||
inst_list = list(np.unique(ann)) | ||
inst_list.remove(0) # 0 is background | ||
for inst_id in inst_list: | ||
inst_map = np.array(ann == inst_id, np.uint8) | ||
remapped_ids = measurements.label(inst_map)[0] | ||
remapped_ids[remapped_ids > 1] += current_max_id | ||
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] | ||
current_max_id = np.amax(ann) | ||
return ann | ||
|
||
|
||
#### | ||
def gaussian_blur(images, random_state, parents, hooks, max_ksize=3): | ||
"""Apply Gaussian blur to input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
ksize = random_state.randint(0, max_ksize, size=(2,)) | ||
ksize = tuple((ksize * 2 + 1).tolist()) | ||
|
||
ret = cv2.GaussianBlur( | ||
img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE | ||
) | ||
ret = np.reshape(ret, img.shape) | ||
ret = ret.astype(np.uint8) | ||
return [ret] | ||
|
||
|
||
#### | ||
def median_blur(images, random_state, parents, hooks, max_ksize=3): | ||
"""Apply median blur to input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
ksize = random_state.randint(0, max_ksize) | ||
ksize = ksize * 2 + 1 | ||
ret = cv2.medianBlur(img, ksize) | ||
ret = ret.astype(np.uint8) | ||
return [ret] | ||
|
||
|
||
#### | ||
def add_to_hue(images, random_state, parents, hooks, range=None): | ||
"""Perturbe the hue of input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
hue = random_state.uniform(*range) | ||
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | ||
if hsv.dtype.itemsize == 1: | ||
# OpenCV uses 0-179 for 8-bit images | ||
hsv[..., 0] = (hsv[..., 0] + hue) % 180 | ||
else: | ||
# OpenCV uses 0-360 for floating point images | ||
hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360 | ||
ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) | ||
ret = ret.astype(np.uint8) | ||
return [ret] | ||
|
||
|
||
#### | ||
def add_to_saturation(images, random_state, parents, hooks, range=None): | ||
"""Perturbe the saturation of input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
value = 1 + random_state.uniform(*range) | ||
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | ||
ret = img * value + (gray * (1 - value))[:, :, np.newaxis] | ||
ret = np.clip(ret, 0, 255) | ||
ret = ret.astype(np.uint8) | ||
return [ret] | ||
|
||
|
||
#### | ||
def add_to_contrast(images, random_state, parents, hooks, range=None): | ||
"""Perturbe the contrast of input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
value = random_state.uniform(*range) | ||
mean = np.mean(img, axis=(0, 1), keepdims=True) | ||
ret = img * value + mean * (1 - value) | ||
ret = np.clip(img, 0, 255) | ||
ret = ret.astype(np.uint8) | ||
return [ret] | ||
|
||
|
||
#### | ||
def add_to_brightness(images, random_state, parents, hooks, range=None): | ||
"""Perturbe the brightness of input images.""" | ||
img = images[0] # aleju input batch as default (always=1 in our case) | ||
value = random_state.uniform(*range) | ||
ret = np.clip(img + value, 0, 255) | ||
ret = ret.astype(np.uint8) | ||
return [ret] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import sys | ||
import math | ||
import numpy as np | ||
import cv2 | ||
import matplotlib.pyplot as plt | ||
|
||
import torch | ||
import torch.utils.data as data | ||
|
||
import psutil | ||
|
||
|
||
#### | ||
class SerializeFileList(data.IterableDataset): | ||
"""Read a single file as multiple patches of same shape, perform the padding beforehand.""" | ||
|
||
def __init__(self, img_list, patch_info_list, patch_size, preproc=None): | ||
super().__init__() | ||
self.patch_size = patch_size | ||
|
||
self.img_list = img_list | ||
self.patch_info_list = patch_info_list | ||
|
||
self.worker_start_img_idx = 0 | ||
# * for internal worker state | ||
self.curr_img_idx = 0 | ||
self.stop_img_idx = 0 | ||
self.curr_patch_idx = 0 | ||
self.stop_patch_idx = 0 | ||
self.preproc = preproc | ||
return | ||
|
||
def __iter__(self): | ||
worker_info = torch.utils.data.get_worker_info() | ||
if worker_info is None: # single-process data loading, return the full iterator | ||
self.stop_img_idx = len(self.img_list) | ||
self.stop_patch_idx = len(self.patch_info_list) | ||
return self | ||
else: # in a worker process so split workload, return a reduced copy of self | ||
per_worker = len(self.patch_info_list) / float(worker_info.num_workers) | ||
per_worker = int(math.ceil(per_worker)) | ||
|
||
global_curr_patch_idx = worker_info.id * per_worker | ||
global_stop_patch_idx = global_curr_patch_idx + per_worker | ||
self.patch_info_list = self.patch_info_list[ | ||
global_curr_patch_idx:global_stop_patch_idx | ||
] | ||
self.curr_patch_idx = 0 | ||
self.stop_patch_idx = len(self.patch_info_list) | ||
# * check img indexer, implicit protocol in infer.py | ||
global_curr_img_idx = self.patch_info_list[0][-1] | ||
global_stop_img_idx = self.patch_info_list[-1][-1] + 1 | ||
self.worker_start_img_idx = global_curr_img_idx | ||
self.img_list = self.img_list[global_curr_img_idx:global_stop_img_idx] | ||
self.curr_img_idx = 0 | ||
self.stop_img_idx = len(self.img_list) | ||
return self # does it mutate source copy? | ||
|
||
def __next__(self): | ||
|
||
if self.curr_patch_idx >= self.stop_patch_idx: | ||
raise StopIteration # when there is nothing more to yield | ||
patch_info = self.patch_info_list[self.curr_patch_idx] | ||
img_ptr = self.img_list[patch_info[-1] - self.worker_start_img_idx] | ||
patch_data = img_ptr[ | ||
patch_info[0] : patch_info[0] + self.patch_size, | ||
patch_info[1] : patch_info[1] + self.patch_size, | ||
] | ||
self.curr_patch_idx += 1 | ||
if self.preproc is not None: | ||
patch_data = self.preproc(patch_data) | ||
return patch_data, patch_info | ||
|
||
|
||
#### | ||
class SerializeArray(data.Dataset): | ||
def __init__(self, mmap_array_path, patch_info_list, patch_size, preproc=None): | ||
super().__init__() | ||
self.patch_size = patch_size | ||
|
||
# use mmap as intermediate sharing, else variable will be duplicated | ||
# accross torch worker => OOM error, open in read only mode | ||
self.image = np.load(mmap_array_path, mmap_mode="r") | ||
|
||
self.patch_info_list = patch_info_list | ||
self.preproc = preproc | ||
return | ||
|
||
def __len__(self): | ||
return len(self.patch_info_list) | ||
|
||
def __getitem__(self, idx): | ||
patch_info = self.patch_info_list[idx] | ||
patch_data = self.image[ | ||
patch_info[0] : patch_info[0] + self.patch_size[0], | ||
patch_info[1] : patch_info[1] + self.patch_size[1], | ||
] | ||
if self.preproc is not None: | ||
patch_data = self.preproc(patch_data) | ||
return patch_data, patch_info |
Oops, something went wrong.