Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add foot keypoints #526

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion alphapose/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .concat_dataset import ConcatDataset
from .custom import CustomDataset
from .mscoco import Mscoco
from .mscoco_with_foot import Mscoco_with_foot
from .coco_det_with_foot import Mscoco_det_with_foot
from .mpii import Mpii

__all__ = ['CustomDataset', 'Mscoco', 'Mscoco_det', 'Mpii', 'ConcatDataset']
__all__ = ['CustomDataset', 'Mscoco', 'Mscoco_det', 'Mpii', 'ConcatDataset', 'Mscoco_with_foot', 'Mscoco_det_with_foot']
108 changes: 108 additions & 0 deletions alphapose/datasets/coco_det_with_foot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# Written by Jiefeng Li ([email protected])
# -----------------------------------------------------

"""MS COCO Human Detection Box dataset. (including foot keypoints)"""
import json
import os

import scipy.misc
import torch
import torch.utils.data as data
from tqdm import tqdm

from alphapose.utils.presets import SimpleTransform
from detector.apis import get_detector
from alphapose.models.builder import DATASET


@DATASET.register_module
class Mscoco_det_with_foot(data.Dataset):
""" COCO human detection box dataset.

"""
EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]

def __init__(self,
det_file=None,
opt=None,
**cfg):

self._cfg = cfg
self._opt = opt
self._preset_cfg = cfg['PRESET']
self._root = cfg['ROOT']
self._img_prefix = cfg['IMG_PREFIX']
if not det_file:
det_file = cfg['DET_FILE']
self._ann_file = os.path.join(self._root, cfg['ANN'])

if os.path.exists(det_file):
print("Detection results exist, will use it")
else:
print("Will create detection results to {}".format(det_file))
self.write_coco_json(det_file)

assert os.path.exists(det_file), "Error: no detection results found"
with open(det_file, 'r') as fid:
self._det_json = json.load(fid)

self._input_size = self._preset_cfg['IMAGE_SIZE']
self._output_size = self._preset_cfg['HEATMAP_SIZE']

self._sigma = self._preset_cfg['SIGMA']

if self._preset_cfg['TYPE'] == 'simple':
self.transformation = SimpleTransform(
self, scale_factor=0,
input_size=self._input_size,
output_size=self._output_size,
rot=0, sigma=self._sigma,
train=False, add_dpg=False)

def __getitem__(self, index):
det_res = self._det_json[index]
if not isinstance(det_res['image_id'], int):
img_id, _ = os.path.splitext(os.path.basename(det_res['image_id']))
img_id = int(img_id)
else:
img_id = det_res['image_id']
img_path = './data/coco/trainval2017/%012d.jpg' % img_id

# Load image
image = scipy.misc.imread(img_path, mode='RGB')

imght, imgwidth = image.shape[1], image.shape[2]
x1, y1, w, h = det_res['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
inp, bbox = self.transformation.test_transform(image, bbox)
return inp, torch.Tensor(bbox), torch.Tensor([det_res['bbox']]), torch.Tensor([det_res['image_id']]), torch.Tensor([det_res['score']]), torch.Tensor([imght]), torch.Tensor([imgwidth])

def __len__(self):
return len(self._det_json)

def write_coco_json(self, det_file):
from pycocotools.coco import COCO
import pathlib

_coco = COCO(self._ann_file)
image_ids = sorted(_coco.getImgIds())
det_model = get_detector(self._opt)
dets = []
for entry in tqdm(_coco.loadImgs(image_ids)):
abs_path = os.path.join(
self._root, self._img_prefix, entry['file_name'])
det = det_model.detect_one_img(abs_path)
if det:
dets += det
pathlib.Path(os.path.split(det_file)[0]).mkdir(parents=True, exist_ok=True)
json.dump(dets, open(det_file, 'w'))

@property
def joint_pairs(self):
"""Joint pairs which defines the pairs of joint to be swapped
when the image is flipped horizontally."""
return [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16],
[17, 20], [18, 21], [19, 22]]
150 changes: 150 additions & 0 deletions alphapose/datasets/mscoco_with_foot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# Written by Jiefeng Li ([email protected])
# -----------------------------------------------------

"""MS COCO Human keypoint dataset. (including foot keypoints)"""
import os

import numpy as np

from alphapose.models.builder import DATASET
from alphapose.utils.bbox import bbox_clip_xyxy, bbox_xywh_to_xyxy

from .custom import CustomDataset


@DATASET.register_module
class Mscoco_with_foot(CustomDataset):
""" COCO Person dataset.

Parameters
----------
train: bool, default is True
If true, will set as training mode.
skip_empty: bool, default is False
Whether skip entire image if no valid label is found. Use `False` if this dataset is
for validation to avoid COCO metric error.
dpg: bool, default is False
If true, will activate `dpg` for data augmentation.
"""
CLASSES = ['person']
EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
num_joints = 23
""" 0 ~ 16: origin person keypoints; 17 ~ 22: foot keypoints.
17: left big toe, 18: left small toe, 19: left heel;
20: right big toe, 21: right small toe, 22: right heel. """
CustomDataset.lower_body_ids = (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)

