-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtests.py
121 lines (94 loc) · 4.1 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from utils import load_classes_from_folder, load_class_from_wav
from process_signal import autocorrelation, calculate_lsp, get_energy, \
get_edges, run_whole_signal, in_region,\
euclidian_distance, dtw, get_new_matrix, \
get_global_distance
pf = 146
ws = 80
wa = ws
p = 16
tol = 1.2
k1 = .0001
k2 = .0003
gender = "female"
gender_test = "male"
train_folder = "./corpus_digitos/training-examples/" + gender
test_folder = "./corpus_digitos/test-examples/" + gender_test
print("Train", gender, "Test", gender_test)
classes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
if gender == 'female':
centroids = pd.read_pickle(r'centroids/centroids-female.pickle')
else:
centroids = pd.read_pickle(r'centroids/centroids.pickle')
centroid_distances = pd.read_pickle(r'centroids/centroids-max-distances.pickle')
# print(centroids)
tests = load_classes_from_folder(test_folder, extension=".wav"); tests = np.array(tests)
confusion_matrix = np.zeros((len(classes), len(classes)))
test_c = 0
for test_class in tests:
for test in test_class:
test_signal = test[0]
# Calculate the lsfs
lsfs_test, _, _ = run_whole_signal(test_signal, ws, wa, pf, k1, k2, p, to_plot=False)
# Go through each centroid and calculate the distance
distances = []
for key in centroids:
if key == 10: continue
# Get the lsf of the centroid
lsf_centroid = centroids[key]
# Calculate the distance between the two signals
dtw_matrix = dtw(lsf_centroid, lsfs_test, p, to_plot=False)
min_matrix = get_new_matrix(dtw_matrix, to_plot=False)
distance, new_matrix = get_global_distance(min_matrix, to_plot=False)
distances = np.append(distances, distance)
# Get the indexes of the 2 smallest distances
min_indexes = np.argsort(distances)[:2]
# Get the classes of the 2 smallest distances
min_classes = [classes[i] for i in min_indexes]
# print("2 centroides", min_classes)
classes_ = np.array([])
new_distances = np.array([])
for nearest_class in min_classes:
if nearest_class == 10:
ajuda = load_class_from_wav(train_folder, 'O')
elif nearest_class == 11:
ajuda = load_class_from_wav(train_folder, 'Z')
else:
ajuda = load_class_from_wav(train_folder, str(nearest_class))
for treino in ajuda:
treino_signal = treino[0]
# Calculate the lsfs
lsfs_treino, _, _ = run_whole_signal(treino_signal, ws, wa, pf, k1, k2, p, to_plot=False)
# Calculate the distance between the two signals
dtw_matrix = dtw(lsfs_treino, lsfs_test, p, to_plot=False)
min_matrix = get_new_matrix(dtw_matrix, to_plot=False)
d, new_matrix = get_global_distance(min_matrix, to_plot=False)
new_distances = np.append(new_distances, d)
classes_ = np.append(classes_, nearest_class)
idx = np.argmin(new_distances)
predicted_class = classes_[idx]
# Check if it is in the lexicon
max_distance = centroid_distances[predicted_class-1]
if new_distances[idx] < max_distance*tol:
confusion_matrix[classes[test_c] - 1, int(predicted_class) - 1] += 1
else:
confusion_matrix[classes[test_c] - 1, classes[len(classes)-1] - 1] += 1
test_c += 1
print(confusion_matrix)
# Get accuracy from confusion matrix
accuracy = np.trace(confusion_matrix)/np.sum(confusion_matrix)
print("Accuracy:", accuracy)
ax = sns.heatmap(confusion_matrix, annot=True, cmap='Blues')
ax.set_title('Confusion Matrix\n\n')
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
## Ticket labels - List must be in alphabetical order
ax.xaxis.set_ticklabels(['1','2', '3', '4', '5', '6', '7', '8', '9', '0', 'Z'])
ax.yaxis.set_ticklabels(['1','2', '3', '4', '5', '6', '7', '8', '9', '0', 'Z'])
## Display the visualization of the Confusion Matrix.
plt.show()