-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathKNN_builder.py
79 lines (59 loc) · 1.71 KB
/
KNN_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by iFantastic on 6/29/18
# Author : ZOU Zijie
# Email : [email protected]
# Plateform : pycharm
"""
This function is for reading traning set and output KNN result
-------------ZOU Zijie :)
"""
import os
import csv
import time
import numpy as np
from sklearn.neighbors import KNeighborsRegressor,KNeighborsClassifier,DistanceMetric
from sklearn.externals import joblib
import sys
sys.path.append(os.path.realpath('.'))
#/------Initial Logger------/
logi_time = time.strftime("_%d-%b-%Y_%H:%M:%S", time.localtime())
def falttern(path):
vector=np.load(path)
result = vector.flatten()
return result
def generate_dataset(csv_path):
x_train=[]
y_train=[]
for root, dirs, files in os.walk(csv_path):
pass
# print(root)
# print(files)
for j in files:
csv_temp = root+"/"+j
name = j[:-4]
rf = open(csv_temp, 'r')
reader = list(csv.reader(rf))
counter =0
for k in reader:
if counter>0:
read_path = k[3]
data = falttern(read_path)
x_train.append(data)
y_train.append(name)
else:
pass
counter+=1
return x_train,y_train
def knn_classifier(input,knn_path):
knn = joblib.load(knn_path)
prob = knn.predict_proba(input)
pred = knn.predict(input)
# print(max(distance[0][0]),pred)
return pred,prob
def training_KNN(csv_path):
X_train,Y_train=generate_dataset(csv_path)
knn = KNeighborsClassifier(n_neighbors=5, algorithm='ball_tree')
knn.fit(X_train, Y_train)
joblib.dump(knn, './models/knn.model')
# return print("---KNN model saved---")