@property
def joint_pairs(self):
"""Joint pairs which defines the pairs of joint to be swapped
when the image is flipped horizontally."""
return [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16],
[17, 20], [18, 21], [19, 22]]

def _load_jsons(self):
"""Load all image paths and labels from JSON annotation files into buffer."""
items = []
labels = []

_coco = self._lazy_load_ann_file()
classes = [c['name'] for c in _coco.loadCats(_coco.getCatIds())]
assert classes == self.CLASSES, "Incompatible category names with COCO. "

self.json_id_to_contiguous = {
v: k for k, v in enumerate(_coco.getCatIds())}

# iterate through the annotations
image_ids = sorted(_coco.getImgIds())
for entry in _coco.loadImgs(image_ids):
dirname, filename = entry['coco_url'].split('/')[-2:]
abs_path = os.path.join(self._root, dirname, filename)
if not os.path.exists(abs_path):
raise IOError('Image: {} not exists.'.format(abs_path))
label = self._check_load_keypoints(_coco, entry)
if not label:
continue

# num of items are relative to person, not image
for obj in label:
items.append(abs_path)
labels.append(obj)

return items, labels

def _check_load_keypoints(self, coco, entry):
"""Check and load ground-truth keypoints"""
ann_ids = coco.getAnnIds(imgIds=entry['id'], iscrowd=False)
objs = coco.loadAnns(ann_ids)
# check valid bboxes
valid_objs = []
width = entry['width']
height = entry['height']

for obj in objs:
contiguous_cid = self.json_id_to_contiguous[obj['category_id']]
if contiguous_cid >= self.num_class:
# not class of interest
continue
if max(obj['keypoints']) == 0:
continue
# convert from (x, y, w, h) to (xmin, ymin, xmax, ymax) and clip bound
xmin, ymin, xmax, ymax = bbox_clip_xyxy(bbox_xywh_to_xyxy(obj['bbox']), width, height)
# require non-zero box area
if obj['area'] <= 0 or xmax <= xmin or ymax <= ymin:
continue
if obj['num_keypoints'] == 0:
continue
# joints 3d: (num_joints, 3, 2); 3 is for x, y, z; 2 is for position, visibility
joints_3d = np.zeros((self.num_joints, 3, 2), dtype=np.float32)
for i in range(self.num_joints):
joints_3d[i, 0, 0] = obj['keypoints'][i * 3 + 0]
joints_3d[i, 1, 0] = obj['keypoints'][i * 3 + 1]
# joints_3d[i, 2, 0] = 0
visible = min(1, obj['keypoints'][i * 3 + 2])
joints_3d[i, :2, 1] = visible
# joints_3d[i, 2, 1] = 0

if np.sum(joints_3d[:, 0, 1]) < 1:
# no visible keypoint
continue

if self._check_centers and self._train:
bbox_center, bbox_area = self._get_box_center_area((xmin, ymin, xmax, ymax))
kp_center, num_vis = self._get_keypoints_center_count(joints_3d)
ks = np.exp(-2 * np.sum(np.square(bbox_center - kp_center)) / bbox_area)
if (num_vis / 80.0 + 47 / 80.0) > ks:
continue

valid_objs.append({
'bbox': (xmin, ymin, xmax, ymax),
'width': width,
'height': height,
'joints_3d': joints_3d
})

if not valid_objs:
if not self._skip_empty:
# dummy invalid labels if no valid objects are found
valid_objs.append({
'bbox': np.array([-1, -1, 0, 0]),
'width': width,
'height': height,
'joints_3d': np.zeros((self.num_joints, 2, 2), dtype=np.float32)
})
return valid_objs

