-
Notifications
You must be signed in to change notification settings - Fork 13
/
train_and_test.py
141 lines (96 loc) · 7.87 KB
/
train_and_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# train_and_test.py
import cv2
import numpy as np
import operator
import os
# module level variables ##########################################################################
MIN_CONTOUR_AREA = 100
RESIZED_IMAGE_WIDTH = 20
RESIZED_IMAGE_HEIGHT = 30
###################################################################################################
class ContourWithData():
# member variables ############################################################################
npaContour = None # contour
boundingRect = None # bounding rect for contour
intRectX = 0 # bounding rect top left corner x location
intRectY = 0 # bounding rect top left corner y location
intRectWidth = 0 # bounding rect width
intRectHeight = 0 # bounding rect height
fltArea = 0.0 # area of contour
def calculateRectTopLeftPointAndWidthAndHeight(self): # calculate bounding rect info
[intX, intY, intWidth, intHeight] = self.boundingRect
self.intRectX = intX
self.intRectY = intY
self.intRectWidth = intWidth
self.intRectHeight = intHeight
def checkIfContourIsValid(self): # this is oversimplified, for a production grade program
if self.fltArea < MIN_CONTOUR_AREA: return False # much better validity checking would be necessary
return True
###################################################################################################
def main():
allContoursWithData = [] # declare empty lists,
validContoursWithData = [] # we will fill these shortly
npaClassifications = np.loadtxt("classifications.txt", np.float32) # read in training classifications
npaFlattenedImages = np.loadtxt("flattened_images.txt", np.float32) # read in training images
npaClassifications = npaClassifications.reshape((npaClassifications.size, 1)) # reshape numpy array to 1d, necessary to pass to call to train
kNearest = cv2.KNearest() # instantiate KNN object
kNearest.train(npaFlattenedImages, npaClassifications) # train KNN object
imgTestingNumbers = cv2.imread("test_numbers.png") # read in testing numbers image
if imgTestingNumbers is None: # if image was not read successfully
print "error: image not read from file \n\n" # print error message to std out
os.system("pause") # pause so user can see error message
return # and exit function (which exits program)
# end if
imgGray = cv2.cvtColor(imgTestingNumbers, cv2.COLOR_BGR2GRAY) # get grayscale image
imgBlurred = cv2.GaussianBlur(imgGray, (5,5), 0) # blur
# filter image from grayscale to black and white
imgThresh = cv2.adaptiveThreshold(imgBlurred, # input image
255, # make pixels that pass the threshold full white
cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # use gaussian rather than mean, seems to give better results
cv2.THRESH_BINARY_INV, # invert so foreground will be white, background will be black
11, # size of a pixel neighborhood used to calculate threshold value
2) # constant subtracted from the mean or weighted mean
imgThreshCopy = imgThresh.copy() # make a copy of the thresh image, this in necessary b/c findContours modifies the image
npaContours, npaHierarchy = cv2.findContours(imgThreshCopy, # input image, make sure to use a copy since the function will modify this image in the course of finding contours
cv2.RETR_EXTERNAL, # retrieve the outermost contours only
cv2.CHAIN_APPROX_SIMPLE) # compress horizontal, vertical, and diagonal segments and leave only their end points
for npaContour in npaContours: # for each contour
contourWithData = ContourWithData() # instantiate a contour with data object
contourWithData.npaContour = npaContour # assign contour to contour with data
contourWithData.boundingRect = cv2.boundingRect(contourWithData.npaContour) # get the bounding rect
contourWithData.calculateRectTopLeftPointAndWidthAndHeight() # get bounding rect info
contourWithData.fltArea = cv2.contourArea(contourWithData.npaContour) # calculate the contour area
allContoursWithData.append(contourWithData) # add contour with data object to list of all contours with data
# end for
for contourWithData in allContoursWithData: # for all contours
if contourWithData.checkIfContourIsValid(): # check if valid
validContoursWithData.append(contourWithData) # if so, append to valid contour list
# end if
# end for
validContoursWithData.sort(key = operator.attrgetter("intRectX")) # sort contours from left to right
strFinalString = "" # declare final string, this will have the final number sequence by the end of the program
for contourWithData in validContoursWithData: # for each contour
# draw a green rect around the current char
cv2.rectangle(imgTestingNumbers, # draw rectangle on original testing image
(contourWithData.intRectX, contourWithData.intRectY), # upper left corner
(contourWithData.intRectX + contourWithData.intRectWidth, contourWithData.intRectY + contourWithData.intRectHeight), # lower right corner
(0, 255, 0), # green
2) # thickness
imgROI = imgThresh[contourWithData.intRectY : contourWithData.intRectY + contourWithData.intRectHeight, # crop char out of threshold image
contourWithData.intRectX : contourWithData.intRectX + contourWithData.intRectWidth]
imgROIResized = cv2.resize(imgROI, (RESIZED_IMAGE_WIDTH, RESIZED_IMAGE_HEIGHT)) # resize image, this will be more consistent for recognition and storage
npaROIResized = imgROIResized.reshape((1, RESIZED_IMAGE_WIDTH * RESIZED_IMAGE_HEIGHT)) # flatten image into 1d numpy array
npaROIResized = np.float32(npaROIResized) # convert from 1d numpy array of ints to 1d numpy array of floats
retval, npaResults, neigh_resp, dists = kNearest.find_nearest(npaROIResized, k = 1) # call KNN function find_nearest
strCurrentChar = str(int(npaResults[0][0])) # get character from results
strFinalString = strFinalString + strCurrentChar # append current char to full string
# end for
print "\n" + strFinalString + "\n" # show the full string
cv2.imshow("imgTestingNumbers", imgTestingNumbers) # show input image with green boxes drawn around found digits
cv2.waitKey(0) # wait for user key press
cv2.destroyAllWindows() # remove windows from memory
return
###################################################################################################
if __name__ == "__main__":
main()
# end if