forked from Shiaoming/ALIKE
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_test.py
252 lines (221 loc) · 9.85 KB
/
demo_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import csv
import os
import time
import cv2
import glob
import logging
import argparse
import numpy as np
from alike import ALike, configs
from copy import deepcopy
from decimal import Decimal
from ALIKE_code.model_transfor.ckpt2pth2 import main
class ImageLoader(object):
def __init__(self, filepath: str):
self.images = glob.glob(os.path.join(filepath, '*.png')) + \
glob.glob(os.path.join(filepath, '*.jpg')) + \
glob.glob(os.path.join(filepath, '*.ppm'))
self.images.sort()
self.N = len(self.images)
logging.info(f'Loading {self.N} images')
self.mode = 'images'
def __getitem__(self, item):
filename = self.images[item]
img = cv2.imread(filename)
return img,filename
def __len__(self):
return self.N
def mnn_mather(desc1, desc2):
sim = desc1 @ desc2.transpose()
sim[sim < 0.75] = 0
nn12 = np.argmax(sim, axis=1)
nn21 = np.argmax(sim, axis=0)
ids1 = np.arange(0, sim.shape[0])
mask = (ids1 == nn21[nn12])
matches = np.stack([ids1[mask], nn12[mask]])
return matches.transpose()
def plot_keypoints(image, kpts, radius=2, color=(0, 0, 255)):
if image.dtype is not np.dtype('uint8'):
image = image * 255
image = image.astype(np.uint8)
if len(image.shape) == 2 or image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
out = np.ascontiguousarray(deepcopy(image))
kpts = np.round(kpts).astype(int)
for kpt in kpts:
x0, y0 = kpt
cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4)
return out
def plot_matches(image0,
image1,
kpts0,
kpts1,
matches,
radius=2,
color=(255, 0, 0)):
out0 = plot_keypoints(image0, kpts0, radius, color)
out1 = plot_keypoints(image1, kpts1, radius, color)
H0, W0 = image0.shape[0], image0.shape[1]
H1, W1 = image1.shape[0], image1.shape[1]
H, W = max(H0, H1), W0 + W1
out = 255 * np.ones((H, W, 3), np.uint8)
out[:H0, :W0, :] = out0
out[:H1, W0:, :] = out1
mkpts0, mkpts1 = kpts0[matches[:, 0]], kpts1[matches[:, 1]]
mkpts0 = np.round(mkpts0).astype(int)
mkpts1 = np.round(mkpts1).astype(int)
points_out = out.copy()
for kpt0, kpt1 in zip(mkpts0, mkpts1):
(x0, y0), (x1, y1) = kpt0, kpt1
mcolor = (
np.random.randint(0, 255),
np.random.randint(0, 255),
np.random.randint(0, 255),
)
cv2.line(out, (x0, y0), (x1 + W0, y1),
color=mcolor,
thickness=1,
lineType=cv2.LINE_AA)
cv2.putText(out, str(len(mkpts0)),
(out.shape[1] - 150, out.shape[0] - 50),
cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 255), 2)
return out,points_out
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ALIKED image pair Demo.')
parser.add_argument('--input', type=str, default='',
help='Image directory.')
parser.add_argument('--input2', type=str, default='',
help='Image directory.')
parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-n",
help="The model configuration")
parser.add_argument('--model_path', default="default",help="The model path, The default is open source model")
parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).")
parser.add_argument('--top_k', type=int, default=-1,
help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)')
parser.add_argument('--scores_th', type=float, default=0.2,
help='Detector score thr eshold (default: 0.2).')
parser.add_argument('--n_limit', type=int, default=5000,
help='Maximum number of keypoints to be detected (default: 5000).')
parser.add_argument('--radius', type=int, default=2,
help='The radius of non-maximum suppression (default: 2).')
parser.add_argument('--write_dir', type=str, default='',help='Image save directory.')
parser.add_argument('--version', type=str, default='',help='version')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
image_loader = ImageLoader(args.input)
# 模型路径配置
model_path_default = {
'alike-t': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth'),
'alike-s': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth'),
'alike-n': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth'),
'alike-l': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')
}
model_path = args.model_path
version = args.version
if not version:
raise Exception("version is not none!")
test_model_save_path = os.path.join('/media/xin/work1/github_pro/ALIKE/test_model_save',version)
os.makedirs(test_model_save_path, exist_ok=True) # 模型保存路径
output_model_path = os.path.join(test_model_save_path,f"{os.path.basename(model_path).split('.')[0]}.pth")
if model_path.endswith('.ckpt'): # 如果是ckpt并且模型不存在,则转换
if not os.path.exists(output_model_path):
main(model_path,output_model_path,args.device)
print("模型转换成功!")
model_path = output_model_path
elif model_path == 'default': # 如果是默认,则自动加载开源模型
model_path = model_path_default[args.model]
model = ALike(**configs[args.model],
model_path=model_path,
device=args.device,
top_k=args.top_k,
radius=args.radius,
scores_th=args.scores_th,
n_limit=args.n_limit
)
logging.info("Press 'space' to start. \nPress 'q' or 'ESC' to stop!")
image_loader2 = ImageLoader(args.input2)
# img_ref = image_loader[0]
# img_rgb = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB)
# pred_ref = model.run(img_rgb)
# kpts_ref = pred_ref['keypoints']
# desc_ref = pred_ref['descriptors']
sum_net_t = []
sum_net_matches_t = []
sum_total_t = [] # 初始化时间列表
for i in range(0,len(image_loader)):
start = time.time()
img,img_name = image_loader[i]
img2,img2_name = image_loader2[i]
if img is None or img2 is None:
break
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_rgb2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
start1 = time.time()
pred = model(img_rgb)
pred_ref = model(img_rgb2)
end1 = time.time()
kpts = pred['keypoints']
desc = pred['descriptors']
kpts_ref = pred_ref['keypoints']
desc_ref = pred_ref['descriptors']
try:
matches = mnn_mather(desc,desc_ref)
except:
continue
end2 = time.time()
status = f"matches/keypoints: {len(matches)}/{len(kpts)}"
vis_img,points_out = plot_matches(img, img2, kpts,kpts_ref, matches)
cv2.namedWindow(args.model)
cv2.setWindowTitle(args.model, args.model + ': ' + status)
cv2.putText(vis_img, "Press 'q' or 'ESC' to stop.", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2,
cv2.LINE_AA)
cv2.putText(points_out, str(len(kpts)),
(points_out.shape[1] - 150, points_out.shape[0] - 50),
cv2.FONT_HERSHEY_COMPLEX, 2, (0, 0, 255), 2)
cv2.imshow('points', points_out)
cv2.imshow(args.model, vis_img)
end = time.time()
net_t = end1 - start1
net_matches_t = end2 - start1
total_t = end - start
print('Processed image %d (net: %.3f FPS,net+matches: %.3f FPS, total: %.3f FPS).' % (
i, net_t, net_matches_t, total_t))
if len(sum_net_t) < 102: # 剔除最后一张和第一张 计算100张图片的平均帧率
sum_net_t.append(net_t)
sum_net_matches_t.append(net_matches_t)
sum_total_t.append(total_t)
save_img_path = args.write_dir
if save_img_path: # 匹配的图像文件保存
img_name = os.path.basename(img_name)
os.makedirs(save_img_path, exist_ok=True)
out_file1 = os.path.join(save_img_path, "t" + img_name)
cv2.imwrite(out_file1, points_out)
out_file2 = os.path.join(save_img_path, "d" + img_name)
cv2.imwrite(out_file2, vis_img)
log_file = os.path.join(save_img_path, "log.csv")
f = open(log_file, 'a') # 记录图像的特征点和匹配数量
writer = csv.writer(f)
writer.writerow([img_name, len(kpts), len(matches)])
c = cv2.waitKey(1)
if c == 32:
while True:
key = cv2.waitKey(1)
if key == 32:
break
if c == ord('q') or c == 27:
break
if i == 1100 or i ==110:
break
# 计算平均帧率
avg_net_FPS = np.mean(sum_net_t[1:len(sum_net_t)-1])
avg_net_matches_FPS = np.mean(sum_net_matches_t[1:len(sum_net_matches_t)-1])
avg_total_FPS = np.mean(sum_total_t[1:len(sum_total_t)-1])
if args.write_dir: # 记录图像的平均帧率
writer.writerow([f'avg_net_FPS:{avg_net_FPS:.3f},avg_net+matches_FPS:{avg_net_matches_FPS:.3f},avg_total_FPS:{avg_total_FPS:.3f}'])
print(
f'avg_FPS:\n avg_net_FPS:{avg_net_FPS:.3f},avg_net+matches_FPS:{avg_net_matches_FPS:.3f},avg_total_FPS:{avg_total_FPS:.3f}')
logging.info('Finished!')
logging.info('Press any key to exit!')
cv2.putText(vis_img, "Finished! Press any key to exit.", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2,cv2.LINE_AA)
cv2.imshow(args.model, vis_img)
cv2.waitKey()