def _get_box_center_area(self, bbox):
"""Get bbox center"""
c = np.array([(bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0])
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
return c, area

def _get_keypoints_center_count(self, keypoints):
"""Get geometric center of all keypoints"""
keypoint_x = np.sum(keypoints[:, 0, 0] * (keypoints[:, 0, 1] > 0))
keypoint_y = np.sum(keypoints[:, 1, 0] * (keypoints[:, 1, 1] > 0))
num = float(np.sum(keypoints[:, 0, 1]))
return np.array([keypoint_x / num, keypoint_y / num]), num
37 changes: 36 additions & 1 deletion alphapose/utils/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,24 @@ def vis_frame_fast(frame, im_res, add_bbox=False, format='coco'):
(77, 255, 222), (77, 196, 255), (77, 135, 255), (191, 255, 77), (77, 255, 77),
(77, 222, 255), (255, 156, 127),
(0, 127, 255), (255, 127, 77), (0, 77, 255), (255, 77, 36)]
elif format == 'coco_with_foot':
l_pair = [
(0, 1), (0, 2), (1, 3), (2, 4), # Head
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(23, 11), (23, 12), # Body
(11, 13), (12, 14), (13, 15), (14, 16),
(17, 19), (18, 19), (20, 22), (21, 22), # Foot
(15, 19), (16, 22)] # Foot
p_color = [(0, 255, 255), (0, 191, 255), (0, 255, 102), (0, 77, 255), (0, 255, 0), # Nose, LEye, REye, LEar, REar
(77, 255, 255), (77, 255, 204), (77, 204, 255), (191, 255, 77), (77, 191, 255), (191, 255, 77), # LShoulder, RShoulder, LElbow, RElbow, LWrist, RWrist
(204, 77, 255), (77, 255, 204), (191, 77, 255), (77, 255, 191), (127, 77, 255), (77, 255, 127), (0, 255, 255), # LHip, RHip, LKnee, Rknee, LAnkle, RAnkle, Neck
(0, 255, 255), (0, 191, 255), (0, 255, 102), (0, 77, 255), (0, 255, 0), (77, 255, 255)] # Foot
line_color = [(0, 215, 255), (0, 255, 204), (0, 134, 255), (0, 255, 50),
(77, 255, 222), (77, 196, 255), (77, 135, 255), (191, 255, 77), (77, 255, 77),
(77, 222, 255), (255, 156, 127),
(0, 127, 255), (255, 127, 77), (0, 77, 255), (255, 77, 36),
(0, 77, 255), (0, 77, 255), (0, 77, 255), (0, 77, 255),
(255, 156, 127), (255, 156, 127)] # Foot
elif format == 'mpii':
l_pair = [
(8, 9), (11, 12), (11, 10), (2, 1), (1, 0),
Expand Down Expand Up @@ -199,14 +217,31 @@ def vis_frame(frame, im_res, add_bbox=False, format='coco'):
(17, 11), (17, 12), # Body
(11, 13), (12, 14), (13, 15), (14, 16)
]

p_color = [(0, 255, 255), (0, 191, 255), (0, 255, 102), (0, 77, 255), (0, 255, 0), # Nose, LEye, REye, LEar, REar
(77, 255, 255), (77, 255, 204), (77, 204, 255), (191, 255, 77), (77, 191, 255), (191, 255, 77), # LShoulder, RShoulder, LElbow, RElbow, LWrist, RWrist
(204, 77, 255), (77, 255, 204), (191, 77, 255), (77, 255, 191), (127, 77, 255), (77, 255, 127), (0, 255, 255)] # LHip, RHip, LKnee, Rknee, LAnkle, RAnkle, Neck
line_color = [(0, 215, 255), (0, 255, 204), (0, 134, 255), (0, 255, 50),
(77, 255, 222), (77, 196, 255), (77, 135, 255), (191, 255, 77), (77, 255, 77),
(77, 222, 255), (255, 156, 127),
(0, 127, 255), (255, 127, 77), (0, 77, 255), (255, 77, 36)]
elif format == 'coco_with_foot':
l_pair = [
(0, 1), (0, 2), (1, 3), (2, 4), # Head
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(23, 11), (23, 12), # Body
(11, 13), (12, 14), (13, 15), (14, 16),
(17, 19), (18, 19), (20, 22), (21, 22), # Foot
(15, 19), (16, 22)] # Foot
p_color = [(0, 255, 255), (0, 191, 255), (0, 255, 102), (0, 77, 255), (0, 255, 0), # Nose, LEye, REye, LEar, REar
(77, 255, 255), (77, 255, 204), (77, 204, 255), (191, 255, 77), (77, 191, 255), (191, 255, 77), # LShoulder, RShoulder, LElbow, RElbow, LWrist, RWrist
(204, 77, 255), (77, 255, 204), (191, 77, 255), (77, 255, 191), (127, 77, 255), (77, 255, 127), (0, 255, 255), # LHip, RHip, LKnee, Rknee, LAnkle, RAnkle, Neck
(0, 255, 255), (0, 191, 255), (0, 255, 102), (0, 77, 255), (0, 255, 0), (77, 255, 255)] # Foot
line_color = [(0, 215, 255), (0, 255, 204), (0, 134, 255), (0, 255, 50),
(77, 255, 222), (77, 196, 255), (77, 135, 255), (191, 255, 77), (77, 255, 77),
(77, 222, 255), (255, 156, 127),
(0, 127, 255), (255, 127, 77), (0, 77, 255), (255, 77, 36),
(0, 77, 255), (0, 77, 255), (0, 77, 255), (0, 77, 255),
(255, 156, 127), (255, 156, 127)] # Foot
elif format == 'mpii':
l_pair = [
(8, 9), (11, 12), (11, 10), (2, 1), (1, 0),
Expand Down
4 changes: 1 addition & 3 deletions alphapose/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
'frameSize': (640, 480)
}

EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]


class DataWriter():
def __init__(self, cfg, opt, save_video=False,
Expand All @@ -29,7 +27,7 @@ def __init__(self, cfg, opt, save_video=False,
self.opt = opt
self.video_save_opt = video_save_opt

self.eval_joints = EVAL_JOINTS
self.eval_joints = list(range(cfg.DATA_PRESET.NUM_JOINTS))
self.save_video = save_video
self.final_result = []
self.heatmap_to_coord = get_func_heatmap_to_coord(cfg)
Expand Down
Loading