-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_LFW.py
59 lines (52 loc) · 2.18 KB
/
eval_LFW.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.backends.cudnn as cudnn
from nets.facenet import Facenet
from utils.dataloader import LFWDataset
from utils.utils_metrics import test
if __name__ == "__main__":
#--------------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#--------------------------------------#
cuda = True
#--------------------------------------#
# 主干特征提取网络的选择
# mobilenet
# inception_resnetv1
#--------------------------------------#
backbone = "mobilenet"
#--------------------------------------------------------#
# 输入图像大小,常用设置如[112, 112, 3]
#--------------------------------------------------------#
input_shape = [160, 160, 3]
#--------------------------------------#
# 训练好的权值文件
#--------------------------------------#
model_path = "model_data/facenet_mobilenet.pth"
#--------------------------------------#
# LFW评估数据集的文件路径
# 以及对应的txt文件
#--------------------------------------#
lfw_dir_path = "lfw"
lfw_pairs_path = "model_data/lfw_pair.txt"
#--------------------------------------#
# 评估的批次大小和记录间隔
#--------------------------------------#
batch_size = 256
log_interval = 1
#--------------------------------------#
# ROC图的保存路径
#--------------------------------------#
png_save_path = "model_data/roc_test.png"
test_loader = torch.utils.data.DataLoader(
LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, image_size=input_shape), batch_size=batch_size, shuffle=False)
model = Facenet(backbone=backbone, mode="predict")
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model = model.eval()
if cuda:
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
model = model.cuda()
test(test_loader, model, png_save_path, log_interval, batch_size, cuda)