-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathamaz_tester.py
47 lines (41 loc) · 1.46 KB
/
amaz_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import amaz_augumentationCustom
import amaz_datashaping
import amaz_augumentation
from chainer import serializers
import pickle
class Tester(object):
def __init__(self,model=None,dataset=None,dataaugumentation=amaz_augumentationCustom.Normalize128):
self.name = "tester"
self.model = model
self.dataset = dataset
self.dataaugumentation = dataaugumentation
self.model_file_path = "model/model_299.pkl"
self.meta = self.init_meta()
self.load_model()
self.xp = np
self.datashaping = amaz_datashaping.DataShaping(self.xp)
def init_meta(self):
meta = self.dataset["meta"]
return meta
def load_model(self):
print("loading model...")
with open(self.model_file_path, 'rb') as i:
self.model = pickle.load(i)
return
def executeOne(self,x):
x = amaz_augumentation.Augumentation().Z_score(x)
da_x = self.dataaugumentation.test(x)
xin = self.datashaping.prepareinput([da_x],dtype=np.float32,volatile=True)
y = self.model(xin,train=False)
res = {}
score_of_each = list(y.data)
predict_index = np.argmax(score_of_each, axis=1)
print(self.meta)
predict_label = self.meta[predict_index]
res["score_of_each"] = score_of_each
res["predict_index"] = int(predict_index)
res["predict_label"] = predict_label
return res
# if __name__ == "__main__":
#