forked from wushidiguo/hello-lottery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecognizer.py
69 lines (50 loc) · 1.92 KB
/
recognizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from utils import *
class Recognizer:
def __init__(
self,
model_file,
device = "cpu",
):
weights = torch.load(model_file)
self.model = weights["model"]
self.converter = weights["converter"]
self.opt = self.model.opt
self.imgH = self.opt.imgH
self.imgW = self.opt.imgW
self.input_channel = self.opt.input_channel
self.device = device
_ = self.model.to(device)
self.model.eval()
def __call__(
self,
imgs
):
results = []
transform = NormalizePAD((self.input_channel, self.imgH, self.imgW))
with torch.no_grad():
for img in imgs:
img = Image.fromarray(img).convert("L")
w, h = img.size
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
img = img.resize((resized_w, self.imgH), Image.BICUBIC)
img = transform(img)
img = img.unsqueeze(0)
img = img.to(self.device)
text_for_pred = torch.LongTensor(1, w // 10 + 1).fill_(0).to(self.device)
preds = self.model(img, text_for_pred)
preds_size = [preds.size(1)]
preds_prob = softmax(preds, dim=-1).squeeze().cpu().detach().numpy()
values = preds_prob.max(axis=-1)
indices = preds_prob.argmax(axis=-1)
preds_str = self.converter.decode_greedy(indices.ravel(), preds_size)[0]
confidence_score = custom_mean(values)
results.append([preds_str, confidence_score])
return results