Skip to content

Commit

Permalink
Validation during training. Could be slow.
Browse files Browse the repository at this point in the history
  • Loading branch information
anatolix committed Mar 10, 2018
1 parent 7e79fd3 commit c130f05
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
30 changes: 14 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,42 +1,40 @@
### About this fork

This fork contains pure python version of **rmpe_dataset_server**.
It have less code(19kb vs 35kb), and significantly faster (140 images/s vs 30 images/s C++ code on my machine)
Could be run as iterator inside **train_pose.py** (default), or as **./rmpe_server.py**
This fork contains **pure python version** of [Realtime Multi-Person Pose Estimation](https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation). Initially it was forked from [Michal Faber fork](https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation), all credit for porting original work to Keras goes to him.

I this fork I've reimplemented images argumentation in pure python, it is significanly shorter(**285** lines vs **1202** lines in Michal Faber's C++ **rmpe_server**, and way less than in original work)

Despite of Python language this code is **significantly faster** than original implementation(140 images/s vs 30 images/s C++ code on my machine). This is not really useful since most of people doesn't have 5 GPUs, but has large hack value. The magic is in combining all affine transformations to one matrix, and calling single **warpAffine**, and vectorized numpy computation of PAFs and Heatmaps.

Could be run as iterator inside **train_pose.py** (default), or as separate **./rmpe_server.py**

#### Current status
- [x] image augmentation: rotate, shift, scale, crop, flip (implemented as single affine transform, i.e. much faster)
- [x] mask calculation: rotate, shift, scale, crop, flip
- [x] joint heatmaps
- [x] limbs part affinity fields
- [x] tested using rmpe_server_tester.py, found some differences from C++ version, but looks like it is C++ code is buggy

- [x] quality is same as original work and bit better than Michal's version.

#### How to help
- re-generate val_dataset.h5 with new version of generate_hdf5.py (will be backward compatible, just one attribute 'meta' added)
- since augmentation is very fast now, by default it works inside train_pose.py (separate thread)
- if you want to run external augmentation server run ./rmpe_server.py and change use_client_gen = True in train_pose.py
- test result with **inspect_dataset.ipynb** or **rmpe_server_tester.py raw save** (saves all images, heatmaps and PAFs to disk)
- look to the code and give feedback
- try to train
#### Current work
- [ ] Ability to easily modify config and train different models. See addins submodule for head detector example and example how to add new datasets(MPII, Brainwash)


# Realtime Multi-Person Pose Estimation
This is a keras version of [Realtime Multi-Person Pose Estimation](https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation) project
This is a keras version of project

## Introduction
Code repo for reproducing [2017 CVPR](https://arxiv.org/abs/1611.08050) paper using keras.

## Results

<p align="center">
<img src="https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/readme/dance.gif", width="720">
<img src="https://github.com/anatolix/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/readme/dance.gif", width="720">
</p>

<div align="center">
<img src="https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/sample_images/ski.jpg", width="300", height="300">
<img src="https://github.com/anatolix/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/sample_images/ski.jpg", width="300", height="300">
&nbsp;
<img src="https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/readme/result.png", width="300", height="300">
<img src="https://github.com/anatolix/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/readme/result.png", width="300", height="300">
</div>


Expand Down
55 changes: 35 additions & 20 deletions training/train_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,39 +160,54 @@ def train(config, model, train_client, val_client, iterations_per_epoch, validat
def validate(config, model, val_client, validation_steps, metrics_id, epoch):

val_di = val_client.gen()
from keras.utils import GeneratorEnqueuer

for i in range(validation_steps):
val_thre = GeneratorEnqueuer(val_di)
val_thre.start()

model_metrics = []
inhouse_metrics = []

metrics = []
for i in range(validation_steps):

X, GT = next(val_di)
X, GT = next(val_thre.get())

Y = model.predict(X)

if config.paf_layers > 0 and config.heat_layers > 0:
GT = np.concatenate([GT[-2], GT[-1]], axis=3)
Y = np.concatenate([Y[-2], Y[-1]], axis=3)
model_losses = [ (np.sum((gt - y) ** 2) / gt.shape[0] / 2) for gt, y in zip(GT,Y) ]
mm = sum(model_losses)

if config.paf_layers > 0 and config.heat_layers > 0:
GTL6 = np.concatenate([GT[-2], GT[-1]], axis=3)
YL6 = np.concatenate([Y[-2], Y[-1]], axis=3)
mm6l1 = model_losses[-2]
mm6l2 = model_losses[-1]
elif config.paf_layers == 0 and config.heat_layers > 0:
GT = GT[-1]
Y = Y[-1]
GTL6 = GT[-1]
YL6 = Y[-1]
mm6l1 = None
mm6l2 = model_losses[-1]
else:
assert False, "Wtf or not implemented"

m = calc_batch_metrics(i, GT, Y, range(config.heat_start, config.bkg_start))
metrics.append(m)
print("Validating[BATCH: %d] MAE: %0.4f, RMSE: %0.4f, DIST: %0.2f" % (i,m["MAE"].mean(), m["RMSE"].mean(),m["DIST"].mean()))

metrics = pd.concat(metrics)
metrics['epoch']=epoch
metrics.to_csv("logs/val_scores.%s.%04d.txt" % (metrics_id, epoch), sep="\t")
del metrics["batch"]
del metrics["item"]
del metrics["layer"]
metrics = metrics.groupby(["epoch"]).mean()
m = calc_batch_metrics(i, GTL6, YL6, range(config.heat_start, config.bkg_start))
inhouse_metrics += [m]

model_metrics += [ (i, mm, mm6l1, mm6l2, m["MAE"].sum()/GTL6.shape[0], m["RMSE"].sum()/GTL6.shape[0], m["DIST"].mean()) ]
print("Validating[BATCH: %d] LOSS: %0.4f, S6L1: %0.4f, S6L2: %0.4f, MAE: %0.4f, RMSE: %0.4f, DIST: %0.2f" % model_metrics[-1] )

inhouse_metrics = pd.concat(inhouse_metrics)
inhouse_metrics['epoch']=epoch
inhouse_metrics.to_csv("logs/val_scores.%s.%04d.txt" % (metrics_id, epoch), sep="\t")

model_metrics = pd.DataFrame(model_metrics, columns=("batch","loss","stage6l1","stage6l2","mae","rmse","dist") )
model_metrics['epoch']=epoch
del model_metrics['batch']
model_metrics = model_metrics.groupby('epoch').mean()
with open('%s.val.tsv' % metrics_id, 'a') as f:
metrics.to_csv(f, header=(epoch==1), sep="\t")
model_metrics.to_csv(f, header=(epoch==1), sep="\t", float_format='%.4f')

val_thre.stop()

def save_network_input_output(model, val_client, validation_steps, metrics_id, batch_size, epoch=None):

Expand Down

0 comments on commit c130f05

Please sign in to comment.