Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyongjianCODE authored Jun 2, 2023
1 parent f1feed4 commit ad6afbc
Show file tree
Hide file tree
Showing 41 changed files with 7,227 additions and 0 deletions.
Empty file added dataloader/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions dataloader/augs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import math

import cv2
import matplotlib.cm as cm
import numpy as np

from scipy import ndimage
from scipy.ndimage import measurements
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import affine_transform, map_coordinates

from skimage import morphology as morph

from misc.utils import cropping_center, get_bounding_box


####
def fix_mirror_padding(ann):
"""Deal with duplicated instances due to mirroring in interpolation
during shape augmentation (scale, rotation etc.).
"""
current_max_id = np.amax(ann)
inst_list = list(np.unique(ann))
inst_list.remove(0) # 0 is background
for inst_id in inst_list:
inst_map = np.array(ann == inst_id, np.uint8)
remapped_ids = measurements.label(inst_map)[0]
remapped_ids[remapped_ids > 1] += current_max_id
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
current_max_id = np.amax(ann)
return ann


####
def gaussian_blur(images, random_state, parents, hooks, max_ksize=3):
"""Apply Gaussian blur to input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
ksize = random_state.randint(0, max_ksize, size=(2,))
ksize = tuple((ksize * 2 + 1).tolist())

ret = cv2.GaussianBlur(
img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE
)
ret = np.reshape(ret, img.shape)
ret = ret.astype(np.uint8)
return [ret]


####
def median_blur(images, random_state, parents, hooks, max_ksize=3):
"""Apply median blur to input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
ksize = random_state.randint(0, max_ksize)
ksize = ksize * 2 + 1
ret = cv2.medianBlur(img, ksize)
ret = ret.astype(np.uint8)
return [ret]


####
def add_to_hue(images, random_state, parents, hooks, range=None):
"""Perturbe the hue of input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
hue = random_state.uniform(*range)
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
if hsv.dtype.itemsize == 1:
# OpenCV uses 0-179 for 8-bit images
hsv[..., 0] = (hsv[..., 0] + hue) % 180
else:
# OpenCV uses 0-360 for floating point images
hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360
ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
ret = ret.astype(np.uint8)
return [ret]


####
def add_to_saturation(images, random_state, parents, hooks, range=None):
"""Perturbe the saturation of input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
value = 1 + random_state.uniform(*range)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
ret = img * value + (gray * (1 - value))[:, :, np.newaxis]
ret = np.clip(ret, 0, 255)
ret = ret.astype(np.uint8)
return [ret]


####
def add_to_contrast(images, random_state, parents, hooks, range=None):
"""Perturbe the contrast of input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
value = random_state.uniform(*range)
mean = np.mean(img, axis=(0, 1), keepdims=True)
ret = img * value + mean * (1 - value)
ret = np.clip(img, 0, 255)
ret = ret.astype(np.uint8)
return [ret]


####
def add_to_brightness(images, random_state, parents, hooks, range=None):
"""Perturbe the brightness of input images."""
img = images[0] # aleju input batch as default (always=1 in our case)
value = random_state.uniform(*range)
ret = np.clip(img + value, 0, 255)
ret = ret.astype(np.uint8)
return [ret]
100 changes: 100 additions & 0 deletions dataloader/infer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import sys
import math
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data

import psutil


####
class SerializeFileList(data.IterableDataset):
"""Read a single file as multiple patches of same shape, perform the padding beforehand."""

def __init__(self, img_list, patch_info_list, patch_size, preproc=None):
super().__init__()
self.patch_size = patch_size

self.img_list = img_list
self.patch_info_list = patch_info_list

self.worker_start_img_idx = 0
# * for internal worker state
self.curr_img_idx = 0
self.stop_img_idx = 0
self.curr_patch_idx = 0
self.stop_patch_idx = 0
self.preproc = preproc
return

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
self.stop_img_idx = len(self.img_list)
self.stop_patch_idx = len(self.patch_info_list)
return self
else: # in a worker process so split workload, return a reduced copy of self
per_worker = len(self.patch_info_list) / float(worker_info.num_workers)
per_worker = int(math.ceil(per_worker))

global_curr_patch_idx = worker_info.id * per_worker
global_stop_patch_idx = global_curr_patch_idx + per_worker
self.patch_info_list = self.patch_info_list[
global_curr_patch_idx:global_stop_patch_idx
]
self.curr_patch_idx = 0
self.stop_patch_idx = len(self.patch_info_list)
# * check img indexer, implicit protocol in infer.py
global_curr_img_idx = self.patch_info_list[0][-1]
global_stop_img_idx = self.patch_info_list[-1][-1] + 1
self.worker_start_img_idx = global_curr_img_idx
self.img_list = self.img_list[global_curr_img_idx:global_stop_img_idx]
self.curr_img_idx = 0
self.stop_img_idx = len(self.img_list)
return self # does it mutate source copy?

def __next__(self):

if self.curr_patch_idx >= self.stop_patch_idx:
raise StopIteration # when there is nothing more to yield
patch_info = self.patch_info_list[self.curr_patch_idx]
img_ptr = self.img_list[patch_info[-1] - self.worker_start_img_idx]
patch_data = img_ptr[
patch_info[0] : patch_info[0] + self.patch_size,
patch_info[1] : patch_info[1] + self.patch_size,
]
self.curr_patch_idx += 1
if self.preproc is not None:
patch_data = self.preproc(patch_data)
return patch_data, patch_info


####
class SerializeArray(data.Dataset):
def __init__(self, mmap_array_path, patch_info_list, patch_size, preproc=None):
super().__init__()
self.patch_size = patch_size

# use mmap as intermediate sharing, else variable will be duplicated
# accross torch worker => OOM error, open in read only mode
self.image = np.load(mmap_array_path, mmap_mode="r")

self.patch_info_list = patch_info_list
self.preproc = preproc
return

def __len__(self):
return len(self.patch_info_list)

def __getitem__(self, idx):
patch_info = self.patch_info_list[idx]
patch_data = self.image[
patch_info[0] : patch_info[0] + self.patch_size[0],
patch_info[1] : patch_info[1] + self.patch_size[1],
]
if self.preproc is not None:
patch_data = self.preproc(patch_data)
return patch_data, patch_info
Loading

0 comments on commit ad6afbc

Please sign in to comment.