Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Nov 16, 2020
1 parent c7ef2cb commit 04bf716
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 57 deletions.
4 changes: 1 addition & 3 deletions retinaface/configs/2020-11-15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ model:
in_channels: 256
out_channels: 256


optimizer:
type: torch.optim.SGD
lr: 0.001
Expand All @@ -39,9 +38,8 @@ scheduler:
T_mult: 2

train_parameters:
batch_size: 4
batch_size: 6
rotate90: True
box_min_size: 5

checkpoint_callback:
type: pytorch_lightning.callbacks.ModelCheckpoint
Expand Down
25 changes: 15 additions & 10 deletions retinaface/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
from retinaface.box_utils import matrix_iof


def _crop(
def random_crop(
image: np.ndarray, boxes: np.ndarray, labels: np.ndarray, landm: np.ndarray, img_dim: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, bool]:
"""
if random.uniform(0, 1) <= 0.2:
scale = 1.0
else:
scale = random.uniform(0.3, 1.0)
"""
height, width = image.shape[:2]
pad_image_flag = True

for _ in range(250):
"""
if random.uniform(0, 1) <= 0.2:
scale = 1.0
else:
scale = random.uniform(0.3, 1.0)
"""

pre_scales = [0.3, 0.45, 0.6, 0.8, 1.0]
scale = random.choice(pre_scales)
short_side = min(width, height)
Expand Down Expand Up @@ -80,7 +81,9 @@ def _crop(
return image, boxes, labels, landm, pad_image_flag


def _mirror(image: np.ndarray, boxes: np.ndarray, landms: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
def random_horizontal_flip(
image: np.ndarray, boxes: np.ndarray, landms: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
width = image.shape[1]
if random.randrange(2):
image = image[:, ::-1]
Expand Down Expand Up @@ -124,10 +127,12 @@ def __call__(self, image: np.ndarray, targets: np.ndarray) -> Tuple[np.ndarray,
landmarks = targets[:, 4:-1].copy()
labels = targets[:, -1:].copy()

image_t, boxes_t, labels_t, landmarks_t, pad_image_flag = _crop(image, boxes, labels, landmarks, self.img_dim)
image_t, boxes_t, labels_t, landmarks_t, pad_image_flag = random_crop(
image, boxes, labels, landmarks, self.img_dim
)

image_t = _pad_to_square(image_t, pad_image_flag)
image_t, boxes_t, landmarks_t = _mirror(image_t, boxes_t, landmarks_t)
image_t, boxes_t, landmarks_t = random_horizontal_flip(image_t, boxes_t, landmarks_t)
height, width = image_t.shape[:2]

boxes_t[:, 0::2] = boxes_t[:, 0::2] / width
Expand Down
16 changes: 12 additions & 4 deletions retinaface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils import data

from retinaface.data_augment import Preproc
from retinaface.utils import filer_labels


class FaceDetectionDataset(data.Dataset):
Expand All @@ -21,7 +20,6 @@ def __init__(
transform: albu.Compose,
preproc: Preproc,
rotate90: bool = False,
box_min_size: int = 5,
) -> None:
self.preproc = preproc

Expand All @@ -33,7 +31,7 @@ def __init__(
with open(label_path) as f:
labels = json.load(f)

self.labels = filer_labels(labels, image_path, min_size=box_min_size)
self.labels = [x for x in labels if (image_path / x["file_name"]).exists()]

def __len__(self) -> int:
return len(self.labels)
Expand All @@ -45,14 +43,24 @@ def __getitem__(self, index: int) -> Dict[str, Any]:

image = load_rgb(self.image_path / file_name)

image_height, image_width = image.shape[:2]

# annotations will have the format
# 4: box, 10 landmarks, 1: landmarks / no landmarks
num_annotations = 4 + 10 + 1
annotations = np.zeros((0, num_annotations))

for label in labels["annotations"]:
annotation = np.zeros((1, num_annotations))
annotation[0, :4] = label["bbox"]

x_min, y_min, x_max, y_max = label["bbox"]

x_min = np.clip(x_min, 0, image_width - 1)
y_min = np.clip(y_min, 0, image_height - 1)
x_max = np.clip(x_max, x_min + 1, image_width - 1)
y_max = np.clip(y_max, y_min, image_height - 1)

annotation[0, :4] = x_min, y_min, x_max, y_max

if "landmarks" in label and label["landmarks"]:
landmarks = np.array(label["landmarks"])
Expand Down
2 changes: 1 addition & 1 deletion retinaface/predict_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import albumentations as A
import numpy as np
import torch
from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
from iglovikov_helper_functions.utils.image_utils import pad_to_size, unpad_from_size
from torch.nn import functional as F
from torchvision.ops import nms

from retinaface.box_utils import decode, decode_landm
from retinaface.network import RetinaFace
from retinaface.prior_box import priorbox
from retinaface.utils import tensor_from_rgb_image


class Model:
Expand Down
4 changes: 2 additions & 2 deletions retinaface/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def train_dataloader(self):
transform=from_dict(self.config.train_aug),
preproc=self.preproc,
rotate90=self.config.train_parameters.rotate90,
box_min_size=self.config.train_parameters.box_min_size,
),
batch_size=self.config.train_parameters.batch_size,
num_workers=self.config.num_workers,
Expand All @@ -86,7 +85,6 @@ def val_dataloader(self):
transform=from_dict(self.config.val_aug),
preproc=self.preproc,
rotate90=self.config.val_parameters.rotate90,
box_min_size=self.config.val_parameters.box_min_size,
),
batch_size=self.config.val_parameters.batch_size,
num_workers=self.config.num_workers,
Expand Down Expand Up @@ -236,6 +234,8 @@ def main():
with open(args.config_path) as f:
config = Adict(yaml.load(f, Loader=yaml.SafeLoader))

pl.trainer.seed_everything(config.seed)

pipeline = RetinaFace(config)

Path(config.checkpoint_callback.filepath).mkdir(exist_ok=True, parents=True)
Expand Down
37 changes: 0 additions & 37 deletions retinaface/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List

import cv2
import numpy as np
from PIL import Image


def vis_annotations(image: np.ndarray, annotations: List[Dict[str, Any]]) -> np.ndarray:
Expand All @@ -24,38 +22,3 @@ def vis_annotations(image: np.ndarray, annotations: List[Dict[str, Any]]) -> np.

vis_image = cv2.rectangle(vis_image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=2)
return vis_image


def filer_labels(labels: List[Dict], image_path: Path, min_size: int) -> List[Dict]:
result: List[Dict[str, Any]] = []

print("Before = ", len(labels))

for label in labels:
if not (image_path / label["file_name"]).exists():
continue

temp: List[Dict[str, Any]] = []

width, height = Image.open(image_path / label["file_name"]).size

for annotation in label["annotations"]:
x_min, y_min, x_max, y_max = annotation["bbox"]

x_min = np.clip(x_min, 0, width - 1)
y_min = np.clip(y_min, 0, height - 1)
x_max = np.clip(x_max, x_min + 1, width - 1)
y_max = np.clip(y_max, y_min + 1, height - 1)

annotation["bbox"] = x_min, y_min, x_max, y_max

if x_max - x_min >= min_size and y_max - y_min >= min_size:
temp += [annotation]

if len(temp) > 0:
label["annotation"] = temp
result += [label]

print("After = ", len(result))

return result

0 comments on commit 04bf716

Please sign in to comment.