Skip to content

Commit

Permalink
evaluate accuracy when training
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaOvodov committed Nov 18, 2020
1 parent b08ef84 commit 328b0e0
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 34 deletions.
61 changes: 34 additions & 27 deletions model/infer_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ class OrientationAttempts(enum.IntEnum):


class BraileInferenceImpl(torch.nn.Module):
def __init__(self, params, model_weights_fn, label_is_valid, verbose=1):
def __init__(self, params, model, device, label_is_valid, verbose=1):
super(BraileInferenceImpl, self).__init__()
self.verbose = verbose
self.model_weights_fn = model_weights_fn

#self.model = model
self.model, _, _ = create_model_retinanet.create_model_retinanet(params, device=device)
self.model = self.model.to(device)
self.model.load_state_dict(torch.load(self.model_weights_fn, map_location = 'cpu'))
self.device = device
if isinstance(model, torch.nn.Module):
self.model_weights_fn = ""
self.model = model
else:
self.model_weights_fn = model
self.model, _, _ = create_model_retinanet.create_model_retinanet(params, device=device)
self.model = self.model.to(device)
self.model.load_state_dict(torch.load(self.model_weights_fn, map_location = 'cpu'))
self.model.eval()
#self.model = torch.jit.script(self.model)

Expand Down Expand Up @@ -134,7 +137,7 @@ def forward(self, input_tensor, input_tensor_rotated, find_orientation, process_
else:
best_idx, err_score = OrientationAttempts.NONE, (torch.tensor([0.]),torch.tensor([0.]),torch.tensor([0.]))
if self.verbose >= 2:
torch.cuda.synchronize(device)
torch.cuda.synchronize(self.device)

if best_idx in [OrientationAttempts.INV, OrientationAttempts.INV_ROT180, OrientationAttempts.INV_ROT90, OrientationAttempts.INV_ROT270]:
best_idx -= 2
Expand Down Expand Up @@ -169,7 +172,8 @@ class BrailleInference:
DRAW_BOTH = DRAW_ORIGINAL | DRAW_REFINED # 3
DRAW_FULL_CHARS = 4

def __init__(self, params_fn=params_fn, model_weights_fn=model_weights_fn, create_script = None, verbose=1, inference_width=inference_width):
def __init__(self, params_fn=params_fn, model_weights_fn=model_weights_fn, create_script = None,
verbose=1, inference_width=inference_width, device=device):
self.verbose = verbose
params = AttrDict.load(params_fn, verbose=verbose)
params.data.net_hw = (inference_width,inference_width,) #(512,768) ###### (1024,1536) #
Expand All @@ -180,23 +184,26 @@ def __init__(self, params_fn=params_fn, model_weights_fn=model_weights_fn, creat
rotate_limit=0,
)
self.preprocessor = data.ImagePreprocessor(params, mode = 'inference')
model_script_fn = model_weights_fn + '.pth'

if create_script != False:
self.impl = BraileInferenceImpl(params, model_weights_fn, lt.label_is_valid, verbose=verbose).to(device)
if create_script is not None:
self.impl = torch.jit.script(self.impl)
if isinstance(self.impl, torch.jit.ScriptModule):
torch.jit.save(self.impl, model_script_fn)
if verbose >= 1:
print("Model loaded and saved to " + model_script_fn)

if isinstance(model_weights_fn, torch.nn.Module):
self.impl = BraileInferenceImpl(params, model_weights_fn, device, lt.label_is_valid, verbose=verbose)
else:
model_script_fn = model_weights_fn + '.pth'
if create_script != False:
self.impl = BraileInferenceImpl(params, model_weights_fn, device, lt.label_is_valid, verbose=verbose)
if create_script is not None:
self.impl = torch.jit.script(self.impl)
if isinstance(self.impl, torch.jit.ScriptModule):
torch.jit.save(self.impl, model_script_fn)
if verbose >= 1:
print("Model loaded and saved to " + model_script_fn)
else:
if verbose >= 1:
print("Model loaded")
else:
self.impl = torch.jit.load(model_script_fn)
if verbose >= 1:
print("Model loaded")
else:
self.impl = torch.jit.load(model_script_fn)
if verbose >= 1:
print("Model pth loaded")
print("Model pth loaded")
self.impl.to(device)

def load_pdf(self, img_fn):
Expand Down Expand Up @@ -292,15 +299,15 @@ def run_impl(self, img, lang, draw_refined, find_orientation, process_2_sides, a
np_img = np.asarray(img)
aug_img, aug_gt_rects = self.preprocessor.preprocess_and_augment(np_img, gt_rects)
aug_img = data.unify_shape(aug_img)
input_tensor = self.preprocessor.to_normalized_tensor(aug_img).to(device)
input_tensor_rotated = torch.tensor(0).to(device)
input_tensor = self.preprocessor.to_normalized_tensor(aug_img).to(self.impl.device)
input_tensor_rotated = torch.tensor(0).to(self.impl.device)

aug_img_rot = None
if find_orientation:
np_img_rot = np.rot90(np_img, 1, (0,1))
aug_img_rot = self.preprocessor.preprocess_and_augment(np_img_rot)[0]
aug_img_rot = data.unify_shape(aug_img_rot)
input_tensor_rotated = self.preprocessor.to_normalized_tensor(aug_img_rot).to(device)
input_tensor_rotated = self.preprocessor.to_normalized_tensor(aug_img_rot).to(self.impl.device)

if self.verbose >= 2:
print(" run_impl.make_batch", time.clock() - t)
Expand Down
12 changes: 12 additions & 0 deletions model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
sys.path.append(local_config.global_3rd_party)
from collections import OrderedDict
import os
import torch
import ignite
from ignite.engine import Events
Expand All @@ -20,6 +21,7 @@
from data_utils import data
from model import create_model_retinanet
from model.params import params, settings
import model.validate_retinanet as validate_retinanet

if settings.findLR:
params.model_name += '_findLR'
Expand Down Expand Up @@ -109,6 +111,16 @@ def lr_scheduler_step(engine):
engine.state.metrics['lr'] = ctx.optimizer.param_groups[0]['lr']
ctx.lr_scheduler.step(**call_params)

@trainer.on(Events.EPOCH_COMPLETED)
def eval_accuracy(engine):
if engine.state.epoch % 100 == 1:
data_set = validate_retinanet.prepare_data(ctx.params.data.val_list_file_names)
for key, data_list in data_set.items():
acc_res = validate_retinanet.evaluate_accuracy(os.path.join(ctx.params.get_base_filename(), 'param.txt'),
model, settings.device, data_list)
for rk, rv in acc_res.items():
engine.state.metrics[key+ ':' + rk] = rv

#@trainer.on(Events.EPOCH_COMPLETED)
#def save_model(engine):
# if save_every and (engine.state.epoch % save_every) == 0:
Expand Down
66 changes: 59 additions & 7 deletions model/validate_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

# Для отладки
inference_width = 850
verbose = 2
verbose = 0

models = [
#('NN_results/dsbi_tst_as_fcdca3_c63909', 'models/clr.099.t7'),
#
('NN_results/dsbi_lay0_526290', 'models/clr.006.t7'),
('NN_results/dsbi_lay1_047eca', 'models/clr.007.t7'),
('NN_results/dsbi_lay3_c4ca62', 'models/clr.005.t7'),
('NN_results/dsbi_lay3_c4ca62', 'models/clr.006.t7'),
('NN_results/dsbi_lay5_0d8197', 'models/clr.006.t7'),
]

model_dirs = [
Expand All @@ -28,6 +26,7 @@
# r'DSBI\data\test.txt',
# ],
'val': [r'DSBI/data/val_li2.txt', ],
'test': [r'DSBI/data/test_li2.txt', ],
}

lang = 'RU'
Expand All @@ -54,10 +53,8 @@
(str(md[0]), str(Path('models')/m.name))
for m in (Path(local_config.data_path)/md[0]).glob(md[1])
]
for m in models:
print(m)

def prepare_data():
def prepare_data(datasets=datasets):
"""
data (datasets defined above as global) -> dict: key - list of dict (image_fn":full image filename, "gt_text": groundtruth pseudotext, "gt_rects": groundtruth rects + label 0..64)
:return:
Expand Down Expand Up @@ -474,9 +471,64 @@ def validate_model(recognizer, data_list, do_filter_lonely_rects, metrics_for_li
'd_by_char_avg': sum_d1/len(data_list)
}

def evaluate_accuracy(params_fn, model, device, data_list, do_filter_lonely_rects = False, metrics_for_lines = True):
"""
:param recognizer: infer_retinanet.BrailleInference instance
:param data_list: list of (image filename, groundtruth pseudotext)
:return: (<distance> avg. by documents, <distance> avg. by char, <<distance> avg. by char> avg. by documents>)
"""
# по символам
recognizer = infer_retinanet.BrailleInference(
params_fn=params_fn,
model_weights_fn=model,
create_script=None,
inference_width=inference_width,
device=device,
verbose=verbose)

tp_c = 0
fp_c = 0
fn_c = 0
for gt_dict in data_list:
img_fn, gt_text, gt_rects = gt_dict['image_fn'], gt_dict['gt_text'], gt_dict['gt_rects']
res_dict = recognizer.run(img_fn,
lang=lang,
draw_refined=infer_retinanet.BrailleInference.DRAW_NONE,
find_orientation=False,
process_2_sides=False,
align_results=False,
repeat_on_aligned=False,
gt_rects=gt_rects)
lines = res_dict['lines']
if do_filter_lonely_rects:
lines, filtered_chars = postprocess.filter_lonely_rects_for_lines(lines)
if metrics_for_lines:
boxes = []
labels = []
for ln in lines:
boxes += [ch.refined_box for ch in ln.chars]
labels += [ch.label for ch in ln.chars]
else:
boxes = res_dict['boxes']
labels = res_dict['labels']
tpi, fpi, fni = char_metrics_rects(boxes = boxes, labels = labels,
gt_rects = res_dict['gt_rects'], image_wh = (res_dict['labeled_image'].width, res_dict['labeled_image'].height),
img=None, do_filter_lonely_rects=do_filter_lonely_rects)
tp_c += tpi
fp_c += fpi
fn_c += fni
precision_c = tp_c/(tp_c+fp_c) if tp_c+fp_c != 0 else 0.
recall_c = tp_c/(tp_c+fn_c) if tp_c+fn_c != 0 else 0.
return {
'precision': precision_c,
'recall': recall_c,
'f1': 2*precision_c*recall_c/(precision_c+recall_c) if precision_c+recall_c != 0 else 0.,
}

def main(table_like_format):
# make data list
for m in models:
print(m)
data_set = prepare_data()
prev_model_root = None

Expand Down

0 comments on commit 328b0e0

Please sign in to comment.