diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..c1ba535 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +loss/ +data/ +cache/ +tf_cache/ +debug/ +results/ + +misc/outputs + +evaluation/evaluate_object +evaluation/analyze_object + +nnet/__pycache__/ + +*.swp + +*.pyc +*.o* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9f3e4eb --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Princeton University +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..0a36645 --- /dev/null +++ b/README.md @@ -0,0 +1,151 @@ +# CornerNet-Lite: Training, Evaluation and Testing Code +Code for reproducing results in the following paper: + +**CornerNet-Lite: Efficient Keypoint Based Object Detection** +Hei Law, Yun Teng, Olga Russakovsky, Jia Deng +*arXiv* + +## Getting Started +### Software Requirement +- Python 3.7 +- PyTorch 1.0.0 +- CUDA 10 +- GCC 4.9.2 or above + +### Installing Dependencies +Please first install [Anaconda](https://anaconda.org) and create an Anaconda environment using the provided package list `conda_packagelist.txt`. +``` +conda create --name CornerNet_Lite --file conda_packagelist.txt --channel pytorch +``` + +After you create the environment, please activate it. +``` +source activate CornerNet_Lite +``` + +### Compiling Corner Pooling Layers +Compile the C++ implementation of the corner pooling layers. (GCC4.9.2 or above is required.) +``` +cd /core/models/py_utils/_cpools/ +python setup.py install --user +``` + +### Compiling NMS +Compile the NMS code which are originally from [Faster R-CNN](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/cpu_nms.pyx) and [Soft-NMS](https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx). +``` +cd /core/external +make +``` + +### Downloading Models +In this repo, we provide models for the following detectors: +- [CornerNet-Saccade](https://drive.google.com/file/d/1MQDyPRI0HgDHxHToudHqQ-2m8TVBciaa/view?usp=sharing) +- [CornerNet-Squeeze](https://drive.google.com/file/d/1qM8BBYCLUBcZx_UmLT0qMXNTh-Yshp4X/view?usp=sharing) +- [CornerNet](https://drive.google.com/file/d/1e8At_iZWyXQgLlMwHkB83kN-AN85Uff1/view?usp=sharing) + +Put the CornerNet-Saccade model under `/cache/nnet/CornerNet_Saccade/`, CornerNet-Squeeze model under `/cache/nnet/CornerNet_Squeeze/` and CornerNet model under `/cache/nnet/CornerNet/`. (\* Note we use underscore instead of dash in both the directory names for CornerNet-Saccade and CornerNet-Squeeze.) + +Note: The CornerNet model is the same as the one in the original [CornerNet repo](https://github.com/princeton-vl/CornerNet). We just ported it to this new repo. + +After downloading the models, you should be able to use the detectors on your own images. We provide a demo script `demo.py` to test if the repo is installed correctly. +``` +python demo.py +``` +This script applies CornerNet-Saccade to `demo.jpg` and writes the results to `demo_out.jpg`. + +In the demo script, the default detector is CornerNet-Saccade. You can modify the demo script to test different detectors. For example, if you want to test CornerNet-Squeeze: +```python +#!/usr/bin/env python + +import cv2 +from core.detectors import CornerNet_Squeeze +from core.vis_utils import draw_bboxes + +detector = CornerNet_Squeeze() +image = cv2.imread("demo.jpg") + +bboxes = detector(image) +image = draw_bboxes(image, bboxes) +cv2.imwrite("demo_out.jpg", image) +``` + +### Using CornerNet-Lite in Your Project +It is also easy to use CornerNet-Lite in your project. You will need to change the directory name from `CornerNet-Lite` to `CornerNet_Lite`. Otherwise, you won't be able to import CornerNet-Lite. +``` +Your project +│ README.md +│ ... +│ foo.py +│ +└───CornerNet_Lite +│ +└───directory1 +│ +└───... +``` + +In `foo.py`, you can easily import CornerNet-Saccade by adding: +```python +from CornerNet_Lite import CornerNet_Saccade + +def foo(): + cornernet = CornerNet_Saccade() + # CornerNet_Saccade is ready to use + + image = cv2.imread('/path/to/your/image') + bboxes = cornernet(image) +``` + +If you want to train or evaluate the detectors on COCO, please move on to the following steps. + +## Training and Evaluation + +### Installing MS COCO APIs +``` +mkdir -p /data +cd /data +git clone git@github.com:cocodataset/cocoapi.git coco +cd /data/coco/PythonAPI +make install +``` + +### Downloading MS COCO Data +- Download the training/validation split we use in our paper from [here](https://drive.google.com/file/d/1dop4188xo5lXDkGtOZUzy2SHOD_COXz4/view?usp=sharing) (originally from [Faster R-CNN](https://github.com/rbgirshick/py-faster-rcnn/tree/master/data)) +- Unzip the file and place `annotations` under `/data/coco` +- Download the images (2014 Train, 2014 Val, 2017 Test) from [here](http://cocodataset.org/#download) +- Create 3 directories, `trainval2014`, `minival2014` and `testdev2017`, under `/data/coco/images/` +- Copy the training/validation/testing images to the corresponding directories according to the annotation files + +To train and evaluate a network, you will need to create a configuration file, which defines the hyperparameters, and a model file, which defines the network architecture. The configuration file should be in JSON format and placed in `/configs/`. Each configuration file should have a corresponding model file in `/core/models/`. i.e. If there is a `.json` in `/configs/`, there should be a `.py` in `/core/models/`. There is only one exception which we will mention later. + +### Training and Evaluating a Model +To train a model: +``` +python train.py +``` + +We provide the configuration files and the model files for CornerNet-Saccade, CornerNet-Squeeze and CornerNet in this repo. Please check the configuration files in `/configs/`. + +To train CornerNet-Saccade: +``` +python train.py CornerNet_Saccade +``` +Please adjust the batch size in `CornerNet_Saccade.json` to accommodate the number of GPUs that are available to you. + +To evaluate the trained model: +``` +python evaluate.py CornerNet_Saccade --testiter 500000 --split +``` + +If you want to test different hyperparameters during evaluation and do not want to overwrite the original configuration file, you can do so by creating a configuration file with a suffix (`-.json`). There is no need to create `-.py` in `/core/models/`. + +To use the new configuration file: +``` +python evaluate.py --testiter --split --suffix +``` + +We also include a configuration file for CornerNet under multi-scale setting, which is `CornerNet-multi_scale.json`, in this repo. + +To use the multi-scale configuration file: +``` +python evaluate.py CornerNet --testiter --split --suffix multi_scale diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..dd957ac --- /dev/null +++ b/__init__.py @@ -0,0 +1,2 @@ +from .core.detectors import CornerNet, CornerNet_Squeeze, CornerNet_Saccade +from .core.vis_utils import draw_bboxes diff --git a/conda_packagelist.txt b/conda_packagelist.txt new file mode 100644 index 0000000..346f63b --- /dev/null +++ b/conda_packagelist.txt @@ -0,0 +1,81 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +blas=1.0=mkl +bzip2=1.0.6=h14c3975_5 +ca-certificates=2018.12.5=0 +cairo=1.14.12=h8948797_3 +certifi=2018.11.29=py37_0 +cffi=1.11.5=py37he75722e_1 +cuda100=1.0=0 +cycler=0.10.0=py37_0 +cython=0.28.5=py37hf484d3e_0 +dbus=1.13.2=h714fa37_1 +expat=2.2.6=he6710b0_0 +ffmpeg=4.0=hcdf2ecd_0 +fontconfig=2.13.0=h9420a91_0 +freeglut=3.0.0=hf484d3e_5 +freetype=2.9.1=h8a8886c_1 +glib=2.56.2=hd408876_0 +graphite2=1.3.12=h23475e2_2 +gst-plugins-base=1.14.0=hbbd80ab_1 +gstreamer=1.14.0=hb453b48_1 +harfbuzz=1.8.8=hffaf4a1_0 +hdf5=1.10.2=hba1933b_1 +icu=58.2=h9c2bf20_1 +intel-openmp=2019.0=118 +jasper=2.0.14=h07fcdf6_1 +jpeg=9b=h024ee3a_2 +kiwisolver=1.0.1=py37hf484d3e_0 +libedit=3.1.20170329=h6b74fdf_2 +libffi=3.2.1=hd88cf55_4 +libgcc-ng=8.2.0=hdf63c60_1 +libgfortran-ng=7.3.0=hdf63c60_0 +libglu=9.0.0=hf484d3e_1 +libopencv=3.4.2=hb342d67_1 +libopus=1.2.1=hb9ed12e_0 +libpng=1.6.35=hbc83047_0 +libstdcxx-ng=8.2.0=hdf63c60_1 +libtiff=4.0.9=he85c1e1_2 +libuuid=1.0.3=h1bed415_2 +libvpx=1.7.0=h439df22_0 +libxcb=1.13=h1bed415_1 +libxml2=2.9.8=h26e45fe_1 +matplotlib=3.0.2=py37h5429711_0 +mkl=2018.0.3=1 +mkl_fft=1.0.6=py37h7dd41cf_0 +mkl_random=1.0.1=py37h4414c95_1 +ncurses=6.1=hf484d3e_0 +ninja=1.8.2=py37h6bb024c_1 +numpy=1.15.4=py37h1d66e8a_0 +numpy-base=1.15.4=py37h81de0dd_0 +olefile=0.46=py37_0 +opencv=3.4.2=py37h6fd60c2_1 +openssl=1.1.1a=h7b6447c_0 +pcre=8.42=h439df22_0 +pillow=5.2.0=py37heded4f4_0 +pip=10.0.1=py37_0 +pixman=0.34.0=hceecf20_3 +py-opencv=3.4.2=py37hb342d67_1 +pycparser=2.18=py37_1 +pyparsing=2.2.0=py37_1 +pyqt=5.9.2=py37h05f1152_2 +python=3.7.1=h0371630_3 +python-dateutil=2.7.3=py37_0 +pytorch=1.0.0=py3.7_cuda10.0.130_cudnn7.4.1_1 +pytz=2018.5=py37_0 +qt=5.9.7=h5867ecd_1 +readline=7.0=h7b6447c_5 +scikit-learn=0.19.1=py37hedc7406_0 +scipy=1.1.0=py37hfa4b5c9_1 +setuptools=40.2.0=py37_0 +sip=4.19.8=py37hf484d3e_0 +six=1.11.0=py37_1 +sqlite=3.25.3=h7b6447c_0 +tk=8.6.8=hbc83047_0 +torchvision=0.2.1=py37_1 +tornado=5.1=py37h14c3975_0 +tqdm=4.25.0=py37h28b3542_0 +wheel=0.31.1=py37_0 +xz=5.2.4=h14c3975_4 +zlib=1.2.11=ha838bed_2 diff --git a/configs/CornerNet-multi_scale.json b/configs/CornerNet-multi_scale.json new file mode 100644 index 0000000..9c73f7c --- /dev/null +++ b/configs/CornerNet-multi_scale.json @@ -0,0 +1,54 @@ +{ + "system": { + "dataset": "COCO", + "batch_size": 49, + "sampling_function": "cornernet", + + "train_split": "trainval", + "val_split": "minival", + + "learning_rate": 0.00025, + "decay_rate": 10, + + "val_iter": 100, + + "opt_algo": "adam", + "prefetch_size": 5, + + "max_iter": 500000, + "stepsize": 450000, + "snapshot": 5000, + + "chunk_sizes": [4, 5, 5, 5, 5, 5, 5, 5, 5, 5], + + "data_dir": "./data" + }, + + "db": { + "rand_scale_min": 0.6, + "rand_scale_max": 1.4, + "rand_scale_step": 0.1, + "rand_scales": null, + + "rand_crop": true, + "rand_color": true, + + "border": 128, + "gaussian_bump": true, + + "input_size": [511, 511], + "output_sizes": [[128, 128]], + + "test_scales": [0.5, 0.75, 1, 1.25, 1.5], + + "top_k": 100, + "categories": 80, + "ae_threshold": 0.5, + "nms_threshold": 0.5, + + "merge_bbox": true, + "weight_exp": 10, + + "max_per_image": 100 + } +} diff --git a/configs/CornerNet.json b/configs/CornerNet.json new file mode 100755 index 0000000..fd319c9 --- /dev/null +++ b/configs/CornerNet.json @@ -0,0 +1,52 @@ +{ + "system": { + "dataset": "COCO", + "batch_size": 49, + "sampling_function": "cornernet", + + "train_split": "trainval", + "val_split": "minival", + + "learning_rate": 0.00025, + "decay_rate": 10, + + "val_iter": 100, + + "opt_algo": "adam", + "prefetch_size": 5, + + "max_iter": 500000, + "stepsize": 450000, + "snapshot": 5000, + + "chunk_sizes": [4, 5, 5, 5, 5, 5, 5, 5, 5, 5], + + "data_dir": "./data" + }, + + "db": { + "rand_scale_min": 0.6, + "rand_scale_max": 1.4, + "rand_scale_step": 0.1, + "rand_scales": null, + + "rand_crop": true, + "rand_color": true, + + "border": 128, + "gaussian_bump": true, + "gaussian_iou": 0.3, + + "input_size": [511, 511], + "output_sizes": [[128, 128]], + + "test_scales": [1], + + "top_k": 100, + "categories": 80, + "ae_threshold": 0.5, + "nms_threshold": 0.5, + + "max_per_image": 100 + } +} diff --git a/configs/CornerNet_Saccade.json b/configs/CornerNet_Saccade.json new file mode 100755 index 0000000..533b3fa --- /dev/null +++ b/configs/CornerNet_Saccade.json @@ -0,0 +1,56 @@ +{ + "system": { + "dataset": "COCO", + "batch_size": 48, + "sampling_function": "cornernet_saccade", + + "train_split": "trainval", + "val_split": "minival", + + "learning_rate": 0.00025, + "decay_rate": 10, + + "val_iter": 100, + + "opt_algo": "adam", + "prefetch_size": 5, + + "max_iter": 500000, + "stepsize": 450000, + "snapshot": 5000, + + "chunk_sizes": [12, 12, 12, 12] + }, + + "db": { + "rand_scale_min": 0.5, + "rand_scale_max": 1.1, + "rand_scale_step": 0.1, + "rand_scales": null, + + "rand_full_crop": true, + "gaussian_bump": true, + "gaussian_iou": 0.5, + + "min_scale": 16, + "view_sizes": [], + + "height_mult": 31, + "width_mult": 31, + + "input_size": [255, 255], + "output_sizes": [[64, 64]], + + "att_max_crops": 30, + "att_scales": [[1, 2, 4]], + "att_thresholds": [0.3], + + "top_k": 12, + "num_dets": 12, + "categories": 80, + "ae_threshold": 0.3, + "nms_threshold": 0.5, + + "max_per_image": 100 + } +} diff --git a/configs/CornerNet_Squeeze.json b/configs/CornerNet_Squeeze.json new file mode 100755 index 0000000..6d36798 --- /dev/null +++ b/configs/CornerNet_Squeeze.json @@ -0,0 +1,54 @@ +{ + "system": { + "dataset": "COCO", + "batch_size": 55, + "sampling_function": "cornernet", + + "train_split": "trainval", + "val_split": "minival", + + "learning_rate": 0.00025, + "decay_rate": 10, + + "val_iter": 100, + + "opt_algo": "adam", + "prefetch_size": 5, + + "max_iter": 500000, + "stepsize": 450000, + "snapshot": 5000, + + "chunk_sizes": [13, 14, 14, 14], + + "data_dir": "./data" + }, + + "db": { + "rand_scale_min": 0.6, + "rand_scale_max": 1.4, + "rand_scale_step": 0.1, + "rand_scales": null, + + "rand_crop": true, + "rand_color": true, + + "border": 128, + "gaussian_bump": true, + "gaussian_iou": 0.3, + + "input_size": [511, 511], + "output_sizes": [[64, 64]], + + "test_scales": [1], + "test_flipped": false, + + "top_k": 20, + "num_dets": 100, + "categories": 80, + "ae_threshold": 0.5, + "nms_threshold": 0.5, + + "max_per_image": 100 + } +} diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/base.py b/core/base.py new file mode 100644 index 0000000..03d41a8 --- /dev/null +++ b/core/base.py @@ -0,0 +1,36 @@ +import json + +from .nnet.py_factory import NetworkFactory + +class Base(object): + def __init__(self, db, nnet, func, model=None): + super(Base, self).__init__() + + self._db = db + self._nnet = nnet + self._func = func + + if model is not None: + self._nnet.load_pretrained_params(model) + + self._nnet.cuda() + self._nnet.eval_mode() + + def _inference(self, image, *args, **kwargs): + return self._func(self._db, self._nnet, image.copy(), *args, **kwargs) + + def __call__(self, image, *args, **kwargs): + categories = self._db.configs["categories"] + bboxes = self._inference(image, *args, **kwargs) + return {self._db.cls2name(j): bboxes[j] for j in range(1, categories + 1)} + +def load_cfg(cfg_file): + with open(cfg_file, "r") as f: + cfg = json.load(f) + + cfg_sys = cfg["system"] + cfg_db = cfg["db"] + return cfg_sys, cfg_db + +def load_nnet(cfg_sys, model): + return NetworkFactory(cfg_sys, model) diff --git a/core/config.py b/core/config.py new file mode 100755 index 0000000..bba2a35 --- /dev/null +++ b/core/config.py @@ -0,0 +1,162 @@ +import os +import numpy as np + +class SystemConfig(object): + def __init__(self): + self._configs = {} + self._configs["dataset"] = None + self._configs["sampling_function"] = "coco_detection" + + # Training Config + self._configs["display"] = 5 + self._configs["snapshot"] = 400 + self._configs["stepsize"] = 5000 + self._configs["learning_rate"] = 0.001 + self._configs["decay_rate"] = 10 + self._configs["max_iter"] = 100000 + self._configs["val_iter"] = 20 + self._configs["batch_size"] = 1 + self._configs["snapshot_name"] = None + self._configs["prefetch_size"] = 100 + self._configs["pretrain"] = None + self._configs["opt_algo"] = "adam" + self._configs["chunk_sizes"] = None + + # Directories + self._configs["data_dir"] = "./data" + self._configs["cache_dir"] = "./cache" + self._configs["config_dir"] = "./config" + self._configs["result_dir"] = "./results" + + # Split + self._configs["train_split"] = "training" + self._configs["val_split"] = "validation" + self._configs["test_split"] = "testdev" + + # Rng + self._configs["data_rng"] = np.random.RandomState(123) + self._configs["nnet_rng"] = np.random.RandomState(317) + + @property + def chunk_sizes(self): + return self._configs["chunk_sizes"] + + @property + def train_split(self): + return self._configs["train_split"] + + @property + def val_split(self): + return self._configs["val_split"] + + @property + def test_split(self): + return self._configs["test_split"] + + @property + def full(self): + return self._configs + + @property + def sampling_function(self): + return self._configs["sampling_function"] + + @property + def data_rng(self): + return self._configs["data_rng"] + + @property + def nnet_rng(self): + return self._configs["nnet_rng"] + + @property + def opt_algo(self): + return self._configs["opt_algo"] + + @property + def prefetch_size(self): + return self._configs["prefetch_size"] + + @property + def pretrain(self): + return self._configs["pretrain"] + + @property + def result_dir(self): + result_dir = os.path.join(self._configs["result_dir"], self.snapshot_name) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + return result_dir + + @property + def dataset(self): + return self._configs["dataset"] + + @property + def snapshot_name(self): + return self._configs["snapshot_name"] + + @property + def snapshot_dir(self): + snapshot_dir = os.path.join(self.cache_dir, "nnet", self.snapshot_name) + + if not os.path.exists(snapshot_dir): + os.makedirs(snapshot_dir) + return snapshot_dir + + @property + def snapshot_file(self): + snapshot_file = os.path.join(self.snapshot_dir, self.snapshot_name + "_{}.pkl") + return snapshot_file + + @property + def config_dir(self): + return self._configs["config_dir"] + + @property + def batch_size(self): + return self._configs["batch_size"] + + @property + def max_iter(self): + return self._configs["max_iter"] + + @property + def learning_rate(self): + return self._configs["learning_rate"] + + @property + def decay_rate(self): + return self._configs["decay_rate"] + + @property + def stepsize(self): + return self._configs["stepsize"] + + @property + def snapshot(self): + return self._configs["snapshot"] + + @property + def display(self): + return self._configs["display"] + + @property + def val_iter(self): + return self._configs["val_iter"] + + @property + def data_dir(self): + return self._configs["data_dir"] + + @property + def cache_dir(self): + if not os.path.exists(self._configs["cache_dir"]): + os.makedirs(self._configs["cache_dir"]) + return self._configs["cache_dir"] + + def update_config(self, new): + for key in new: + if key in self._configs: + self._configs[key] = new[key] + return self diff --git a/core/dbs/__init__.py b/core/dbs/__init__.py new file mode 100755 index 0000000..fa7f7b3 --- /dev/null +++ b/core/dbs/__init__.py @@ -0,0 +1,6 @@ +from .coco import COCO + +datasets = { + "COCO": COCO +} + diff --git a/core/dbs/base.py b/core/dbs/base.py new file mode 100644 index 0000000..e627b11 --- /dev/null +++ b/core/dbs/base.py @@ -0,0 +1,72 @@ +import os +import numpy as np + +class BASE(object): + def __init__(self): + self._split = None + self._db_inds = [] + self._image_ids = [] + + self._mean = np.zeros((3, ), dtype=np.float32) + self._std = np.ones((3, ), dtype=np.float32) + self._eig_val = np.ones((3, ), dtype=np.float32) + self._eig_vec = np.zeros((3, 3), dtype=np.float32) + + self._configs = {} + self._configs["data_aug"] = True + + self._data_rng = None + + @property + def configs(self): + return self._configs + + @property + def mean(self): + return self._mean + + @property + def std(self): + return self._std + + @property + def eig_val(self): + return self._eig_val + + @property + def eig_vec(self): + return self._eig_vec + + @property + def db_inds(self): + return self._db_inds + + @property + def split(self): + return self._split + + def update_config(self, new): + for key in new: + if key in self._configs: + self._configs[key] = new[key] + + def image_ids(self, ind): + return self._image_ids[ind] + + def image_path(self, ind): + pass + + def write_result(self, ind, all_bboxes, all_scores): + pass + + def evaluate(self, name): + pass + + def shuffle_inds(self, quiet=False): + if self._data_rng is None: + self._data_rng = np.random.RandomState(os.getpid()) + + if not quiet: + print("shuffling indices...") + rand_perm = self._data_rng.permutation(len(self._db_inds)) + self._db_inds = self._db_inds[rand_perm] diff --git a/core/dbs/coco.py b/core/dbs/coco.py new file mode 100755 index 0000000..5969581 --- /dev/null +++ b/core/dbs/coco.py @@ -0,0 +1,169 @@ +import os +import json +import numpy as np + +from .detection import DETECTION +from ..paths import get_file_path + +# COCO bounding boxes are 0-indexed + +class COCO(DETECTION): + def __init__(self, db_config, split=None, sys_config=None): + assert split is None or sys_config is not None + super(COCO, self).__init__(db_config) + + self._mean = np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32) + self._std = np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32) + self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32) + self._eig_vec = np.array([ + [-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938] + ], dtype=np.float32) + + self._coco_cls_ids = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 84, 85, 86, 87, 88, 89, 90 + ] + + self._coco_cls_names = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', + 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', + 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', + 'bear', 'zebra','giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard','sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', + 'toothbrush' + ] + + self._cls2coco = {ind + 1: coco_id for ind, coco_id in enumerate(self._coco_cls_ids)} + self._coco2cls = {coco_id: cls_id for cls_id, coco_id in self._cls2coco.items()} + self._coco2name = {cls_id: cls_name for cls_id, cls_name in zip(self._coco_cls_ids, self._coco_cls_names)} + self._name2coco = {cls_name: cls_id for cls_name, cls_id in self._coco2name.items()} + + if split is not None: + coco_dir = os.path.join(sys_config.data_dir, "coco") + + self._split = { + "trainval": "trainval2014", + "minival": "minival2014", + "testdev": "testdev2017" + }[split] + self._data_dir = os.path.join(coco_dir, "images", self._split) + self._anno_file = os.path.join(coco_dir, "annotations", "instances_{}.json".format(self._split)) + + self._detections, self._eval_ids = self._load_coco_annos() + self._image_ids = list(self._detections.keys()) + self._db_inds = np.arange(len(self._image_ids)) + + def _load_coco_annos(self): + from pycocotools.coco import COCO + + coco = COCO(self._anno_file) + self._coco = coco + + class_ids = coco.getCatIds() + image_ids = coco.getImgIds() + + eval_ids = {} + detections = {} + for image_id in image_ids: + image = coco.loadImgs(image_id)[0] + dets = [] + + eval_ids[image["file_name"]] = image_id + for class_id in class_ids: + annotation_ids = coco.getAnnIds(imgIds=image["id"], catIds=class_id) + annotations = coco.loadAnns(annotation_ids) + category = self._coco2cls[class_id] + for annotation in annotations: + det = annotation["bbox"] + [category] + det[2] += det[0] + det[3] += det[1] + dets.append(det) + + file_name = image["file_name"] + if len(dets) == 0: + detections[file_name] = np.zeros((0, 5), dtype=np.float32) + else: + detections[file_name] = np.array(dets, dtype=np.float32) + return detections, eval_ids + + def image_path(self, ind): + if self._data_dir is None: + raise ValueError("Data directory is not set") + + db_ind = self._db_inds[ind] + file_name = self._image_ids[db_ind] + return os.path.join(self._data_dir, file_name) + + def detections(self, ind): + db_ind = self._db_inds[ind] + file_name = self._image_ids[db_ind] + return self._detections[file_name].copy() + + def cls2name(self, cls): + coco = self._cls2coco[cls] + return self._coco2name[coco] + + def _to_float(self, x): + return float("{:.2f}".format(x)) + + def convert_to_coco(self, all_bboxes): + detections = [] + for image_id in all_bboxes: + coco_id = self._eval_ids[image_id] + for cls_ind in all_bboxes[image_id]: + category_id = self._cls2coco[cls_ind] + for bbox in all_bboxes[image_id][cls_ind]: + bbox[2] -= bbox[0] + bbox[3] -= bbox[1] + + score = bbox[4] + bbox = list(map(self._to_float, bbox[0:4])) + + detection = { + "image_id": coco_id, + "category_id": category_id, + "bbox": bbox, + "score": float("{:.2f}".format(score)) + } + + detections.append(detection) + return detections + + def evaluate(self, result_json, cls_ids, image_ids): + from pycocotools.cocoeval import COCOeval + + if self._split == "testdev": + return None + + coco = self._coco + + eval_ids = [self._eval_ids[image_id] for image_id in image_ids] + cat_ids = [self._cls2coco[cls_id] for cls_id in cls_ids] + + coco_dets = coco.loadRes(result_json) + coco_eval = COCOeval(coco, coco_dets, "bbox") + coco_eval.params.imgIds = eval_ids + coco_eval.params.catIds = cat_ids + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval.stats[0], coco_eval.stats[12:] diff --git a/core/dbs/detection.py b/core/dbs/detection.py new file mode 100644 index 0000000..6d90aa7 --- /dev/null +++ b/core/dbs/detection.py @@ -0,0 +1,70 @@ +import numpy as np + +from .base import BASE + +class DETECTION(BASE): + def __init__(self, db_config): + super(DETECTION, self).__init__() + + # Configs for training + self._configs["categories"] = 80 + self._configs["rand_scales"] = [1] + self._configs["rand_scale_min"] = 0.8 + self._configs["rand_scale_max"] = 1.4 + self._configs["rand_scale_step"] = 0.2 + + # Configs for both training and testing + self._configs["input_size"] = [383, 383] + self._configs["output_sizes"] = [[96, 96], [48, 48], [24, 24], [12, 12]] + + self._configs["score_threshold"] = 0.05 + self._configs["nms_threshold"] = 0.7 + self._configs["max_per_set"] = 40 + self._configs["max_per_image"] = 100 + self._configs["top_k"] = 20 + self._configs["ae_threshold"] = 1 + self._configs["nms_kernel"] = 3 + self._configs["num_dets"] = 1000 + + self._configs["nms_algorithm"] = "exp_soft_nms" + self._configs["weight_exp"] = 8 + self._configs["merge_bbox"] = False + + self._configs["data_aug"] = True + self._configs["lighting"] = True + + self._configs["border"] = 64 + self._configs["gaussian_bump"] = False + self._configs["gaussian_iou"] = 0.7 + self._configs["gaussian_radius"] = -1 + self._configs["rand_crop"] = False + self._configs["rand_color"] = False + self._configs["rand_center"] = True + + self._configs["init_sizes"] = [192, 255] + self._configs["view_sizes"] = [] + + self._configs["min_scale"] = 16 + self._configs["max_scale"] = 32 + + self._configs["att_sizes"] = [[16, 16], [32, 32], [64, 64]] + self._configs["att_ranges"] = [[96, 256], [32, 96], [0, 32]] + self._configs["att_ratios"] = [16, 8, 4] + self._configs["att_scales"] = [1, 1.5, 2] + self._configs["att_thresholds"] = [0.3, 0.3, 0.3, 0.3] + self._configs["att_nms_ks"] = [3, 3, 3] + self._configs["att_max_crops"] = 8 + self._configs["ref_dets"] = True + + # Configs for testing + self._configs["test_scales"] = [1] + self._configs["test_flipped"] = True + + self.update_config(db_config) + + if self._configs["rand_scales"] is None: + self._configs["rand_scales"] = np.arange( + self._configs["rand_scale_min"], + self._configs["rand_scale_max"], + self._configs["rand_scale_step"] + ) diff --git a/core/detectors.py b/core/detectors.py new file mode 100644 index 0000000..9307abf --- /dev/null +++ b/core/detectors.py @@ -0,0 +1,49 @@ +from .base import Base, load_cfg, load_nnet +from .paths import get_file_path +from .config import SystemConfig +from .dbs.coco import COCO + +class CornerNet(Base): + def __init__(self): + from .test.cornernet import cornernet_inference + from .models.CornerNet import model + + cfg_path = get_file_path("..", "configs", "CornerNet.json") + model_path = get_file_path("..", "cache", "nnet", "CornerNet", "CornerNet_500000.pkl") + + cfg_sys, cfg_db = load_cfg(cfg_path) + sys_cfg = SystemConfig().update_config(cfg_sys) + coco = COCO(cfg_db) + + cornernet = load_nnet(sys_cfg, model()) + super(CornerNet, self).__init__(coco, cornernet, cornernet_inference, model=model_path) + +class CornerNet_Squeeze(Base): + def __init__(self): + from .test.cornernet import cornernet_inference + from .models.CornerNet_Squeeze import model + + cfg_path = get_file_path("..", "configs", "CornerNet_Squeeze.json") + model_path = get_file_path("..", "cache", "nnet", "CornerNet_Squeeze", "CornerNet_Squeeze_500000.pkl") + + cfg_sys, cfg_db = load_cfg(cfg_path) + sys_cfg = SystemConfig().update_config(cfg_sys) + coco = COCO(cfg_db) + + cornernet = load_nnet(sys_cfg, model()) + super(CornerNet_Squeeze, self).__init__(coco, cornernet, cornernet_inference, model=model_path) + +class CornerNet_Saccade(Base): + def __init__(self): + from .test.cornernet_saccade import cornernet_saccade_inference + from .models.CornerNet_Saccade import model + + cfg_path = get_file_path("..", "configs", "CornerNet_Saccade.json") + model_path = get_file_path("..", "cache", "nnet", "CornerNet_Saccade", "CornerNet_Saccade_500000.pkl") + + cfg_sys, cfg_db = load_cfg(cfg_path) + sys_cfg = SystemConfig().update_config(cfg_sys) + coco = COCO(cfg_db) + + cornernet = load_nnet(sys_cfg, model()) + super(CornerNet_Saccade, self).__init__(coco, cornernet, cornernet_saccade_inference, model=model_path) diff --git a/core/external/.gitignore b/core/external/.gitignore new file mode 100644 index 0000000..f7c8c1a --- /dev/null +++ b/core/external/.gitignore @@ -0,0 +1,7 @@ +bbox.c +bbox.cpython-35m-x86_64-linux-gnu.so +bbox.cpython-36m-x86_64-linux-gnu.so + +nms.c +nms.cpython-35m-x86_64-linux-gnu.so +nms.cpython-36m-x86_64-linux-gnu.so diff --git a/core/external/Makefile b/core/external/Makefile new file mode 100644 index 0000000..a482398 --- /dev/null +++ b/core/external/Makefile @@ -0,0 +1,3 @@ +all: + python setup.py build_ext --inplace + rm -rf build diff --git a/core/external/__init__.py b/core/external/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/external/bbox.pyx b/core/external/bbox.pyx new file mode 100644 index 0000000..e14780d --- /dev/null +++ b/core/external/bbox.pyx @@ -0,0 +1,55 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Sergey Karayev +# -------------------------------------------------------- + +cimport cython +import numpy as np +cimport numpy as np + +DTYPE = np.float +ctypedef np.float_t DTYPE_t + +def bbox_overlaps( + np.ndarray[DTYPE_t, ndim=2] boxes, + np.ndarray[DTYPE_t, ndim=2] query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray of float + query_boxes: (K, 4) ndarray of float + Returns + ------- + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + cdef unsigned int N = boxes.shape[0] + cdef unsigned int K = query_boxes.shape[0] + cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) + cdef DTYPE_t iw, ih, box_area + cdef DTYPE_t ua + cdef unsigned int k, n + for k in range(K): + box_area = ( + (query_boxes[k, 2] - query_boxes[k, 0] + 1) * + (query_boxes[k, 3] - query_boxes[k, 1] + 1) + ) + for n in range(N): + iw = ( + min(boxes[n, 2], query_boxes[k, 2]) - + max(boxes[n, 0], query_boxes[k, 0]) + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 3], query_boxes[k, 3]) - + max(boxes[n, 1], query_boxes[k, 1]) + 1 + ) + if ih > 0: + ua = float( + (boxes[n, 2] - boxes[n, 0] + 1) * + (boxes[n, 3] - boxes[n, 1] + 1) + + box_area - iw * ih + ) + overlaps[n, k] = iw * ih / ua + return overlaps diff --git a/core/external/nms.pyx b/core/external/nms.pyx new file mode 100644 index 0000000..2916328 --- /dev/null +++ b/core/external/nms.pyx @@ -0,0 +1,279 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import numpy as np +cimport numpy as np + +cdef inline np.float32_t max(np.float32_t a, np.float32_t b): + return a if a >= b else b + +cdef inline np.float32_t min(np.float32_t a, np.float32_t b): + return a if a <= b else b + +def nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): + cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] + cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] + cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] + cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] + cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] + + cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) + cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] + + cdef int ndets = dets.shape[0] + cdef np.ndarray[np.int_t, ndim=1] suppressed = \ + np.zeros((ndets), dtype=np.int) + + # nominal indices + cdef int _i, _j + # sorted indices + cdef int i, j + # temp variables for box i's (the box currently under consideration) + cdef np.float32_t ix1, iy1, ix2, iy2, iarea + # variables for computing overlap with box j (lower scoring box) + cdef np.float32_t xx1, yy1, xx2, yy2 + cdef np.float32_t w, h + cdef np.float32_t inter, ovr + + keep = [] + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + keep.append(i) + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + + return keep + +def soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + keep = [i for i in range(N)] + return keep + +def soft_nms_merge(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0, float weight_exp=6): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + cdef float mx1,mx2,my1,my2,mts,mbs,mw + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + mx1 = boxes[i, 0] * boxes[i, 5] + my1 = boxes[i, 1] * boxes[i, 5] + mx2 = boxes[i, 2] * boxes[i, 6] + my2 = boxes[i, 3] * boxes[i, 6] + mts = boxes[i, 5] + mbs = boxes[i, 6] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + mw = (1 - weight) ** weight_exp + mx1 = mx1 + boxes[pos, 0] * boxes[pos, 5] * mw + my1 = my1 + boxes[pos, 1] * boxes[pos, 5] * mw + mx2 = mx2 + boxes[pos, 2] * boxes[pos, 6] * mw + my2 = my2 + boxes[pos, 3] * boxes[pos, 6] * mw + mts = mts + boxes[pos, 5] * mw + mbs = mbs + boxes[pos, 6] * mw + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + boxes[i, 0] = mx1 / mts + boxes[i, 1] = my1 / mts + boxes[i, 2] = mx2 / mbs + boxes[i, 3] = my2 / mbs + + keep = [i for i in range(N)] + return keep diff --git a/core/external/setup.py b/core/external/setup.py new file mode 100644 index 0000000..ca3bd04 --- /dev/null +++ b/core/external/setup.py @@ -0,0 +1,23 @@ +import numpy +from distutils.core import setup +from distutils.extension import Extension +from Cython.Build import cythonize + +extensions = [ + Extension( + "bbox", + ["bbox.pyx"], + extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] + ), + Extension( + "nms", + ["nms.pyx"], + extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] + ) +] + +setup( + name="coco", + ext_modules=cythonize(extensions), + include_dirs=[numpy.get_include()] +) diff --git a/core/models/CornerNet.py b/core/models/CornerNet.py new file mode 100755 index 0000000..6f29940 --- /dev/null +++ b/core/models/CornerNet.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn + +from .py_utils import TopPool, BottomPool, LeftPool, RightPool + +from .py_utils.utils import convolution, residual, corner_pool +from .py_utils.losses import CornerNet_Loss +from .py_utils.modules import hg_module, hg, hg_net + +def make_pool_layer(dim): + return nn.Sequential() + +def make_hg_layer(inp_dim, out_dim, modules): + layers = [residual(inp_dim, out_dim, stride=2)] + layers += [residual(out_dim, out_dim) for _ in range(1, modules)] + return nn.Sequential(*layers) + +class model(hg_net): + def _pred_mod(self, dim): + return nn.Sequential( + convolution(3, 256, 256, with_bn=False), + nn.Conv2d(256, dim, (1, 1)) + ) + + def _merge_mod(self): + return nn.Sequential( + nn.Conv2d(256, 256, (1, 1), bias=False), + nn.BatchNorm2d(256) + ) + + def __init__(self): + stacks = 2 + pre = nn.Sequential( + convolution(7, 3, 128, stride=2), + residual(128, 256, stride=2) + ) + hg_mods = nn.ModuleList([ + hg_module( + 5, [256, 256, 384, 384, 384, 512], [2, 2, 2, 2, 2, 4], + make_pool_layer=make_pool_layer, + make_hg_layer=make_hg_layer + ) for _ in range(stacks) + ]) + cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)]) + inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)]) + cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + + hgs = hg(pre, hg_mods, cnvs, inters, cnvs_, inters_) + + tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)]) + br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)]) + + tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + for tl_heat, br_heat in zip(tl_heats, br_heats): + torch.nn.init.constant_(tl_heat[-1].bias, -2.19) + torch.nn.init.constant_(br_heat[-1].bias, -2.19) + + tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + + tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + + super(model, self).__init__( + hgs, tl_modules, br_modules, tl_heats, br_heats, + tl_tags, br_tags, tl_offs, br_offs + ) + + self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1) diff --git a/core/models/CornerNet_Saccade.py b/core/models/CornerNet_Saccade.py new file mode 100644 index 0000000..549bd2c --- /dev/null +++ b/core/models/CornerNet_Saccade.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from .py_utils import TopPool, BottomPool, LeftPool, RightPool + +from .py_utils.utils import convolution, residual, corner_pool +from .py_utils.losses import CornerNet_Saccade_Loss +from .py_utils.modules import saccade_net, saccade_module, saccade + +def make_pool_layer(dim): + return nn.Sequential() + +def make_hg_layer(inp_dim, out_dim, modules): + layers = [residual(inp_dim, out_dim, stride=2)] + layers += [residual(out_dim, out_dim) for _ in range(1, modules)] + return nn.Sequential(*layers) + +class model(saccade_net): + def _pred_mod(self, dim): + return nn.Sequential( + convolution(3, 256, 256, with_bn=False), + nn.Conv2d(256, dim, (1, 1)) + ) + + def _merge_mod(self): + return nn.Sequential( + nn.Conv2d(256, 256, (1, 1), bias=False), + nn.BatchNorm2d(256) + ) + + def __init__(self): + stacks = 3 + pre = nn.Sequential( + convolution(7, 3, 128, stride=2), + residual(128, 256, stride=2) + ) + hg_mods = nn.ModuleList([ + saccade_module( + 3, [256, 384, 384, 512], [1, 1, 1, 1], + make_pool_layer=make_pool_layer, + make_hg_layer=make_hg_layer + ) for _ in range(stacks) + ]) + cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)]) + inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)]) + cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + + att_mods = nn.ModuleList([ + nn.ModuleList([ + nn.Sequential( + convolution(3, 384, 256, with_bn=False), + nn.Conv2d(256, 1, (1, 1)) + ), + nn.Sequential( + convolution(3, 384, 256, with_bn=False), + nn.Conv2d(256, 1, (1, 1)) + ), + nn.Sequential( + convolution(3, 256, 256, with_bn=False), + nn.Conv2d(256, 1, (1, 1)) + ) + ]) for _ in range(stacks) + ]) + for att_mod in att_mods: + for att in att_mod: + torch.nn.init.constant_(att[-1].bias, -2.19) + + hgs = saccade(pre, hg_mods, cnvs, inters, cnvs_, inters_) + + tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)]) + br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)]) + + tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + for tl_heat, br_heat in zip(tl_heats, br_heats): + torch.nn.init.constant_(tl_heat[-1].bias, -2.19) + torch.nn.init.constant_(br_heat[-1].bias, -2.19) + + tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + + tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + + super(model, self).__init__( + hgs, tl_modules, br_modules, tl_heats, br_heats, + tl_tags, br_tags, tl_offs, br_offs, att_mods + ) + + self.loss = CornerNet_Saccade_Loss(pull_weight=1e-1, push_weight=1e-1) diff --git a/core/models/CornerNet_Squeeze.py b/core/models/CornerNet_Squeeze.py new file mode 100644 index 0000000..4003a15 --- /dev/null +++ b/core/models/CornerNet_Squeeze.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from .py_utils import TopPool, BottomPool, LeftPool, RightPool + +from .py_utils.utils import convolution, corner_pool, residual +from .py_utils.losses import CornerNet_Loss +from .py_utils.modules import hg_module, hg, hg_net + +class fire_module(nn.Module): + def __init__(self, inp_dim, out_dim, sr=2, stride=1): + super(fire_module, self).__init__() + self.conv1 = nn.Conv2d(inp_dim, out_dim // sr, kernel_size=1, stride=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_dim // sr) + self.conv_1x1 = nn.Conv2d(out_dim // sr, out_dim // 2, kernel_size=1, stride=stride, bias=False) + self.conv_3x3 = nn.Conv2d(out_dim // sr, out_dim // 2, kernel_size=3, padding=1, + stride=stride, groups=out_dim // sr, bias=False) + self.bn2 = nn.BatchNorm2d(out_dim) + self.skip = (stride == 1 and inp_dim == out_dim) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + conv1 = self.conv1(x) + bn1 = self.bn1(conv1) + conv2 = torch.cat((self.conv_1x1(bn1), self.conv_3x3(bn1)), 1) + bn2 = self.bn2(conv2) + if self.skip: + return self.relu(bn2 + x) + else: + return self.relu(bn2) + +def make_pool_layer(dim): + return nn.Sequential() + +def make_unpool_layer(dim): + return nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1) + +def make_layer(inp_dim, out_dim, modules): + layers = [fire_module(inp_dim, out_dim)] + layers += [fire_module(out_dim, out_dim) for _ in range(1, modules)] + return nn.Sequential(*layers) + +def make_layer_revr(inp_dim, out_dim, modules): + layers = [fire_module(inp_dim, inp_dim) for _ in range(modules - 1)] + layers += [fire_module(inp_dim, out_dim)] + return nn.Sequential(*layers) + +def make_hg_layer(inp_dim, out_dim, modules): + layers = [fire_module(inp_dim, out_dim, stride=2)] + layers += [fire_module(out_dim, out_dim) for _ in range(1, modules)] + return nn.Sequential(*layers) + +class model(hg_net): + def _pred_mod(self, dim): + return nn.Sequential( + convolution(1, 256, 256, with_bn=False), + nn.Conv2d(256, dim, (1, 1)) + ) + + def _merge_mod(self): + return nn.Sequential( + nn.Conv2d(256, 256, (1, 1), bias=False), + nn.BatchNorm2d(256) + ) + + def __init__(self): + stacks = 2 + pre = nn.Sequential( + convolution(7, 3, 128, stride=2), + residual(128, 256, stride=2), + residual(256, 256, stride=2) + ) + hg_mods = nn.ModuleList([ + hg_module( + 4, [256, 256, 384, 384, 512], [2, 2, 2, 2, 4], + make_pool_layer=make_pool_layer, + make_unpool_layer=make_unpool_layer, + make_up_layer=make_layer, + make_low_layer=make_layer, + make_hg_layer_revr=make_layer_revr, + make_hg_layer=make_hg_layer + ) for _ in range(stacks) + ]) + cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)]) + inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)]) + cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) + + hgs = hg(pre, hg_mods, cnvs, inters, cnvs_, inters_) + + tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)]) + br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)]) + + tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)]) + for tl_heat, br_heat in zip(tl_heats, br_heats): + torch.nn.init.constant_(tl_heat[-1].bias, -2.19) + torch.nn.init.constant_(br_heat[-1].bias, -2.19) + + tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)]) + + tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)]) + + super(model, self).__init__( + hgs, tl_modules, br_modules, tl_heats, br_heats, + tl_tags, br_tags, tl_offs, br_offs + ) + + self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1) diff --git a/core/models/__init__.py b/core/models/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/core/models/py_utils/__init__.py b/core/models/py_utils/__init__.py new file mode 100644 index 0000000..c1fde01 --- /dev/null +++ b/core/models/py_utils/__init__.py @@ -0,0 +1 @@ +from ._cpools import TopPool, BottomPool, LeftPool, RightPool diff --git a/core/models/py_utils/_cpools/.gitignore b/core/models/py_utils/_cpools/.gitignore new file mode 100644 index 0000000..6a0882d --- /dev/null +++ b/core/models/py_utils/_cpools/.gitignore @@ -0,0 +1,3 @@ +build/ +cpools.egg-info/ +dist/ diff --git a/core/models/py_utils/_cpools/__init__.py b/core/models/py_utils/_cpools/__init__.py new file mode 100644 index 0000000..1b4e76f --- /dev/null +++ b/core/models/py_utils/_cpools/__init__.py @@ -0,0 +1,74 @@ +import torch + +from torch import nn +from torch.autograd import Function + +import top_pool, bottom_pool, left_pool, right_pool + +class TopPoolFunction(Function): + @staticmethod + def forward(ctx, input): + output = top_pool.forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = top_pool.backward(input, grad_output)[0] + return output + +class BottomPoolFunction(Function): + @staticmethod + def forward(ctx, input): + output = bottom_pool.forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = bottom_pool.backward(input, grad_output)[0] + return output + +class LeftPoolFunction(Function): + @staticmethod + def forward(ctx, input): + output = left_pool.forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = left_pool.backward(input, grad_output)[0] + return output + +class RightPoolFunction(Function): + @staticmethod + def forward(ctx, input): + output = right_pool.forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = right_pool.backward(input, grad_output)[0] + return output + +class TopPool(nn.Module): + def forward(self, x): + return TopPoolFunction.apply(x) + +class BottomPool(nn.Module): + def forward(self, x): + return BottomPoolFunction.apply(x) + +class LeftPool(nn.Module): + def forward(self, x): + return LeftPoolFunction.apply(x) + +class RightPool(nn.Module): + def forward(self, x): + return RightPoolFunction.apply(x) diff --git a/core/models/py_utils/_cpools/setup.py b/core/models/py_utils/_cpools/setup.py new file mode 100644 index 0000000..9682833 --- /dev/null +++ b/core/models/py_utils/_cpools/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension + +setup( + name="cpools", + ext_modules=[ + CppExtension("top_pool", ["src/top_pool.cpp"]), + CppExtension("bottom_pool", ["src/bottom_pool.cpp"]), + CppExtension("left_pool", ["src/left_pool.cpp"]), + CppExtension("right_pool", ["src/right_pool.cpp"]) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/core/models/py_utils/_cpools/src/bottom_pool.cpp b/core/models/py_utils/_cpools/src/bottom_pool.cpp new file mode 100644 index 0000000..8a20a43 --- /dev/null +++ b/core/models/py_utils/_cpools/src/bottom_pool.cpp @@ -0,0 +1,80 @@ +#include + +#include + +std::vector pool_forward( + at::Tensor input +) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get height + int64_t height = input.size(2); + + output.copy_(input); + + for (int64_t ind = 1; ind < height; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 2, ind, height); + at::Tensor cur_temp = at::slice(output, 2, ind, height); + at::Tensor next_temp = at::slice(output, 2, 0, height-ind); + at::max_out(max_temp, cur_temp, next_temp); + } + + return { + output + }; +} + +std::vector pool_backward( + at::Tensor input, + at::Tensor grad_output +) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(2, 0); + max_val.copy_(input_temp); + + max_ind.fill_(0); + + auto output_temp = output.select(2, 0); + auto grad_output_temp = grad_output.select(2, 0); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(2); + auto gt_mask = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kByte)); + auto max_temp = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 0; ind < height - 1; ++ind) { + input_temp = input.select(2, ind + 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, ind + 1); + + grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2); + output.scatter_add_(2, un_max_ind, grad_output_temp); + } + + return { + output + }; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", &pool_forward, "Bottom Pool Forward", + py::call_guard() + ); + m.def( + "backward", &pool_backward, "Bottom Pool Backward", + py::call_guard() + ); +} diff --git a/core/models/py_utils/_cpools/src/left_pool.cpp b/core/models/py_utils/_cpools/src/left_pool.cpp new file mode 100644 index 0000000..c36fc1b --- /dev/null +++ b/core/models/py_utils/_cpools/src/left_pool.cpp @@ -0,0 +1,80 @@ +#include + +#include + +std::vector pool_forward( + at::Tensor input +) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get width + int64_t width = input.size(3); + + output.copy_(input); + + for (int64_t ind = 1; ind < width; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 3, 0, width-ind); + at::Tensor cur_temp = at::slice(output, 3, 0, width-ind); + at::Tensor next_temp = at::slice(output, 3, ind, width); + at::max_out(max_temp, cur_temp, next_temp); + } + + return { + output + }; +} + +std::vector pool_backward( + at::Tensor input, + at::Tensor grad_output +) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(3, width - 1); + max_val.copy_(input_temp); + + max_ind.fill_(width - 1); + + auto output_temp = output.select(3, width - 1); + auto grad_output_temp = grad_output.select(3, width - 1); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(3); + auto gt_mask = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kByte)); + auto max_temp = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 1; ind < width; ++ind) { + input_temp = input.select(3, width - ind - 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, width - ind - 1); + + grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3); + output.scatter_add_(3, un_max_ind, grad_output_temp); + } + + return { + output + }; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", &pool_forward, "Left Pool Forward", + py::call_guard() + ); + m.def( + "backward", &pool_backward, "Left Pool Backward", + py::call_guard() + ); +} diff --git a/core/models/py_utils/_cpools/src/right_pool.cpp b/core/models/py_utils/_cpools/src/right_pool.cpp new file mode 100644 index 0000000..1b2da43 --- /dev/null +++ b/core/models/py_utils/_cpools/src/right_pool.cpp @@ -0,0 +1,80 @@ +#include + +#include + +std::vector pool_forward( + at::Tensor input +) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get width + int64_t width = input.size(3); + + output.copy_(input); + + for (int64_t ind = 1; ind < width; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 3, ind, width); + at::Tensor cur_temp = at::slice(output, 3, ind, width); + at::Tensor next_temp = at::slice(output, 3, 0, width-ind); + at::max_out(max_temp, cur_temp, next_temp); + } + + return { + output + }; +} + +std::vector pool_backward( + at::Tensor input, + at::Tensor grad_output +) { + at::Tensor output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(3, 0); + max_val.copy_(input_temp); + + max_ind.fill_(0); + + auto output_temp = output.select(3, 0); + auto grad_output_temp = grad_output.select(3, 0); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(3); + auto gt_mask = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kByte)); + auto max_temp = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 0; ind < width - 1; ++ind) { + input_temp = input.select(3, ind + 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, ind + 1); + + grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3); + output.scatter_add_(3, un_max_ind, grad_output_temp); + } + + return { + output + }; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", &pool_forward, "Right Pool Forward", + py::call_guard() + ); + m.def( + "backward", &pool_backward, "Right Pool Backward", + py::call_guard() + ); +} diff --git a/core/models/py_utils/_cpools/src/top_pool.cpp b/core/models/py_utils/_cpools/src/top_pool.cpp new file mode 100644 index 0000000..bc63c49 --- /dev/null +++ b/core/models/py_utils/_cpools/src/top_pool.cpp @@ -0,0 +1,80 @@ +#include + +#include + +std::vector top_pool_forward( + at::Tensor input +) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get height + int64_t height = input.size(2); + + output.copy_(input); + + for (int64_t ind = 1; ind < height; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 2, 0, height-ind); + at::Tensor cur_temp = at::slice(output, 2, 0, height-ind); + at::Tensor next_temp = at::slice(output, 2, ind, height); + at::max_out(max_temp, cur_temp, next_temp); + } + + return { + output + }; +} + +std::vector top_pool_backward( + at::Tensor input, + at::Tensor grad_output +) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(2, height - 1); + max_val.copy_(input_temp); + + max_ind.fill_(height - 1); + + auto output_temp = output.select(2, height - 1); + auto grad_output_temp = grad_output.select(2, height - 1); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(2); + auto gt_mask = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kByte)); + auto max_temp = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 1; ind < height; ++ind) { + input_temp = input.select(2, height - ind - 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, height - ind - 1); + + grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2); + output.scatter_add_(2, un_max_ind, grad_output_temp); + } + + return { + output + }; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", &top_pool_forward, "Top Pool Forward", + py::call_guard() + ); + m.def( + "backward", &top_pool_backward, "Top Pool Backward", + py::call_guard() + ); +} diff --git a/core/models/py_utils/data_parallel.py b/core/models/py_utils/data_parallel.py new file mode 100644 index 0000000..cc2a5a8 --- /dev/null +++ b/core/models/py_utils/data_parallel.py @@ -0,0 +1,116 @@ +import torch +from torch.nn.modules import Module +from torch.nn.parallel.scatter_gather import gather +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.parallel_apply import parallel_apply + +from .scatter_gather import scatter_kwargs + +class DataParallel(Module): + r"""Implements data parallelism at the module level. + + This container parallelizes the application of the given module by + splitting the input across the specified devices by chunking in the batch + dimension. In the forward pass, the module is replicated on each device, + and each replica handles a portion of the input. During the backwards + pass, gradients from each replica are summed into the original module. + + The batch size should be larger than the number of GPUs used. It should + also be an integer multiple of the number of GPUs so that each chunk is the + same size (so that each GPU processes the same number of samples). + + See also: :ref:`cuda-nn-dataparallel-instead` + + Arbitrary positional and keyword inputs are allowed to be passed into + DataParallel EXCEPT Tensors. All variables will be scattered on dim + specified (default 0). Primitive types will be broadcasted, but all + other types will be a shallow copy and can be corrupted if written to in + the model's forward pass. + + Args: + module: module to be parallelized + device_ids: CUDA devices (default: all devices) + output_device: device location of output (default: device_ids[0]) + + Example:: + + >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) + >>> output = net(input_var) + """ + + # TODO: update notes/cuda.rst when this class handles 8+ GPUs well + + def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): + super(DataParallel, self).__init__() + + if not torch.cuda.is_available(): + self.module = module + self.device_ids = [] + return + + if device_ids is None: + device_ids = list(range(torch.cuda.device_count())) + if output_device is None: + output_device = device_ids[0] + self.dim = dim + self.module = module + self.device_ids = device_ids + self.chunk_sizes = chunk_sizes + self.output_device = output_device + if len(self.device_ids) == 1: + self.module.cuda(device_ids[0]) + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes) + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def replicate(self, module, device_ids): + return replicate(module, device_ids) + + def scatter(self, inputs, kwargs, device_ids, chunk_sizes): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes) + + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + def gather(self, outputs, output_device): + return gather(outputs, output_device, dim=self.dim) + + +def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): + r"""Evaluates module(input) in parallel across the GPUs given in device_ids. + + This is the functional version of the DataParallel module. + + Args: + module: the module to evaluate in parallel + inputs: inputs to the module + device_ids: GPU ids on which to replicate module + output_device: GPU location of the output Use -1 to indicate the CPU. + (default: device_ids[0]) + Returns: + a Variable containing the result of module(input) located on + output_device + """ + if not isinstance(inputs, tuple): + inputs = (inputs,) + + if device_ids is None: + device_ids = list(range(torch.cuda.device_count())) + + if output_device is None: + output_device = device_ids[0] + + inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + if len(device_ids) == 1: + return module(*inputs[0], **module_kwargs[0]) + used_device_ids = device_ids[:len(inputs)] + replicas = replicate(module, used_device_ids) + outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) + return gather(outputs, output_device, dim) diff --git a/core/models/py_utils/losses.py b/core/models/py_utils/losses.py new file mode 100644 index 0000000..21bffa2 --- /dev/null +++ b/core/models/py_utils/losses.py @@ -0,0 +1,224 @@ +import torch +import torch.nn as nn + +from .utils import _tranpose_and_gather_feat + +def _sigmoid(x): + return torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4) + +def _ae_loss(tag0, tag1, mask): + num = mask.sum(dim=1, keepdim=True).float() + tag0 = tag0.squeeze() + tag1 = tag1.squeeze() + + tag_mean = (tag0 + tag1) / 2 + + tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4) + tag0 = tag0[mask].sum() + tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4) + tag1 = tag1[mask].sum() + pull = tag0 + tag1 + + mask = mask.unsqueeze(1) + mask.unsqueeze(2) + mask = mask.eq(2) + num = num.unsqueeze(2) + num2 = (num - 1) * num + dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2) + dist = 1 - torch.abs(dist) + dist = nn.functional.relu(dist, inplace=True) + dist = dist - 1 / (num + 1e-4) + dist = dist / (num2 + 1e-4) + dist = dist[mask] + push = dist.sum() + return pull, push + +def _off_loss(off, gt_off, mask): + num = mask.float().sum() + mask = mask.unsqueeze(2).expand_as(gt_off) + + off = off[mask] + gt_off = gt_off[mask] + + off_loss = nn.functional.smooth_l1_loss(off, gt_off, reduction="sum") + off_loss = off_loss / (num + 1e-4) + return off_loss + +def _focal_loss_mask(preds, gt, mask): + pos_inds = gt.eq(1) + neg_inds = gt.lt(1) + + neg_weights = torch.pow(1 - gt[neg_inds], 4) + + pos_mask = mask[pos_inds] + neg_mask = mask[neg_inds] + + loss = 0 + for pred in preds: + pos_pred = pred[pos_inds] + neg_pred = pred[neg_inds] + + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * pos_mask + neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights * neg_mask + + num_pos = pos_inds.float().sum() + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if pos_pred.nelement() == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + return loss + +def _focal_loss(preds, gt): + pos_inds = gt.eq(1) + neg_inds = gt.lt(1) + + neg_weights = torch.pow(1 - gt[neg_inds], 4) + + loss = 0 + for pred in preds: + pos_pred = pred[pos_inds] + neg_pred = pred[neg_inds] + + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) + neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights + + num_pos = pos_inds.float().sum() + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if pos_pred.nelement() == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + return loss + +class CornerNet_Saccade_Loss(nn.Module): + def __init__(self, pull_weight=1, push_weight=1, off_weight=1, focal_loss=_focal_loss_mask): + super(CornerNet_Saccade_Loss, self).__init__() + + self.pull_weight = pull_weight + self.push_weight = push_weight + self.off_weight = off_weight + self.focal_loss = focal_loss + self.ae_loss = _ae_loss + self.off_loss = _off_loss + + def forward(self, outs, targets): + tl_heats = outs[0] + br_heats = outs[1] + tl_tags = outs[2] + br_tags = outs[3] + tl_offs = outs[4] + br_offs = outs[5] + atts = outs[6] + + gt_tl_heat = targets[0] + gt_br_heat = targets[1] + gt_mask = targets[2] + gt_tl_off = targets[3] + gt_br_off = targets[4] + gt_tl_ind = targets[5] + gt_br_ind = targets[6] + gt_tl_valid = targets[7] + gt_br_valid = targets[8] + gt_atts = targets[9] + + # focal loss + focal_loss = 0 + + tl_heats = [_sigmoid(t) for t in tl_heats] + br_heats = [_sigmoid(b) for b in br_heats] + + focal_loss += self.focal_loss(tl_heats, gt_tl_heat, gt_tl_valid) + focal_loss += self.focal_loss(br_heats, gt_br_heat, gt_br_valid) + + atts = [[_sigmoid(a) for a in att] for att in atts] + atts = [[att[ind] for att in atts] for ind in range(len(gt_atts))] + + att_loss = 0 + for att, gt_att in zip(atts, gt_atts): + att_loss += _focal_loss(att, gt_att) / max(len(att), 1) + + # tag loss + pull_loss = 0 + push_loss = 0 + tl_tags = [_tranpose_and_gather_feat(tl_tag, gt_tl_ind) for tl_tag in tl_tags] + br_tags = [_tranpose_and_gather_feat(br_tag, gt_br_ind) for br_tag in br_tags] + for tl_tag, br_tag in zip(tl_tags, br_tags): + pull, push = self.ae_loss(tl_tag, br_tag, gt_mask) + pull_loss += pull + push_loss += push + pull_loss = self.pull_weight * pull_loss + push_loss = self.push_weight * push_loss + + off_loss = 0 + tl_offs = [_tranpose_and_gather_feat(tl_off, gt_tl_ind) for tl_off in tl_offs] + br_offs = [_tranpose_and_gather_feat(br_off, gt_br_ind) for br_off in br_offs] + for tl_off, br_off in zip(tl_offs, br_offs): + off_loss += self.off_loss(tl_off, gt_tl_off, gt_mask) + off_loss += self.off_loss(br_off, gt_br_off, gt_mask) + off_loss = self.off_weight * off_loss + + loss = (focal_loss + att_loss + pull_loss + push_loss + off_loss) / max(len(tl_heats), 1) + return loss.unsqueeze(0) + +class CornerNet_Loss(nn.Module): + def __init__(self, pull_weight=1, push_weight=1, off_weight=1, focal_loss=_focal_loss): + super(CornerNet_Loss, self).__init__() + + self.pull_weight = pull_weight + self.push_weight = push_weight + self.off_weight = off_weight + self.focal_loss = focal_loss + self.ae_loss = _ae_loss + self.off_loss = _off_loss + + def forward(self, outs, targets): + tl_heats = outs[0] + br_heats = outs[1] + tl_tags = outs[2] + br_tags = outs[3] + tl_offs = outs[4] + br_offs = outs[5] + + gt_tl_heat = targets[0] + gt_br_heat = targets[1] + gt_mask = targets[2] + gt_tl_off = targets[3] + gt_br_off = targets[4] + gt_tl_ind = targets[5] + gt_br_ind = targets[6] + + # focal loss + focal_loss = 0 + + tl_heats = [_sigmoid(t) for t in tl_heats] + br_heats = [_sigmoid(b) for b in br_heats] + + focal_loss += self.focal_loss(tl_heats, gt_tl_heat) + focal_loss += self.focal_loss(br_heats, gt_br_heat) + + # tag loss + pull_loss = 0 + push_loss = 0 + tl_tags = [_tranpose_and_gather_feat(tl_tag, gt_tl_ind) for tl_tag in tl_tags] + br_tags = [_tranpose_and_gather_feat(br_tag, gt_br_ind) for br_tag in br_tags] + for tl_tag, br_tag in zip(tl_tags, br_tags): + pull, push = self.ae_loss(tl_tag, br_tag, gt_mask) + pull_loss += pull + push_loss += push + pull_loss = self.pull_weight * pull_loss + push_loss = self.push_weight * push_loss + + off_loss = 0 + tl_offs = [_tranpose_and_gather_feat(tl_off, gt_tl_ind) for tl_off in tl_offs] + br_offs = [_tranpose_and_gather_feat(br_off, gt_br_ind) for br_off in br_offs] + for tl_off, br_off in zip(tl_offs, br_offs): + off_loss += self.off_loss(tl_off, gt_tl_off, gt_mask) + off_loss += self.off_loss(br_off, gt_br_off, gt_mask) + off_loss = self.off_weight * off_loss + + loss = (focal_loss + pull_loss + push_loss + off_loss) / max(len(tl_heats), 1) + return loss.unsqueeze(0) diff --git a/core/models/py_utils/modules.py b/core/models/py_utils/modules.py new file mode 100644 index 0000000..d28590b --- /dev/null +++ b/core/models/py_utils/modules.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn + +from .utils import residual, upsample, merge, _decode + +def _make_layer(inp_dim, out_dim, modules): + layers = [residual(inp_dim, out_dim)] + layers += [residual(out_dim, out_dim) for _ in range(1, modules)] + return nn.Sequential(*layers) + +def _make_layer_revr(inp_dim, out_dim, modules): + layers = [residual(inp_dim, inp_dim) for _ in range(modules - 1)] + layers += [residual(inp_dim, out_dim)] + return nn.Sequential(*layers) + +def _make_pool_layer(dim): + return nn.MaxPool2d(kernel_size=2, stride=2) + +def _make_unpool_layer(dim): + return upsample(scale_factor=2) + +def _make_merge_layer(dim): + return merge() + +class hg_module(nn.Module): + def __init__( + self, n, dims, modules, make_up_layer=_make_layer, + make_pool_layer=_make_pool_layer, make_hg_layer=_make_layer, + make_low_layer=_make_layer, make_hg_layer_revr=_make_layer_revr, + make_unpool_layer=_make_unpool_layer, make_merge_layer=_make_merge_layer + ): + super(hg_module, self).__init__() + + curr_mod = modules[0] + next_mod = modules[1] + + curr_dim = dims[0] + next_dim = dims[1] + + self.n = n + self.up1 = make_up_layer(curr_dim, curr_dim, curr_mod) + self.max1 = make_pool_layer(curr_dim) + self.low1 = make_hg_layer(curr_dim, next_dim, curr_mod) + self.low2 = hg_module( + n - 1, dims[1:], modules[1:], + make_up_layer=make_up_layer, + make_pool_layer=make_pool_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_low_layer, + make_hg_layer_revr=make_hg_layer_revr, + make_unpool_layer=make_unpool_layer, + make_merge_layer=make_merge_layer + ) if n > 1 else make_low_layer(next_dim, next_dim, next_mod) + self.low3 = make_hg_layer_revr(next_dim, curr_dim, curr_mod) + self.up2 = make_unpool_layer(curr_dim) + self.merg = make_merge_layer(curr_dim) + + def forward(self, x): + up1 = self.up1(x) + max1 = self.max1(x) + low1 = self.low1(max1) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + merg = self.merg(up1, up2) + return merg + +class hg(nn.Module): + def __init__(self, pre, hg_modules, cnvs, inters, cnvs_, inters_): + super(hg, self).__init__() + + self.pre = pre + self.hgs = hg_modules + self.cnvs = cnvs + + self.inters = inters + self.inters_ = inters_ + self.cnvs_ = cnvs_ + + def forward(self, x): + inter = self.pre(x) + + cnvs = [] + for ind, (hg_, cnv_) in enumerate(zip(self.hgs, self.cnvs)): + hg = hg_(inter) + cnv = cnv_(hg) + cnvs.append(cnv) + + if ind < len(self.hgs) - 1: + inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) + inter = nn.functional.relu_(inter) + inter = self.inters[ind](inter) + return cnvs + +class hg_net(nn.Module): + def __init__( + self, hg, tl_modules, br_modules, tl_heats, br_heats, + tl_tags, br_tags, tl_offs, br_offs + ): + super(hg_net, self).__init__() + + self._decode = _decode + + self.hg = hg + + self.tl_modules = tl_modules + self.br_modules = br_modules + + self.tl_heats = tl_heats + self.br_heats = br_heats + + self.tl_tags = tl_tags + self.br_tags = br_tags + + self.tl_offs = tl_offs + self.br_offs = br_offs + + def _train(self, *xs): + image = xs[0] + cnvs = self.hg(image) + + tl_modules = [tl_mod_(cnv) for tl_mod_, cnv in zip(self.tl_modules, cnvs)] + br_modules = [br_mod_(cnv) for br_mod_, cnv in zip(self.br_modules, cnvs)] + tl_heats = [tl_heat_(tl_mod) for tl_heat_, tl_mod in zip(self.tl_heats, tl_modules)] + br_heats = [br_heat_(br_mod) for br_heat_, br_mod in zip(self.br_heats, br_modules)] + tl_tags = [tl_tag_(tl_mod) for tl_tag_, tl_mod in zip(self.tl_tags, tl_modules)] + br_tags = [br_tag_(br_mod) for br_tag_, br_mod in zip(self.br_tags, br_modules)] + tl_offs = [tl_off_(tl_mod) for tl_off_, tl_mod in zip(self.tl_offs, tl_modules)] + br_offs = [br_off_(br_mod) for br_off_, br_mod in zip(self.br_offs, br_modules)] + return [tl_heats, br_heats, tl_tags, br_tags, tl_offs, br_offs] + + def _test(self, *xs, **kwargs): + image = xs[0] + cnvs = self.hg(image) + + tl_mod = self.tl_modules[-1](cnvs[-1]) + br_mod = self.br_modules[-1](cnvs[-1]) + + tl_heat, br_heat = self.tl_heats[-1](tl_mod), self.br_heats[-1](br_mod) + tl_tag, br_tag = self.tl_tags[-1](tl_mod), self.br_tags[-1](br_mod) + tl_off, br_off = self.tl_offs[-1](tl_mod), self.br_offs[-1](br_mod) + + outs = [tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off] + return self._decode(*outs, **kwargs), tl_heat, br_heat, tl_tag, br_tag + + def forward(self, *xs, test=False, **kwargs): + if not test: + return self._train(*xs, **kwargs) + return self._test(*xs, **kwargs) + +class saccade_module(nn.Module): + def __init__( + self, n, dims, modules, make_up_layer=_make_layer, + make_pool_layer=_make_pool_layer, make_hg_layer=_make_layer, + make_low_layer=_make_layer, make_hg_layer_revr=_make_layer_revr, + make_unpool_layer=_make_unpool_layer, make_merge_layer=_make_merge_layer + ): + super(saccade_module, self).__init__() + + curr_mod = modules[0] + next_mod = modules[1] + + curr_dim = dims[0] + next_dim = dims[1] + + self.n = n + self.up1 = make_up_layer(curr_dim, curr_dim, curr_mod) + self.max1 = make_pool_layer(curr_dim) + self.low1 = make_hg_layer(curr_dim, next_dim, curr_mod) + self.low2 = saccade_module( + n - 1, dims[1:], modules[1:], + make_up_layer=make_up_layer, + make_pool_layer=make_pool_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_low_layer, + make_hg_layer_revr=make_hg_layer_revr, + make_unpool_layer=make_unpool_layer, + make_merge_layer=make_merge_layer + ) if n > 1 else make_low_layer(next_dim, next_dim, next_mod) + self.low3 = make_hg_layer_revr(next_dim, curr_dim, curr_mod) + self.up2 = make_unpool_layer(curr_dim) + self.merg = make_merge_layer(curr_dim) + + def forward(self, x): + up1 = self.up1(x) + max1 = self.max1(x) + low1 = self.low1(max1) + if self.n > 1: + low2, mergs = self.low2(low1) + else: + low2, mergs = self.low2(low1), [] + low3 = self.low3(low2) + up2 = self.up2(low3) + merg = self.merg(up1, up2) + mergs.append(merg) + return merg, mergs + +class saccade(nn.Module): + def __init__(self, pre, hg_modules, cnvs, inters, cnvs_, inters_): + super(saccade, self).__init__() + + self.pre = pre + self.hgs = hg_modules + self.cnvs = cnvs + + self.inters = inters + self.inters_ = inters_ + self.cnvs_ = cnvs_ + + def forward(self, x): + inter = self.pre(x) + + cnvs = [] + atts = [] + for ind, (hg_, cnv_) in enumerate(zip(self.hgs, self.cnvs)): + hg, ups = hg_(inter) + cnv = cnv_(hg) + cnvs.append(cnv) + atts.append(ups) + + if ind < len(self.hgs) - 1: + inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) + inter = nn.functional.relu_(inter) + inter = self.inters[ind](inter) + return cnvs, atts + +class saccade_net(nn.Module): + def __init__( + self, hg, tl_modules, br_modules, tl_heats, br_heats, + tl_tags, br_tags, tl_offs, br_offs, att_modules, up_start=0 + ): + super(saccade_net, self).__init__() + + self._decode = _decode + + self.hg = hg + + self.tl_modules = tl_modules + self.br_modules = br_modules + self.tl_heats = tl_heats + self.br_heats = br_heats + self.tl_tags = tl_tags + self.br_tags = br_tags + self.tl_offs = tl_offs + self.br_offs = br_offs + + self.att_modules = att_modules + self.up_start = up_start + + def _train(self, *xs): + image = xs[0] + + cnvs, ups = self.hg(image) + ups = [up[self.up_start:] for up in ups] + + tl_modules = [tl_mod_(cnv) for tl_mod_, cnv in zip(self.tl_modules, cnvs)] + br_modules = [br_mod_(cnv) for br_mod_, cnv in zip(self.br_modules, cnvs)] + tl_heats = [tl_heat_(tl_mod) for tl_heat_, tl_mod in zip(self.tl_heats, tl_modules)] + br_heats = [br_heat_(br_mod) for br_heat_, br_mod in zip(self.br_heats, br_modules)] + tl_tags = [tl_tag_(tl_mod) for tl_tag_, tl_mod in zip(self.tl_tags, tl_modules)] + br_tags = [br_tag_(br_mod) for br_tag_, br_mod in zip(self.br_tags, br_modules)] + tl_offs = [tl_off_(tl_mod) for tl_off_, tl_mod in zip(self.tl_offs, tl_modules)] + br_offs = [br_off_(br_mod) for br_off_, br_mod in zip(self.br_offs, br_modules)] + atts = [[att_mod_(u) for att_mod_, u in zip(att_mods, up)] for att_mods, up in zip(self.att_modules, ups)] + return [tl_heats, br_heats, tl_tags, br_tags, tl_offs, br_offs, atts] + + def _test(self, *xs, no_att=False, **kwargs): + image = xs[0] + cnvs, ups = self.hg(image) + ups = [up[self.up_start:] for up in ups] + + if not no_att: + atts = [att_mod_(up) for att_mod_, up in zip(self.att_modules[-1], ups[-1])] + atts = [torch.sigmoid(att) for att in atts] + + tl_mod = self.tl_modules[-1](cnvs[-1]) + br_mod = self.br_modules[-1](cnvs[-1]) + + tl_heat, br_heat = self.tl_heats[-1](tl_mod), self.br_heats[-1](br_mod) + tl_tag, br_tag = self.tl_tags[-1](tl_mod), self.br_tags[-1](br_mod) + tl_off, br_off = self.tl_offs[-1](tl_mod), self.br_offs[-1](br_mod) + + outs = [tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off] + if not no_att: + return self._decode(*outs, **kwargs), atts + else: + return self._decode(*outs, **kwargs) + + def forward(self, *xs, test=False, **kwargs): + if not test: + return self._train(*xs, **kwargs) + return self._test(*xs, **kwargs) diff --git a/core/models/py_utils/scatter_gather.py b/core/models/py_utils/scatter_gather.py new file mode 100644 index 0000000..9a46058 --- /dev/null +++ b/core/models/py_utils/scatter_gather.py @@ -0,0 +1,38 @@ +import torch +from torch.autograd import Variable +from torch.nn.parallel._functions import Scatter, Gather + + +def scatter(inputs, target_gpus, dim=0, chunk_sizes=None): + r""" + Slices variables into approximately equal chunks and + distributes them across given GPUs. Duplicates + references to objects that are not variables. Does not + support Tensors. + """ + def scatter_map(obj): + if isinstance(obj, Variable): + return Scatter.apply(target_gpus, chunk_sizes, dim, obj) + assert not torch.is_tensor(obj), "Tensors not supported in scatter." + if isinstance(obj, tuple): + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list): + return list(map(list, zip(*map(scatter_map, obj)))) + if isinstance(obj, dict): + return list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return [obj for targets in target_gpus] + + return scatter_map(inputs) + + +def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None): + r"""Scatter with support for kwargs dictionary""" + inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else [] + kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/core/models/py_utils/utils.py b/core/models/py_utils/utils.py new file mode 100644 index 0000000..85d3e23 --- /dev/null +++ b/core/models/py_utils/utils.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +def _gather_feat(feat, ind, mask=None): + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + +def _nms(heat, kernel=1): + pad = (kernel - 1) // 2 + + hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad) + keep = (hmax == heat).float() + return heat * keep + +def _tranpose_and_gather_feat(feat, ind): + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = _gather_feat(feat, ind) + return feat + +def _topk(scores, K=20): + batch, cat, height, width = scores.size() + + topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) + + topk_clses = (topk_inds / (height * width)).int() + + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds / width).int().float() + topk_xs = (topk_inds % width).int().float() + return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs + +def _decode( + tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr, + K=100, kernel=1, ae_threshold=1, num_dets=1000, no_border=False +): + batch, cat, height, width = tl_heat.size() + + tl_heat = torch.sigmoid(tl_heat) + br_heat = torch.sigmoid(br_heat) + + # perform nms on heatmaps + tl_heat = _nms(tl_heat, kernel=kernel) + br_heat = _nms(br_heat, kernel=kernel) + + tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K) + br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K) + + tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K) + tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K) + br_ys = br_ys.view(batch, 1, K).expand(batch, K, K) + br_xs = br_xs.view(batch, 1, K).expand(batch, K, K) + + if no_border: + tl_ys_binds = (tl_ys == 0) + tl_xs_binds = (tl_xs == 0) + br_ys_binds = (br_ys == height - 1) + br_xs_binds = (br_xs == width - 1) + + if tl_regr is not None and br_regr is not None: + tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds) + tl_regr = tl_regr.view(batch, K, 1, 2) + br_regr = _tranpose_and_gather_feat(br_regr, br_inds) + br_regr = br_regr.view(batch, 1, K, 2) + + tl_xs = tl_xs + tl_regr[..., 0] + tl_ys = tl_ys + tl_regr[..., 1] + br_xs = br_xs + br_regr[..., 0] + br_ys = br_ys + br_regr[..., 1] + + # all possible boxes based on top k corners (ignoring class) + bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) + + tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds) + tl_tag = tl_tag.view(batch, K, 1) + br_tag = _tranpose_and_gather_feat(br_tag, br_inds) + br_tag = br_tag.view(batch, 1, K) + dists = torch.abs(tl_tag - br_tag) + + tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K) + br_scores = br_scores.view(batch, 1, K).expand(batch, K, K) + scores = (tl_scores + br_scores) / 2 + + # reject boxes based on classes + tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K) + br_clses = br_clses.view(batch, 1, K).expand(batch, K, K) + cls_inds = (tl_clses != br_clses) + + # reject boxes based on distances + dist_inds = (dists > ae_threshold) + + # reject boxes based on widths and heights + width_inds = (br_xs < tl_xs) + height_inds = (br_ys < tl_ys) + + if no_border: + scores[tl_ys_binds] = -1 + scores[tl_xs_binds] = -1 + scores[br_ys_binds] = -1 + scores[br_xs_binds] = -1 + + scores[cls_inds] = -1 + scores[dist_inds] = -1 + scores[width_inds] = -1 + scores[height_inds] = -1 + + scores = scores.view(batch, -1) + scores, inds = torch.topk(scores, num_dets) + scores = scores.unsqueeze(2) + + bboxes = bboxes.view(batch, -1, 4) + bboxes = _gather_feat(bboxes, inds) + + clses = tl_clses.contiguous().view(batch, -1, 1) + clses = _gather_feat(clses, inds).float() + + tl_scores = tl_scores.contiguous().view(batch, -1, 1) + tl_scores = _gather_feat(tl_scores, inds).float() + br_scores = br_scores.contiguous().view(batch, -1, 1) + br_scores = _gather_feat(br_scores, inds).float() + + detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2) + return detections + +class upsample(nn.Module): + def __init__(self, scale_factor): + super(upsample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + return nn.functional.interpolate(x, scale_factor=self.scale_factor) + +class merge(nn.Module): + def forward(self, x, y): + return x + y + +class convolution(nn.Module): + def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): + super(convolution, self).__init__() + + pad = (k - 1) // 2 + self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn) + self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + conv = self.conv(x) + bn = self.bn(conv) + relu = self.relu(bn) + return relu + +class residual(nn.Module): + def __init__(self, inp_dim, out_dim, k=3, stride=1): + super(residual, self).__init__() + p = (k - 1) // 2 + + self.conv1 = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(p, p), stride=(stride, stride), bias=False) + self.bn1 = nn.BatchNorm2d(out_dim) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(out_dim, out_dim, (k, k), padding=(p, p), bias=False) + self.bn2 = nn.BatchNorm2d(out_dim) + + self.skip = nn.Sequential( + nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False), + nn.BatchNorm2d(out_dim) + ) if stride != 1 or inp_dim != out_dim else nn.Sequential() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + conv1 = self.conv1(x) + bn1 = self.bn1(conv1) + relu1 = self.relu1(bn1) + + conv2 = self.conv2(relu1) + bn2 = self.bn2(conv2) + + skip = self.skip(x) + return self.relu(bn2 + skip) + +class corner_pool(nn.Module): + def __init__(self, dim, pool1, pool2): + super(corner_pool, self).__init__() + self._init_layers(dim, pool1, pool2) + + def _init_layers(self, dim, pool1, pool2): + self.p1_conv1 = convolution(3, dim, 128) + self.p2_conv1 = convolution(3, dim, 128) + + self.p_conv1 = nn.Conv2d(128, dim, (3, 3), padding=(1, 1), bias=False) + self.p_bn1 = nn.BatchNorm2d(dim) + + self.conv1 = nn.Conv2d(dim, dim, (1, 1), bias=False) + self.bn1 = nn.BatchNorm2d(dim) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = convolution(3, dim, dim) + + self.pool1 = pool1() + self.pool2 = pool2() + + def forward(self, x): + # pool 1 + p1_conv1 = self.p1_conv1(x) + pool1 = self.pool1(p1_conv1) + + # pool 2 + p2_conv1 = self.p2_conv1(x) + pool2 = self.pool2(p2_conv1) + + # pool 1 + pool 2 + p_conv1 = self.p_conv1(pool1 + pool2) + p_bn1 = self.p_bn1(p_conv1) + + conv1 = self.conv1(x) + bn1 = self.bn1(conv1) + relu1 = self.relu1(p_bn1 + bn1) + + conv2 = self.conv2(relu1) + return conv2 diff --git a/core/nnet/__init__.py b/core/nnet/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/core/nnet/py_factory.py b/core/nnet/py_factory.py new file mode 100755 index 0000000..ed867cd --- /dev/null +++ b/core/nnet/py_factory.py @@ -0,0 +1,137 @@ +import os +import torch +import pickle +import importlib +import torch.nn as nn + +from ..models.py_utils.data_parallel import DataParallel + +torch.manual_seed(317) + +class Network(nn.Module): + def __init__(self, model, loss): + super(Network, self).__init__() + + self.model = model + self.loss = loss + + def forward(self, xs, ys, **kwargs): + preds = self.model(*xs, **kwargs) + loss = self.loss(preds, ys, **kwargs) + return loss + +# for model backward compatibility +# previously model was wrapped by DataParallel module +class DummyModule(nn.Module): + def __init__(self, model): + super(DummyModule, self).__init__() + self.module = model + + def forward(self, *xs, **kwargs): + return self.module(*xs, **kwargs) + +class NetworkFactory(object): + def __init__(self, system_config, model, distributed=False, gpu=None): + super(NetworkFactory, self).__init__() + + self.system_config = system_config + + self.gpu = gpu + self.model = DummyModule(model) + self.loss = model.loss + self.network = Network(self.model, self.loss) + + if distributed: + from apex.parallel import DistributedDataParallel, convert_syncbn_model + torch.cuda.set_device(gpu) + self.network = self.network.cuda(gpu) + self.network = convert_syncbn_model(self.network) + self.network = DistributedDataParallel(self.network) + else: + self.network = DataParallel(self.network, chunk_sizes=system_config.chunk_sizes) + + total_params = 0 + for params in self.model.parameters(): + num_params = 1 + for x in params.size(): + num_params *= x + total_params += num_params + print("total parameters: {}".format(total_params)) + + if system_config.opt_algo == "adam": + self.optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()) + ) + elif system_config.opt_algo == "sgd": + self.optimizer = torch.optim.SGD( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=system_config.learning_rate, + momentum=0.9, weight_decay=0.0001 + ) + else: + raise ValueError("unknown optimizer") + + def cuda(self): + self.model.cuda() + + def train_mode(self): + self.network.train() + + def eval_mode(self): + self.network.eval() + + def _t_cuda(self, xs): + if type(xs) is list: + return [x.cuda(self.gpu, non_blocking=True) for x in xs] + return xs.cuda(self.gpu, non_blocking=True) + + def train(self, xs, ys, **kwargs): + xs = [self._t_cuda(x) for x in xs] + ys = [self._t_cuda(y) for y in ys] + + self.optimizer.zero_grad() + loss = self.network(xs, ys) + loss = loss.mean() + loss.backward() + self.optimizer.step() + + return loss + + def validate(self, xs, ys, **kwargs): + with torch.no_grad(): + xs = [self._t_cuda(x) for x in xs] + ys = [self._t_cuda(y) for y in ys] + + loss = self.network(xs, ys) + loss = loss.mean() + return loss + + def test(self, xs, **kwargs): + with torch.no_grad(): + xs = [self._t_cuda(x) for x in xs] + return self.model(*xs, **kwargs) + + def set_lr(self, lr): + print("setting learning rate to: {}".format(lr)) + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def load_pretrained_params(self, pretrained_model): + print("loading from {}".format(pretrained_model)) + with open(pretrained_model, "rb") as f: + params = torch.load(f) + self.model.load_state_dict(params) + + def load_params(self, iteration): + cache_file = self.system_config.snapshot_file.format(iteration) + print("loading model from {}".format(cache_file)) + with open(cache_file, "rb") as f: + params = torch.load(f) + self.model.load_state_dict(params) + + def save_params(self, iteration): + cache_file = self.system_config.snapshot_file.format(iteration) + print("saving model to {}".format(cache_file)) + with open(cache_file, "wb") as f: + params = self.model.state_dict() + torch.save(params, f) diff --git a/core/paths.py b/core/paths.py new file mode 100644 index 0000000..b487c3d --- /dev/null +++ b/core/paths.py @@ -0,0 +1,7 @@ +import pkg_resources + +_package_name = __name__ + +def get_file_path(*paths): + path = "/".join(paths) + return pkg_resources.resource_filename(_package_name, path) diff --git a/core/sample/__init__.py b/core/sample/__init__.py new file mode 100644 index 0000000..2241054 --- /dev/null +++ b/core/sample/__init__.py @@ -0,0 +1,5 @@ +from .cornernet import cornernet +from .cornernet_saccade import cornernet_saccade + +def data_sampling_func(sys_configs, db, k_ind, data_aug=True, debug=False): + return globals()[sys_configs.sampling_function](sys_configs, db, k_ind, data_aug, debug) diff --git a/core/sample/cornernet.py b/core/sample/cornernet.py new file mode 100644 index 0000000..a4e0796 --- /dev/null +++ b/core/sample/cornernet.py @@ -0,0 +1,160 @@ +import cv2 +import math +import numpy as np +import torch + +from .utils import random_crop, draw_gaussian, gaussian_radius, normalize_, color_jittering_, lighting_ + +def _resize_image(image, detections, size): + detections = detections.copy() + height, width = image.shape[0:2] + new_height, new_width = size + + image = cv2.resize(image, (new_width, new_height)) + + height_ratio = new_height / height + width_ratio = new_width / width + detections[:, 0:4:2] *= width_ratio + detections[:, 1:4:2] *= height_ratio + return image, detections + +def _clip_detections(image, detections): + detections = detections.copy() + height, width = image.shape[0:2] + + detections[:, 0:4:2] = np.clip(detections[:, 0:4:2], 0, width - 1) + detections[:, 1:4:2] = np.clip(detections[:, 1:4:2], 0, height - 1) + keep_inds = ((detections[:, 2] - detections[:, 0]) > 0) & \ + ((detections[:, 3] - detections[:, 1]) > 0) + detections = detections[keep_inds] + return detections + +def cornernet(system_configs, db, k_ind, data_aug, debug): + data_rng = system_configs.data_rng + batch_size = system_configs.batch_size + + categories = db.configs["categories"] + input_size = db.configs["input_size"] + output_size = db.configs["output_sizes"][0] + + border = db.configs["border"] + lighting = db.configs["lighting"] + rand_crop = db.configs["rand_crop"] + rand_color = db.configs["rand_color"] + rand_scales = db.configs["rand_scales"] + gaussian_bump = db.configs["gaussian_bump"] + gaussian_iou = db.configs["gaussian_iou"] + gaussian_rad = db.configs["gaussian_radius"] + + max_tag_len = 128 + + # allocating memory + images = np.zeros((batch_size, 3, input_size[0], input_size[1]), dtype=np.float32) + tl_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + br_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + tl_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32) + br_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32) + tl_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64) + br_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64) + tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8) + tag_lens = np.zeros((batch_size, ), dtype=np.int32) + + db_size = db.db_inds.size + for b_ind in range(batch_size): + if not debug and k_ind == 0: + db.shuffle_inds() + + db_ind = db.db_inds[k_ind] + k_ind = (k_ind + 1) % db_size + + # reading image + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + + # reading detections + detections = db.detections(db_ind) + + # cropping an image randomly + if not debug and rand_crop: + image, detections = random_crop(image, detections, rand_scales, input_size, border=border) + + image, detections = _resize_image(image, detections, input_size) + detections = _clip_detections(image, detections) + + width_ratio = output_size[1] / input_size[1] + height_ratio = output_size[0] / input_size[0] + + # flipping an image randomly + if not debug and np.random.uniform() > 0.5: + image[:] = image[:, ::-1, :] + width = image.shape[1] + detections[:, [0, 2]] = width - detections[:, [2, 0]] - 1 + + if not debug: + image = image.astype(np.float32) / 255. + if rand_color: + color_jittering_(data_rng, image) + if lighting: + lighting_(data_rng, image, 0.1, db.eig_val, db.eig_vec) + normalize_(image, db.mean, db.std) + images[b_ind] = image.transpose((2, 0, 1)) + + for ind, detection in enumerate(detections): + category = int(detection[-1]) - 1 + + xtl, ytl = detection[0], detection[1] + xbr, ybr = detection[2], detection[3] + + fxtl = (xtl * width_ratio) + fytl = (ytl * height_ratio) + fxbr = (xbr * width_ratio) + fybr = (ybr * height_ratio) + + xtl = int(fxtl) + ytl = int(fytl) + xbr = int(fxbr) + ybr = int(fybr) + + if gaussian_bump: + width = detection[2] - detection[0] + height = detection[3] - detection[1] + + width = math.ceil(width * width_ratio) + height = math.ceil(height * height_ratio) + + if gaussian_rad == -1: + radius = gaussian_radius((height, width), gaussian_iou) + radius = max(0, int(radius)) + else: + radius = gaussian_rad + + draw_gaussian(tl_heatmaps[b_ind, category], [xtl, ytl], radius) + draw_gaussian(br_heatmaps[b_ind, category], [xbr, ybr], radius) + else: + tl_heatmaps[b_ind, category, ytl, xtl] = 1 + br_heatmaps[b_ind, category, ybr, xbr] = 1 + + tag_ind = tag_lens[b_ind] + tl_regrs[b_ind, tag_ind, :] = [fxtl - xtl, fytl - ytl] + br_regrs[b_ind, tag_ind, :] = [fxbr - xbr, fybr - ybr] + tl_tags[b_ind, tag_ind] = ytl * output_size[1] + xtl + br_tags[b_ind, tag_ind] = ybr * output_size[1] + xbr + tag_lens[b_ind] += 1 + + for b_ind in range(batch_size): + tag_len = tag_lens[b_ind] + tag_masks[b_ind, :tag_len] = 1 + + images = torch.from_numpy(images) + tl_heatmaps = torch.from_numpy(tl_heatmaps) + br_heatmaps = torch.from_numpy(br_heatmaps) + tl_regrs = torch.from_numpy(tl_regrs) + br_regrs = torch.from_numpy(br_regrs) + tl_tags = torch.from_numpy(tl_tags) + br_tags = torch.from_numpy(br_tags) + tag_masks = torch.from_numpy(tag_masks) + + return { + "xs": [images], + "ys": [tl_heatmaps, br_heatmaps, tag_masks, tl_regrs, br_regrs, tl_tags, br_tags] + }, k_ind diff --git a/core/sample/cornernet_saccade.py b/core/sample/cornernet_saccade.py new file mode 100644 index 0000000..14e368b --- /dev/null +++ b/core/sample/cornernet_saccade.py @@ -0,0 +1,285 @@ +import cv2 +import math +import torch +import numpy as np + +from .utils import draw_gaussian, gaussian_radius, normalize_, color_jittering_, lighting_, crop_image + +def bbox_overlaps(a_dets, b_dets): + a_widths = a_dets[:, 2] - a_dets[:, 0] + a_heights = a_dets[:, 3] - a_dets[:, 1] + a_areas = a_widths * a_heights + + b_widths = b_dets[:, 2] - b_dets[:, 0] + b_heights = b_dets[:, 3] - b_dets[:, 1] + b_areas = b_widths * b_heights + + return a_areas / b_areas + +def clip_detections(border, detections): + detections = detections.copy() + + y0, y1, x0, x1 = border + det_xs = detections[:, 0:4:2] + det_ys = detections[:, 1:4:2] + np.clip(det_xs, x0, x1 - 1, out=det_xs) + np.clip(det_ys, y0, y1 - 1, out=det_ys) + + keep_inds = ((det_xs[:, 1] - det_xs[:, 0]) > 0) & \ + ((det_ys[:, 1] - det_ys[:, 0]) > 0) + keep_inds = np.where(keep_inds)[0] + return detections[keep_inds], keep_inds + +def crop_image_dets(image, dets, ind, input_size, output_size=None, random_crop=True, rand_center=True): + if ind is not None: + det_x0, det_y0, det_x1, det_y1 = dets[ind, 0:4] + else: + det_x0, det_y0, det_x1, det_y1 = None, None, None, None + + input_height, input_width = input_size + image_height, image_width = image.shape[0:2] + + centered = rand_center and np.random.uniform() > 0.5 + if not random_crop or image_width <= input_width: + xc = image_width // 2 + elif ind is None or not centered: + xmin = max(det_x1 - input_width, 0) if ind is not None else 0 + xmax = min(image_width - input_width, det_x0) if ind is not None else image_width - input_width + xrand = np.random.randint(int(xmin), int(xmax) + 1) + xc = xrand + input_width // 2 + else: + xmin = max((det_x0 + det_x1) // 2 - np.random.randint(0, 15), 0) + xmax = min((det_x0 + det_x1) // 2 + np.random.randint(0, 15), image_width - 1) + xc = np.random.randint(int(xmin), int(xmax) + 1) + + if not random_crop or image_height <= input_height: + yc = image_height // 2 + elif ind is None or not centered: + ymin = max(det_y1 - input_height, 0) if ind is not None else 0 + ymax = min(image_height - input_height, det_y0) if ind is not None else image_height - input_height + yrand = np.random.randint(int(ymin), int(ymax) + 1) + yc = yrand + input_height // 2 + else: + ymin = max((det_y0 + det_y1) // 2 - np.random.randint(0, 15), 0) + ymax = min((det_y0 + det_y1) // 2 + np.random.randint(0, 15), image_height - 1) + yc = np.random.randint(int(ymin), int(ymax) + 1) + + image, border, offset = crop_image(image, [yc, xc], input_size, output_size=output_size) + dets[:, 0:4:2] -= offset[1] + dets[:, 1:4:2] -= offset[0] + return image, dets, border + +def scale_image_detections(image, dets, scale): + height, width = image.shape[0:2] + + new_height = int(height * scale) + new_width = int(width * scale) + + image = cv2.resize(image, (new_width, new_height)) + dets = dets.copy() + dets[:, 0:4] *= scale + return image, dets + +def ref_scale(detections, random_crop=False): + if detections.shape[0] == 0: + return None, None + + if random_crop and np.random.uniform() > 0.7: + return None, None + + ref_ind = np.random.randint(detections.shape[0]) + ref_det = detections[ref_ind].copy() + ref_h = ref_det[3] - ref_det[1] + ref_w = ref_det[2] - ref_det[0] + ref_hw = max(ref_h, ref_w) + + if ref_hw > 96: + return np.random.randint(low=96, high=255) / ref_hw, ref_ind + elif ref_hw > 32: + return np.random.randint(low=32, high=97) / ref_hw, ref_ind + return np.random.randint(low=16, high=33) / ref_hw, ref_ind + +def create_attention_mask(atts, ratios, sizes, detections): + for det in detections: + width = det[2] - det[0] + height = det[3] - det[1] + + max_hw = max(width, height) + for att, ratio, size in zip(atts, ratios, sizes): + if max_hw >= size[0] and max_hw <= size[1]: + x = (det[0] + det[2]) / 2 + y = (det[1] + det[3]) / 2 + x = (x / ratio).astype(np.int32) + y = (y / ratio).astype(np.int32) + att[y, x] = 1 + +def cornernet_saccade(system_configs, db, k_ind, data_aug, debug): + data_rng = system_configs.data_rng + batch_size = system_configs.batch_size + + categories = db.configs["categories"] + input_size = db.configs["input_size"] + output_size = db.configs["output_sizes"][0] + rand_scales = db.configs["rand_scales"] + rand_crop = db.configs["rand_crop"] + rand_center = db.configs["rand_center"] + view_sizes = db.configs["view_sizes"] + + gaussian_iou = db.configs["gaussian_iou"] + gaussian_rad = db.configs["gaussian_radius"] + + att_ratios = db.configs["att_ratios"] + att_ranges = db.configs["att_ranges"] + att_sizes = db.configs["att_sizes"] + + min_scale = db.configs["min_scale"] + max_scale = db.configs["max_scale"] + max_objects = 128 + + images = np.zeros((batch_size, 3, input_size[0], input_size[1]), dtype=np.float32) + tl_heats = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + br_heats = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + tl_valids = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + br_valids = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) + tl_regrs = np.zeros((batch_size, max_objects, 2), dtype=np.float32) + br_regrs = np.zeros((batch_size, max_objects, 2), dtype=np.float32) + tl_tags = np.zeros((batch_size, max_objects), dtype=np.int64) + br_tags = np.zeros((batch_size, max_objects), dtype=np.int64) + tag_masks = np.zeros((batch_size, max_objects), dtype=np.uint8) + tag_lens = np.zeros((batch_size, ), dtype=np.int32) + attentions = [np.zeros((batch_size, 1, att_size[0], att_size[1]), dtype=np.float32) for att_size in att_sizes] + + db_size = db.db_inds.size + for b_ind in range(batch_size): + if not debug and k_ind == 0: + # if k_ind == 0: + db.shuffle_inds() + + db_ind = db.db_inds[k_ind] + k_ind = (k_ind + 1) % db_size + + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + + orig_detections = db.detections(db_ind) + keep_inds = np.arange(orig_detections.shape[0]) + + # clip the detections + detections = orig_detections.copy() + border = [0, image.shape[0], 0, image.shape[1]] + detections, clip_inds = clip_detections(border, detections) + keep_inds = keep_inds[clip_inds] + + scale, ref_ind = ref_scale(detections, random_crop=rand_crop) + scale = np.random.choice(rand_scales) if scale is None else scale + + orig_detections[:, 0:4:2] *= scale + orig_detections[:, 1:4:2] *= scale + + image, detections = scale_image_detections(image, detections, scale) + ref_detection = detections[ref_ind].copy() + + image, detections, border = crop_image_dets(image, detections, ref_ind, input_size, rand_center=rand_center) + + detections, clip_inds = clip_detections(border, detections) + keep_inds = keep_inds[clip_inds] + + width_ratio = output_size[1] / input_size[1] + height_ratio = output_size[0] / input_size[0] + + # flipping an image randomly + if not debug and np.random.uniform() > 0.5: + image[:] = image[:, ::-1, :] + width = image.shape[1] + detections[:, [0, 2]] = width - detections[:, [2, 0]] - 1 + create_attention_mask([att[b_ind, 0] for att in attentions], att_ratios, att_ranges, detections) + + if debug: + dimage = image.copy() + for det in detections.astype(np.int32): + cv2.rectangle(dimage, + (det[0], det[1]), + (det[2], det[3]), + (0, 255, 0), 2 + ) + cv2.imwrite('debug/{:03d}.jpg'.format(b_ind), dimage) + overlaps = bbox_overlaps(detections, orig_detections[keep_inds]) > 0.5 + + if not debug: + image = image.astype(np.float32) / 255. + color_jittering_(data_rng, image) + lighting_(data_rng, image, 0.1, db.eig_val, db.eig_vec) + normalize_(image, db.mean, db.std) + images[b_ind] = image.transpose((2, 0, 1)) + + for ind, (detection, overlap) in enumerate(zip(detections, overlaps)): + category = int(detection[-1]) - 1 + + xtl, ytl = detection[0], detection[1] + xbr, ybr = detection[2], detection[3] + + det_height = int(ybr) - int(ytl) + det_width = int(xbr) - int(xtl) + det_max = max(det_height, det_width) + + valid = det_max >= min_scale + + fxtl = (xtl * width_ratio) + fytl = (ytl * height_ratio) + fxbr = (xbr * width_ratio) + fybr = (ybr * height_ratio) + + xtl = int(fxtl) + ytl = int(fytl) + xbr = int(fxbr) + ybr = int(fybr) + + width = detection[2] - detection[0] + height = detection[3] - detection[1] + + width = math.ceil(width * width_ratio) + height = math.ceil(height * height_ratio) + + if gaussian_rad == -1: + radius = gaussian_radius((height, width), gaussian_iou) + radius = max(0, int(radius)) + else: + radius = gaussian_rad + + if overlap and valid: + draw_gaussian(tl_heats[b_ind, category], [xtl, ytl], radius) + draw_gaussian(br_heats[b_ind, category], [xbr, ybr], radius) + + tag_ind = tag_lens[b_ind] + tl_regrs[b_ind, tag_ind, :] = [fxtl - xtl, fytl - ytl] + br_regrs[b_ind, tag_ind, :] = [fxbr - xbr, fybr - ybr] + tl_tags[b_ind, tag_ind] = ytl * output_size[1] + xtl + br_tags[b_ind, tag_ind] = ybr * output_size[1] + xbr + tag_lens[b_ind] += 1 + else: + draw_gaussian(tl_valids[b_ind, category], [xtl, ytl], radius) + draw_gaussian(br_valids[b_ind, category], [xbr, ybr], radius) + + tl_valids = (tl_valids == 0).astype(np.float32) + br_valids = (br_valids == 0).astype(np.float32) + + for b_ind in range(batch_size): + tag_len = tag_lens[b_ind] + tag_masks[b_ind, :tag_len] = 1 + + images = torch.from_numpy(images) + tl_heats = torch.from_numpy(tl_heats) + br_heats = torch.from_numpy(br_heats) + tl_regrs = torch.from_numpy(tl_regrs) + br_regrs = torch.from_numpy(br_regrs) + tl_tags = torch.from_numpy(tl_tags) + br_tags = torch.from_numpy(br_tags) + tag_masks = torch.from_numpy(tag_masks) + tl_valids = torch.from_numpy(tl_valids) + br_valids = torch.from_numpy(br_valids) + attentions = [torch.from_numpy(att) for att in attentions] + + return { + "xs": [images], + "ys": [tl_heats, br_heats, tag_masks, tl_regrs, br_regrs, tl_tags, br_tags, tl_valids, br_valids, attentions] + }, k_ind diff --git a/core/sample/utils.py b/core/sample/utils.py new file mode 100644 index 0000000..fd8f437 --- /dev/null +++ b/core/sample/utils.py @@ -0,0 +1,163 @@ +import cv2 +import numpy as np +import random + +def grayscale(image): + return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + +def normalize_(image, mean, std): + image -= mean + image /= std + +def lighting_(data_rng, image, alphastd, eigval, eigvec): + alpha = data_rng.normal(scale=alphastd, size=(3, )) + image += np.dot(eigvec, eigval * alpha) + +def blend_(alpha, image1, image2): + image1 *= alpha + image2 *= (1 - alpha) + image1 += image2 + +def saturation_(data_rng, image, gs, gs_mean, var): + alpha = 1. + data_rng.uniform(low=-var, high=var) + blend_(alpha, image, gs[:, :, None]) + +def brightness_(data_rng, image, gs, gs_mean, var): + alpha = 1. + data_rng.uniform(low=-var, high=var) + image *= alpha + +def contrast_(data_rng, image, gs, gs_mean, var): + alpha = 1. + data_rng.uniform(low=-var, high=var) + blend_(alpha, image, gs_mean) + +def color_jittering_(data_rng, image): + functions = [brightness_, contrast_, saturation_] + random.shuffle(functions) + + gs = grayscale(image) + gs_mean = gs.mean() + for f in functions: + f(data_rng, image, gs, gs_mean, 0.4) + +def gaussian2D(shape, sigma=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m+1,-n:n+1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + +def draw_gaussian(heatmap, center, radius, k=1): + diameter = 2 * radius + 1 + gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) + + x, y = center + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + +def gaussian_radius(det_size, min_overlap): + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) + r1 = (b1 - sq1) / (2 * a1) + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) + r2 = (b2 - sq2) / (2 * a2) + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) + r3 = (b3 + sq3) / (2 * a3) + return min(r1, r2, r3) + +def _get_border(border, size): + i = 1 + while size - border // i <= border // i: + i *= 2 + return border // i + +def random_crop(image, detections, random_scales, view_size, border=64): + view_height, view_width = view_size + image_height, image_width = image.shape[0:2] + + scale = np.random.choice(random_scales) + height = int(view_height * scale) + width = int(view_width * scale) + + cropped_image = np.zeros((height, width, 3), dtype=image.dtype) + + w_border = _get_border(border, image_width) + h_border = _get_border(border, image_height) + + ctx = np.random.randint(low=w_border, high=image_width - w_border) + cty = np.random.randint(low=h_border, high=image_height - h_border) + + x0, x1 = max(ctx - width // 2, 0), min(ctx + width // 2, image_width) + y0, y1 = max(cty - height // 2, 0), min(cty + height // 2, image_height) + + left_w, right_w = ctx - x0, x1 - ctx + top_h, bottom_h = cty - y0, y1 - cty + + # crop image + cropped_ctx, cropped_cty = width // 2, height // 2 + x_slice = slice(cropped_ctx - left_w, cropped_ctx + right_w) + y_slice = slice(cropped_cty - top_h, cropped_cty + bottom_h) + cropped_image[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] + + # crop detections + cropped_detections = detections.copy() + cropped_detections[:, 0:4:2] -= x0 + cropped_detections[:, 1:4:2] -= y0 + cropped_detections[:, 0:4:2] += cropped_ctx - left_w + cropped_detections[:, 1:4:2] += cropped_cty - top_h + + return cropped_image, cropped_detections + +def crop_image(image, center, size, output_size=None): + if output_size == None: + output_size = size + + cty, ctx = center + height, width = size + o_height, o_width = output_size + im_height, im_width = image.shape[0:2] + cropped_image = np.zeros((o_height, o_width, 3), dtype=image.dtype) + + x0, x1 = max(0, ctx - width // 2), min(ctx + width // 2, im_width) + y0, y1 = max(0, cty - height // 2), min(cty + height // 2, im_height) + + left, right = ctx - x0, x1 - ctx + top, bottom = cty - y0, y1 - cty + + cropped_cty, cropped_ctx = o_height // 2, o_width // 2 + y_slice = slice(cropped_cty - top, cropped_cty + bottom) + x_slice = slice(cropped_ctx - left, cropped_ctx + right) + cropped_image[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] + + border = np.array([ + cropped_cty - top, + cropped_cty + bottom, + cropped_ctx - left, + cropped_ctx + right + ], dtype=np.float32) + + offset = np.array([ + cty - o_height // 2, + ctx - o_width // 2 + ]) + + return cropped_image, border, offset diff --git a/core/test/__init__.py b/core/test/__init__.py new file mode 100644 index 0000000..080cc8f --- /dev/null +++ b/core/test/__init__.py @@ -0,0 +1,5 @@ +from .cornernet import cornernet +from .cornernet_saccade import cornernet_saccade + +def test_func(sys_config, db, nnet, result_dir, debug=False): + return globals()[sys_config.sampling_function](db, nnet, result_dir, debug=debug) diff --git a/core/test/cornernet.py b/core/test/cornernet.py new file mode 100644 index 0000000..e20f014 --- /dev/null +++ b/core/test/cornernet.py @@ -0,0 +1,176 @@ +import os +import cv2 +import json +import numpy as np +import torch + +from tqdm import tqdm + +from ..utils import Timer +from ..vis_utils import draw_bboxes +from ..sample.utils import crop_image +from ..external.nms import soft_nms, soft_nms_merge + +def rescale_dets_(detections, ratios, borders, sizes): + xs, ys = detections[..., 0:4:2], detections[..., 1:4:2] + xs /= ratios[:, 1][:, None, None] + ys /= ratios[:, 0][:, None, None] + xs -= borders[:, 2][:, None, None] + ys -= borders[:, 0][:, None, None] + np.clip(xs, 0, sizes[:, 1][:, None, None], out=xs) + np.clip(ys, 0, sizes[:, 0][:, None, None], out=ys) + +def decode(nnet, images, K, ae_threshold=0.5, kernel=3, num_dets=1000): + detections = nnet.test([images], ae_threshold=ae_threshold, test=True, K=K, kernel=kernel, num_dets=num_dets)[0] + return detections.data.cpu().numpy() + +def cornernet(db, nnet, result_dir, debug=False, decode_func=decode): + debug_dir = os.path.join(result_dir, "debug") + if not os.path.exists(debug_dir): + os.makedirs(debug_dir) + + if db.split != "trainval2014": + db_inds = db.db_inds[:100] if debug else db.db_inds + else: + db_inds = db.db_inds[:100] if debug else db.db_inds[:5000] + + num_images = db_inds.size + categories = db.configs["categories"] + + timer = Timer() + top_bboxes = {} + for ind in tqdm(range(0, num_images), ncols=80, desc="locating kps"): + db_ind = db_inds[ind] + + image_id = db.image_ids(db_ind) + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + + timer.tic() + top_bboxes[image_id] = cornernet_inference(db, nnet, image) + timer.toc() + + if debug: + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + bboxes = { + db.cls2name(j): top_bboxes[image_id][j] + for j in range(1, categories + 1) + } + image = draw_bboxes(image, bboxes) + debug_file = os.path.join(debug_dir, "{}.jpg".format(db_ind)) + cv2.imwrite(debug_file, image) + print('average time: {}'.format(timer.average_time)) + + result_json = os.path.join(result_dir, "results.json") + detections = db.convert_to_coco(top_bboxes) + with open(result_json, "w") as f: + json.dump(detections, f) + + cls_ids = list(range(1, categories + 1)) + image_ids = [db.image_ids(ind) for ind in db_inds] + db.evaluate(result_json, cls_ids, image_ids) + return 0 + +def cornernet_inference(db, nnet, image, decode_func=decode): + K = db.configs["top_k"] + ae_threshold = db.configs["ae_threshold"] + nms_kernel = db.configs["nms_kernel"] + num_dets = db.configs["num_dets"] + test_flipped = db.configs["test_flipped"] + + input_size = db.configs["input_size"] + output_size = db.configs["output_sizes"][0] + + scales = db.configs["test_scales"] + weight_exp = db.configs["weight_exp"] + merge_bbox = db.configs["merge_bbox"] + categories = db.configs["categories"] + nms_threshold = db.configs["nms_threshold"] + max_per_image = db.configs["max_per_image"] + nms_algorithm = { + "nms": 0, + "linear_soft_nms": 1, + "exp_soft_nms": 2 + }[db.configs["nms_algorithm"]] + + height, width = image.shape[0:2] + + height_scale = (input_size[0] + 1) // output_size[0] + width_scale = (input_size[1] + 1) // output_size[1] + + im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1) + im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1) + + detections = [] + for scale in scales: + new_height = int(height * scale) + new_width = int(width * scale) + new_center = np.array([new_height // 2, new_width // 2]) + + inp_height = new_height | 127 + inp_width = new_width | 127 + + images = np.zeros((1, 3, inp_height, inp_width), dtype=np.float32) + ratios = np.zeros((1, 2), dtype=np.float32) + borders = np.zeros((1, 4), dtype=np.float32) + sizes = np.zeros((1, 2), dtype=np.float32) + + out_height, out_width = (inp_height + 1) // height_scale, (inp_width + 1) // width_scale + height_ratio = out_height / inp_height + width_ratio = out_width / inp_width + + resized_image = cv2.resize(image, (new_width, new_height)) + resized_image, border, offset = crop_image(resized_image, new_center, [inp_height, inp_width]) + + resized_image = resized_image / 255. + + images[0] = resized_image.transpose((2, 0, 1)) + borders[0] = border + sizes[0] = [int(height * scale), int(width * scale)] + ratios[0] = [height_ratio, width_ratio] + + if test_flipped: + images = np.concatenate((images, images[:, :, :, ::-1]), axis=0) + images = torch.from_numpy(images).cuda() + images -= im_mean + images /= im_std + + dets = decode_func(nnet, images, K, ae_threshold=ae_threshold, kernel=nms_kernel, num_dets=num_dets) + if test_flipped: + dets[1, :, [0, 2]] = out_width - dets[1, :, [2, 0]] + dets = dets.reshape(1, -1, 8) + + rescale_dets_(dets, ratios, borders, sizes) + dets[:, :, 0:4] /= scale + detections.append(dets) + + detections = np.concatenate(detections, axis=1) + + classes = detections[..., -1] + classes = classes[0] + detections = detections[0] + + # reject detections with negative scores + keep_inds = (detections[:, 4] > -1) + detections = detections[keep_inds] + classes = classes[keep_inds] + + top_bboxes = {} + for j in range(categories): + keep_inds = (classes == j) + top_bboxes[j + 1] = detections[keep_inds][:, 0:7].astype(np.float32) + if merge_bbox: + soft_nms_merge(top_bboxes[j + 1], Nt=nms_threshold, method=nms_algorithm, weight_exp=weight_exp) + else: + soft_nms(top_bboxes[j + 1], Nt=nms_threshold, method=nms_algorithm) + top_bboxes[j + 1] = top_bboxes[j + 1][:, 0:5] + + scores = np.hstack([top_bboxes[j][:, -1] for j in range(1, categories + 1)]) + if len(scores) > max_per_image: + kth = len(scores) - max_per_image + thresh = np.partition(scores, kth)[kth] + for j in range(1, categories + 1): + keep_inds = (top_bboxes[j][:, -1] >= thresh) + top_bboxes[j] = top_bboxes[j][keep_inds] + return top_bboxes diff --git a/core/test/cornernet_saccade.py b/core/test/cornernet_saccade.py new file mode 100644 index 0000000..eb0e6bc --- /dev/null +++ b/core/test/cornernet_saccade.py @@ -0,0 +1,394 @@ +import os +import cv2 +import math +import json +import torch +import torch.nn as nn +import numpy as np + +from tqdm import tqdm + +from ..utils import Timer +from ..vis_utils import draw_bboxes +from ..external.nms import soft_nms + +def crop_image_gpu(image, center, size, out_image): + cty, ctx = center + height, width = size + o_height, o_width = out_image.shape[1:3] + im_height, im_width = image.shape[1:3] + + scale = o_height / max(height, width) + x0, x1 = max(0, ctx - width // 2), min(ctx + width // 2, im_width) + y0, y1 = max(0, cty - height // 2), min(cty + height // 2, im_height) + + left, right = ctx - x0, x1 - ctx + top, bottom = cty - y0, y1 - cty + + cropped_cty, cropped_ctx = o_height // 2, o_width // 2 + out_y0, out_y1 = cropped_cty - int(top * scale), cropped_cty + int(bottom * scale) + out_x0, out_x1 = cropped_ctx - int(left * scale), cropped_ctx + int(right * scale) + + new_height = out_y1 - out_y0 + new_width = out_x1 - out_x0 + image = image[:, y0:y1, x0:x1].unsqueeze(0) + out_image[:, out_y0:out_y1, out_x0:out_x1] = nn.functional.interpolate( + image, size=[new_height, new_width], mode='bilinear' + )[0] + + return np.array([cty - height // 2, ctx - width // 2]) + +def remap_dets_(detections, scales, offsets): + xs, ys = detections[..., 0:4:2], detections[..., 1:4:2] + + xs /= scales.reshape(-1, 1, 1) + ys /= scales.reshape(-1, 1, 1) + xs += offsets[:, 1][:, None, None] + ys += offsets[:, 0][:, None, None] + +def att_nms(atts, ks): + pads = [(k - 1) // 2 for k in ks] + pools = [nn.functional.max_pool2d(att, (k, k), stride=1, padding=pad) for k, att, pad in zip(ks, atts, pads)] + keeps = [(att == pool).float() for att, pool in zip(atts, pools)] + atts = [att * keep for att, keep in zip(atts, keeps)] + return atts + +def batch_decode(db, nnet, images, no_att=False): + K = db.configs["top_k"] + ae_threshold = db.configs["ae_threshold"] + kernel = db.configs["nms_kernel"] + num_dets = db.configs["num_dets"] + + att_nms_ks = db.configs["att_nms_ks"] + att_ranges = db.configs["att_ranges"] + + num_images = images.shape[0] + detections = [] + attentions = [[] for _ in range(len(att_ranges))] + + batch_size = 32 + for b_ind in range(math.ceil(num_images / batch_size)): + b_start = b_ind * batch_size + b_end = min(num_images, (b_ind + 1) * batch_size) + + b_images = images[b_start:b_end] + b_outputs = nnet.test( + [b_images], ae_threshold=ae_threshold, K=K, kernel=kernel, + test=True, num_dets=num_dets, no_border=True, no_att=no_att + ) + if no_att: + b_detections = b_outputs + else: + b_detections = b_outputs[0] + b_attentions = b_outputs[1] + b_attentions = att_nms(b_attentions, att_nms_ks) + b_attentions = [b_attention.data.cpu().numpy() for b_attention in b_attentions] + + b_detections = b_detections.data.cpu().numpy() + + detections.append(b_detections) + if not no_att: + for attention, b_attention in zip(attentions, b_attentions): + attention.append(b_attention) + + if not no_att: + attentions = [np.concatenate(atts, axis=0) for atts in attentions] if detections else None + detections = np.concatenate(detections, axis=0) if detections else np.zeros((0, num_dets, 8)) + return detections, attentions + +def decode_atts(db, atts, att_scales, scales, offsets, height, width, thresh, ignore_same=False): + att_ranges = db.configs["att_ranges"] + att_ratios = db.configs["att_ratios"] + input_size = db.configs["input_size"] + + next_ys, next_xs, next_scales, next_scores = [], [], [], [] + + num_atts = atts[0].shape[0] + for aind in range(num_atts): + for att, att_range, att_ratio, att_scale in zip(atts, att_ranges, att_ratios, att_scales): + ys, xs = np.where(att[aind, 0] > thresh) + scores = att[aind, 0, ys, xs] + + ys = ys * att_ratio / scales[aind] + offsets[aind, 0] + xs = xs * att_ratio / scales[aind] + offsets[aind, 1] + + keep = (ys >= 0) & (ys < height) & (xs >= 0) & (xs < width) + ys, xs, scores = ys[keep], xs[keep], scores[keep] + + next_scale = att_scale * scales[aind] + if (ignore_same and att_scale <= 1) or scales[aind] > 2 or next_scale > 4: + continue + + next_scales += [next_scale] * len(xs) + next_scores += scores.tolist() + next_ys += ys.tolist() + next_xs += xs.tolist() + next_ys = np.array(next_ys) + next_xs = np.array(next_xs) + next_scales = np.array(next_scales) + next_scores = np.array(next_scores) + return np.stack((next_ys, next_xs, next_scales, next_scores), axis=1) + +def get_ref_locs(dets): + keep = dets[:, 4] > 0.5 + dets = dets[keep] + + ref_xs = (dets[:, 0] + dets[:, 2]) / 2 + ref_ys = (dets[:, 1] + dets[:, 3]) / 2 + + ref_maxhws = np.maximum(dets[:, 2] - dets[:, 0], dets[:, 3] - dets[:, 1]) + ref_scales = np.zeros_like(ref_maxhws) + ref_scores = dets[:, 4] + + large_inds = ref_maxhws > 96 + medium_inds = (ref_maxhws > 32) & (ref_maxhws <= 96) + small_inds = ref_maxhws <= 32 + + ref_scales[large_inds] = 192 / ref_maxhws[large_inds] + ref_scales[medium_inds] = 64 / ref_maxhws[medium_inds] + ref_scales[small_inds] = 24 / ref_maxhws[small_inds] + + new_locations = np.stack((ref_ys, ref_xs, ref_scales, ref_scores), axis=1) + new_locations[:, 3] = 1 + return new_locations + +def get_locs(db, nnet, image, im_mean, im_std, att_scales, thresh, sizes, ref_dets=True): + att_ranges = db.configs["att_ranges"] + att_ratios = db.configs["att_ratios"] + input_size = db.configs["input_size"] + + height, width = image.shape[1:3] + + locations = [] + for size in sizes: + scale = size / max(height, width) + location = [height // 2, width // 2, scale] + locations.append(location) + + locations = np.array(locations, dtype=np.float32) + images, offsets = prepare_images(db, image, locations, flipped=False) + + images -= im_mean + images /= im_std + + dets, atts = batch_decode(db, nnet, images) + + scales = locations[:, 2] + next_locations = decode_atts(db, atts, att_scales, scales, offsets, height, width, thresh) + + rescale_dets_(db, dets) + remap_dets_(dets, scales, offsets) + + dets = dets.reshape(-1, 8) + keep = dets[:, 4] > 0.3 + dets = dets[keep] + + if ref_dets: + ref_locations = get_ref_locs(dets) + next_locations = np.concatenate((next_locations, ref_locations), axis=0) + next_locations = location_nms(next_locations, thresh=16) + return dets, next_locations, atts + +def location_nms(locations, thresh=15): + next_locations = [] + sorted_inds = np.argsort(locations[:, -1])[::-1] + + locations = locations[sorted_inds] + ys = locations[:, 0] + xs = locations[:, 1] + scales = locations[:, 2] + + dist_ys = np.absolute(ys.reshape(-1, 1) - ys.reshape(1, -1)) + dist_xs = np.absolute(xs.reshape(-1, 1) - xs.reshape(1, -1)) + dists = np.minimum(dist_ys, dist_xs) + ratios = scales.reshape(-1, 1) / scales.reshape(1, -1) + while dists.shape[0] > 0: + next_locations.append(locations[0]) + + scale = scales[0] + dist = dists[0] + ratio = ratios[0] + + keep = (dist > (thresh / scale)) | (ratio > 1.2) | (ratio < 0.8) + + locations = locations[keep] + + scales = scales[keep] + dists = dists[keep, :] + dists = dists[:, keep] + ratios = ratios[keep, :] + ratios = ratios[:, keep] + return np.stack(next_locations) if next_locations else np.zeros((0, 4)) + +def prepare_images(db, image, locs, flipped=True): + input_size = db.configs["input_size"] + num_patches = locs.shape[0] + + images = torch.cuda.FloatTensor(num_patches, 3, input_size[0], input_size[1]).fill_(0) + offsets = np.zeros((num_patches, 2), dtype=np.float32) + for ind, (y, x, scale) in enumerate(locs[:, :3]): + crop_height = int(input_size[0] / scale) + crop_width = int(input_size[1] / scale) + offsets[ind] = crop_image_gpu(image, [int(y), int(x)], [crop_height, crop_width], images[ind]) + return images, offsets + +def rescale_dets_(db, dets): + input_size = db.configs["input_size"] + output_size = db.configs["output_sizes"][0] + + ratios = [o / i for o, i in zip(output_size, input_size)] + dets[..., 0:4:2] /= ratios[1] + dets[..., 1:4:2] /= ratios[0] + +def cornernet_saccade(db, nnet, result_dir, debug=False, decode_func=batch_decode): + debug_dir = os.path.join(result_dir, "debug") + if not os.path.exists(debug_dir): + os.makedirs(debug_dir) + + if db.split != "trainval2014": + db_inds = db.db_inds[:500] if debug else db.db_inds + else: + db_inds = db.db_inds[:100] if debug else db.db_inds[:5000] + + num_images = db_inds.size + categories = db.configs["categories"] + + timer = Timer() + top_bboxes = {} + for k_ind in tqdm(range(0, num_images), ncols=80, desc="locating kps"): + db_ind = db_inds[k_ind] + + image_id = db.image_ids(db_ind) + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + + timer.tic() + top_bboxes[image_id] = cornernet_saccade_inference(db, nnet, image) + timer.toc() + + if debug: + image_path = db.image_path(db_ind) + image = cv2.imread(image_path) + bboxes = { + db.cls2name(j): top_bboxes[image_id][j] + for j in range(1, categories + 1) + } + image = draw_bboxes(image, bboxes) + debug_file = os.path.join(debug_dir, "{}.jpg".format(db_ind)) + cv2.imwrite(debug_file, image) + print('average time: {}'.format(timer.average_time)) + + result_json = os.path.join(result_dir, "results.json") + detections = db.convert_to_coco(top_bboxes) + with open(result_json, "w") as f: + json.dump(detections, f) + + cls_ids = list(range(1, categories + 1)) + image_ids = [db.image_ids(ind) for ind in db_inds] + db.evaluate(result_json, cls_ids, image_ids) + return 0 + +def cornernet_saccade_inference(db, nnet, image, decode_func=batch_decode): + init_sizes = db.configs["init_sizes"] + ref_dets = db.configs["ref_dets"] + + att_thresholds = db.configs["att_thresholds"] + att_scales = db.configs["att_scales"] + att_max_crops = db.configs["att_max_crops"] + + categories = db.configs["categories"] + nms_threshold = db.configs["nms_threshold"] + max_per_image = db.configs["max_per_image"] + nms_algorithm = { + "nms": 0, + "linear_soft_nms": 1, + "exp_soft_nms": 2 + }[db.configs["nms_algorithm"]] + + num_iterations = len(att_thresholds) + + im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1) + im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1) + + detections = [] + height, width = image.shape[0:2] + + image = image / 255. + image = image.transpose((2, 0, 1)).copy() + image = torch.from_numpy(image).cuda(non_blocking=True) + + dets, locations, atts = get_locs( + db, nnet, image, im_mean, im_std, + att_scales[0], att_thresholds[0], + init_sizes, ref_dets=ref_dets + ) + + detections = [dets] + num_patches = locations.shape[0] + + num_crops = 0 + for ind in range(1, num_iterations + 1): + if num_patches == 0: + break + + if num_crops + num_patches > att_max_crops: + max_crops = min(att_max_crops - num_crops, num_patches) + locations = locations[:max_crops] + + num_patches = locations.shape[0] + num_crops += locations.shape[0] + no_att = (ind == num_iterations) + + images, offsets = prepare_images(db, image, locations, flipped=False) + images -= im_mean + images /= im_std + + dets, atts = decode_func(db, nnet, images, no_att=no_att) + dets = dets.reshape(num_patches, -1, 8) + + rescale_dets_(db, dets) + remap_dets_(dets, locations[:, 2], offsets) + + dets = dets.reshape(-1, 8) + keeps = (dets[:, 4] > -1) + dets = dets[keeps] + + detections.append(dets) + + if num_crops == att_max_crops: + break + + if ind < num_iterations: + att_threshold = att_thresholds[ind] + att_scale = att_scales[ind] + + next_locations = decode_atts( + db, atts, att_scale, locations[:, 2], offsets, height, width, att_threshold, ignore_same=True + ) + + if ref_dets: + ref_locations = get_ref_locs(dets) + next_locations = np.concatenate((next_locations, ref_locations), axis=0) + next_locations = location_nms(next_locations, thresh=16) + + locations = next_locations + num_patches = locations.shape[0] + + detections = np.concatenate(detections, axis=0) + classes = detections[..., -1] + + top_bboxes = {} + for j in range(categories): + keep_inds = (classes == j) + top_bboxes[j + 1] = detections[keep_inds][:, 0:7].astype(np.float32) + keep_inds = soft_nms(top_bboxes[j + 1], Nt=nms_threshold, method=nms_algorithm, sigma=0.7) + top_bboxes[j + 1] = top_bboxes[j + 1][keep_inds, 0:5] + + scores = np.hstack([top_bboxes[j][:, -1] for j in range(1, categories + 1)]) + if len(scores) > max_per_image: + kth = len(scores) - max_per_image + thresh = np.partition(scores, kth)[kth] + for j in range(1, categories + 1): + keep_inds = (top_bboxes[j][:, -1] >= thresh) + top_bboxes[j] = top_bboxes[j][keep_inds] + return top_bboxes diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000..9f60522 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,2 @@ +from .tqdm import stdout_to_tqdm +from .timer import Timer diff --git a/core/utils/timer.py b/core/utils/timer.py new file mode 100644 index 0000000..80c29de --- /dev/null +++ b/core/utils/timer.py @@ -0,0 +1,25 @@ +import time + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff diff --git a/core/utils/tqdm.py b/core/utils/tqdm.py new file mode 100755 index 0000000..334dfc1 --- /dev/null +++ b/core/utils/tqdm.py @@ -0,0 +1,25 @@ +import sys +import numpy as np +import contextlib + +from tqdm import tqdm + +class TqdmFile(object): + dummy_file = None + def __init__(self, dummy_file): + self.dummy_file = dummy_file + + def write(self, x): + if len(x.rstrip()) > 0: + tqdm.write(x, file=self.dummy_file) + +@contextlib.contextmanager +def stdout_to_tqdm(): + save_stdout = sys.stdout + try: + sys.stdout = TqdmFile(sys.stdout) + yield save_stdout + except Exception as exc: + raise exc + finally: + sys.stdout = save_stdout diff --git a/core/vis_utils.py b/core/vis_utils.py new file mode 100644 index 0000000..9cbe169 --- /dev/null +++ b/core/vis_utils.py @@ -0,0 +1,62 @@ +import cv2 +import numpy as np + +def draw_bboxes(image, bboxes, font_size=0.5, thresh=0.5, colors=None): + """Draws bounding boxes on an image. + + Args: + image: An image in OpenCV format + bboxes: A dictionary representing bounding boxes of different object + categories, where the keys are the names of the categories and the + values are the bounding boxes. The bounding boxes of category should be + stored in a 2D NumPy array, where each row is a bounding box (x1, y1, + x2, y2, score). + font_size: (Optional) Font size of the category names. + thresh: (Optional) Only bounding boxes with scores above the threshold + will be drawn. + colors: (Optional) Color of bounding boxes for each category. If it is + not provided, this function will use random color for each category. + + Returns: + An image with bounding boxes. + """ + + image = image.copy() + for cat_name in bboxes: + keep_inds = bboxes[cat_name][:, -1] > thresh + cat_size = cv2.getTextSize(cat_name, cv2.FONT_HERSHEY_SIMPLEX, font_size, 2)[0] + + if colors is None: + color = np.random.random((3, )) * 0.6 + 0.4 + color = (color * 255).astype(np.int32).tolist() + else: + color = colors[cat_name] + + for bbox in bboxes[cat_name][keep_inds]: + bbox = bbox[0:4].astype(np.int32) + if bbox[1] - cat_size[1] - 2 < 0: + cv2.rectangle(image, + (bbox[0], bbox[1] + 2), + (bbox[0] + cat_size[0], bbox[1] + cat_size[1] + 2), + color, -1 + ) + cv2.putText(image, cat_name, + (bbox[0], bbox[1] + cat_size[1] + 2), + cv2.FONT_HERSHEY_SIMPLEX, font_size, (0, 0, 0), thickness=1 + ) + else: + cv2.rectangle(image, + (bbox[0], bbox[1] - cat_size[1] - 2), + (bbox[0] + cat_size[0], bbox[1] - 2), + color, -1 + ) + cv2.putText(image, cat_name, + (bbox[0], bbox[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, font_size, (0, 0, 0), thickness=1 + ) + cv2.rectangle(image, + (bbox[0], bbox[1]), + (bbox[2], bbox[3]), + color, 2 + ) + return image diff --git a/demo.jpg b/demo.jpg new file mode 100644 index 0000000..fbf2b61 Binary files /dev/null and b/demo.jpg differ diff --git a/demo.py b/demo.py new file mode 100755 index 0000000..3fcdaf7 --- /dev/null +++ b/demo.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +import cv2 +from core.detectors import CornerNet_Saccade +from core.vis_utils import draw_bboxes + +detector = CornerNet_Saccade() +image = cv2.imread("demo.jpg") + +bboxes = detector(image) +image = draw_bboxes(image, bboxes) +cv2.imwrite("demo_out.jpg", image) diff --git a/evaluate.py b/evaluate.py new file mode 100755 index 0000000..fd276d4 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +import os +import json +import torch +import pprint +import argparse +import importlib + +from core.dbs import datasets +from core.test import test_func +from core.config import SystemConfig +from core.nnet.py_factory import NetworkFactory + +torch.backends.cudnn.benchmark = False + +def parse_args(): + parser = argparse.ArgumentParser(description="Evaluation Script") + parser.add_argument("cfg_file", help="config file", type=str) + parser.add_argument("--testiter", dest="testiter", + help="test at iteration i", + default=None, type=int) + parser.add_argument("--split", dest="split", + help="which split to use", + default="validation", type=str) + parser.add_argument("--suffix", dest="suffix", default=None, type=str) + parser.add_argument("--debug", action="store_true") + + args = parser.parse_args() + return args + +def make_dirs(directories): + for directory in directories: + if not os.path.exists(directory): + os.makedirs(directory) + +def test(db, system_config, model, args): + split = args.split + testiter = args.testiter + debug = args.debug + suffix = args.suffix + + result_dir = system_config.result_dir + result_dir = os.path.join(result_dir, str(testiter), split) + + if suffix is not None: + result_dir = os.path.join(result_dir, suffix) + + make_dirs([result_dir]) + + test_iter = system_config.max_iter if testiter is None else testiter + print("loading parameters at iteration: {}".format(test_iter)) + + print("building neural network...") + nnet = NetworkFactory(system_config, model) + print("loading parameters...") + nnet.load_params(test_iter) + + nnet.cuda() + nnet.eval_mode() + test_func(system_config, db, nnet, result_dir, debug=debug) + +def main(args): + if args.suffix is None: + cfg_file = os.path.join("./configs", args.cfg_file + ".json") + else: + cfg_file = os.path.join("./configs", args.cfg_file + "-{}.json".format(args.suffix)) + print("cfg_file: {}".format(cfg_file)) + + with open(cfg_file, "r") as f: + config = json.load(f) + + config["system"]["snapshot_name"] = args.cfg_file + system_config = SystemConfig().update_config(config["system"]) + + model_file = "core.models.{}".format(args.cfg_file) + model_file = importlib.import_module(model_file) + model = model_file.model() + + train_split = system_config.train_split + val_split = system_config.val_split + test_split = system_config.test_split + + split = { + "training": train_split, + "validation": val_split, + "testing": test_split + }[args.split] + + print("loading all datasets...") + dataset = system_config.dataset + print("split: {}".format(split)) + testing_db = datasets[dataset](config["db"], split=split, sys_config=system_config) + + print("system config...") + pprint.pprint(system_config.full) + + print("db config...") + pprint.pprint(testing_db.configs) + + test(testing_db, system_config, model, args) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/train.py b/train.py new file mode 100755 index 0000000..88ad348 --- /dev/null +++ b/train.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python +import os +import json +import torch +import numpy as np +import queue +import pprint +import random +import argparse +import importlib +import threading +import traceback +import torch.distributed as dist +import torch.multiprocessing as mp + +from tqdm import tqdm +from torch.multiprocessing import Process, Queue, Pool + +from core.dbs import datasets +from core.utils import stdout_to_tqdm +from core.config import SystemConfig +from core.sample import data_sampling_func +from core.nnet.py_factory import NetworkFactory + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True + +def parse_args(): + parser = argparse.ArgumentParser(description="Training Script") + parser.add_argument("cfg_file", help="config file", type=str) + parser.add_argument("--iter", dest="start_iter", + help="train at iteration i", + default=0, type=int) + parser.add_argument("--workers", default=4, type=int) + parser.add_argument("--initialize", action="store_true") + + parser.add_argument("--distributed", action="store_true") + parser.add_argument("--world-size", default=-1, type=int, + help="number of nodes of distributed training") + parser.add_argument("--rank", default=0, type=int, + help="node rank for distributed training") + parser.add_argument("--dist-url", default=None, type=str, + help="url used to set up distributed training") + parser.add_argument("--dist-backend", default="nccl", type=str) + + args = parser.parse_args() + return args + +def prefetch_data(system_config, db, queue, sample_data, data_aug): + ind = 0 + print("start prefetching data...") + np.random.seed(os.getpid()) + while True: + try: + data, ind = sample_data(system_config, db, ind, data_aug=data_aug) + queue.put(data) + except Exception as e: + traceback.print_exc() + raise e + +def _pin_memory(ts): + if type(ts) is list: + return [t.pin_memory() for t in ts] + return ts.pin_memory() + +def pin_memory(data_queue, pinned_data_queue, sema): + while True: + data = data_queue.get() + + data["xs"] = [_pin_memory(x) for x in data["xs"]] + data["ys"] = [_pin_memory(y) for y in data["ys"]] + + pinned_data_queue.put(data) + + if sema.acquire(blocking=False): + return + +def init_parallel_jobs(system_config, dbs, queue, fn, data_aug): + tasks = [Process(target=prefetch_data, args=(system_config, db, queue, fn, data_aug)) for db in dbs] + for task in tasks: + task.daemon = True + task.start() + return tasks + +def terminate_tasks(tasks): + for task in tasks: + task.terminate() + +def train(training_dbs, validation_db, system_config, model, args): + # reading arguments from command + start_iter = args.start_iter + distributed = args.distributed + world_size = args.world_size + initialize = args.initialize + gpu = args.gpu + rank = args.rank + + # reading arguments from json file + batch_size = system_config.batch_size + learning_rate = system_config.learning_rate + max_iteration = system_config.max_iter + pretrained_model = system_config.pretrain + stepsize = system_config.stepsize + snapshot = system_config.snapshot + val_iter = system_config.val_iter + display = system_config.display + decay_rate = system_config.decay_rate + stepsize = system_config.stepsize + + print("Process {}: building model...".format(rank)) + nnet = NetworkFactory(system_config, model, distributed=distributed, gpu=gpu) + if initialize: + nnet.save_params(0) + exit(0) + + # queues storing data for training + training_queue = Queue(system_config.prefetch_size) + validation_queue = Queue(5) + + # queues storing pinned data for training + pinned_training_queue = queue.Queue(system_config.prefetch_size) + pinned_validation_queue = queue.Queue(5) + + # allocating resources for parallel reading + training_tasks = init_parallel_jobs(system_config, training_dbs, training_queue, data_sampling_func, True) + if val_iter: + validation_tasks = init_parallel_jobs(system_config, [validation_db], validation_queue, data_sampling_func, False) + + training_pin_semaphore = threading.Semaphore() + validation_pin_semaphore = threading.Semaphore() + training_pin_semaphore.acquire() + validation_pin_semaphore.acquire() + + training_pin_args = (training_queue, pinned_training_queue, training_pin_semaphore) + training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args) + training_pin_thread.daemon = True + training_pin_thread.start() + + validation_pin_args = (validation_queue, pinned_validation_queue, validation_pin_semaphore) + validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args) + validation_pin_thread.daemon = True + validation_pin_thread.start() + + if pretrained_model is not None: + if not os.path.exists(pretrained_model): + raise ValueError("pretrained model does not exist") + print("Process {}: loading from pretrained model".format(rank)) + nnet.load_pretrained_params(pretrained_model) + + if start_iter: + nnet.load_params(start_iter) + learning_rate /= (decay_rate ** (start_iter // stepsize)) + nnet.set_lr(learning_rate) + print("Process {}: training starts from iteration {} with learning_rate {}".format(rank, start_iter + 1, learning_rate)) + else: + nnet.set_lr(learning_rate) + + if rank == 0: + print("training start...") + nnet.cuda() + nnet.train_mode() + with stdout_to_tqdm() as save_stdout: + for iteration in tqdm(range(start_iter + 1, max_iteration + 1), file=save_stdout, ncols=80): + training = pinned_training_queue.get(block=True) + training_loss = nnet.train(**training) + + if display and iteration % display == 0: + print("Process {}: training loss at iteration {}: {}".format(rank, iteration, training_loss.item())) + del training_loss + + if val_iter and validation_db.db_inds.size and iteration % val_iter == 0: + nnet.eval_mode() + validation = pinned_validation_queue.get(block=True) + validation_loss = nnet.validate(**validation) + print("Process {}: validation loss at iteration {}: {}".format(rank, iteration, validation_loss.item())) + nnet.train_mode() + + if iteration % snapshot == 0 and rank == 0: + nnet.save_params(iteration) + + if iteration % stepsize == 0: + learning_rate /= decay_rate + nnet.set_lr(learning_rate) + + # sending signal to kill the thread + training_pin_semaphore.release() + validation_pin_semaphore.release() + + # terminating data fetching processes + terminate_tasks(training_tasks) + terminate_tasks(validation_tasks) + +def main(gpu, ngpus_per_node, args): + args.gpu = gpu + if args.distributed: + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + rank = args.rank + + cfg_file = os.path.join("./configs", args.cfg_file + ".json") + with open(cfg_file, "r") as f: + config = json.load(f) + + config["system"]["snapshot_name"] = args.cfg_file + system_config = SystemConfig().update_config(config["system"]) + + model_file = "core.models.{}".format(args.cfg_file) + model_file = importlib.import_module(model_file) + model = model_file.model() + + train_split = system_config.train_split + val_split = system_config.val_split + + print("Process {}: loading all datasets...".format(rank)) + dataset = system_config.dataset + workers = args.workers + print("Process {}: using {} workers".format(rank, workers)) + training_dbs = [datasets[dataset](config["db"], split=train_split, sys_config=system_config) for _ in range(workers)] + validation_db = datasets[dataset](config["db"], split=val_split, sys_config=system_config) + + if rank == 0: + print("system config...") + pprint.pprint(system_config.full) + + print("db config...") + pprint.pprint(training_dbs[0].configs) + + print("len of db: {}".format(len(training_dbs[0].db_inds))) + print("distributed: {}".format(args.distributed)) + + train(training_dbs, validation_db, system_config, model, args) + +if __name__ == "__main__": + args = parse_args() + + distributed = args.distributed + world_size = args.world_size + + if distributed and world_size < 0: + raise ValueError("world size must be greater than 0 in distributed training") + + ngpus_per_node = torch.cuda.device_count() + if distributed: + args.world_size = ngpus_per_node * args.world_size + mp.spawn(main, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + main(None, ngpus_per_node, args)