Skip to content

Commit

Permalink
优化代码结构,新增pytorch转keras模型代码,新增darknet训练代码,优化web界面显示,更新相关模型
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlihaoyu committed Apr 3, 2019
1 parent f2fd9c2 commit 2da42b1
Show file tree
Hide file tree
Showing 46 changed files with 1,649 additions and 207 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# 实现功能
- [x] 文字方向检测 0、90、180、270度检测(支持dnn/tensorflow)
- [x] 支持(darknet/opencv dnn /keras)文字检测,暂时公布(keras版本训练)
- [x] 不定长OCR训练(英文、中英文) crnn\dense ocr
- [x] 支持(darknet/opencv dnn /keras)文字检测,支持darknet/keras训练
- [x] 不定长OCR训练(英文、中英文) crnn\dense ocr 识别及训练 ,新增pytorch转keras模型代码(tools/pytorch_to_keras.py)
- [x] 新增对身份证/火车票结构化数据识别

## 环境部署
Expand Down Expand Up @@ -45,9 +45,10 @@ ipython app.py 8080 ##8080端口号,可以设置任意端口

## 识别结果展示

<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/train1.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/card1.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/demo2.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/train-demo.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/idcard-demo.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/img-demo.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/line-demo.png"/>

## Play with Docker Container(镜像有些滞后)
``` Bash
Expand All @@ -58,7 +59,7 @@ docker run -d -p 8080:8080 zergmk2/chineseocr
## 访问服务
http://127.0.0.1:8080/ocr

<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/demo1.png"/>
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/demo.png"/>


## 参考
Expand Down
97 changes: 56 additions & 41 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
@author: lywen
"""
import os
from PIL import Image
import cv2
import json
import time
import uuid
import base64
import web
from PIL import Image
web.config.debug = True
import model
render = web.template.render('templates', base='base')
from config import DETECTANGLE
from apphelper.image import union_rbox
from apphelper.image import union_rbox,adjust_box_to_origin
from application import trainTicket,idcard


Expand All @@ -37,49 +38,65 @@ def POST(self):
data = web.data()
data = json.loads(data)
billModel = data.get('billModel','')
textAngle = data.get('textAngle',False)##文字检测
textLine = data.get('textLine',False)##只进行单行识别

imgString = data['imgString'].encode().split(b';base64,')[-1]
imgString = base64.b64decode(imgString)
jobid = uuid.uuid1().__str__()
path = '/tmp/{}.jpg'.format(jobid)
path = 'test/{}.jpg'.format(jobid)
with open(path,'wb') as f:
f.write(imgString)
img = Image.open(path).convert("RGB")
W,H = img.size
img = cv2.imread(path)##GBR
H,W = img.shape[:2]
timeTake = time.time()
_,result,angle= model.model(img,
detectAngle=DETECTANGLE,##是否进行文字方向检测
config=dict(MAX_HORIZONTAL_GAP=100,##字符之间的最大间隔,用于文本行的合并
MIN_V_OVERLAPS=0.7,
MIN_SIZE_SIM=0.7,
TEXT_PROPOSALS_MIN_SCORE=0.1,
TEXT_PROPOSALS_NMS_THRESH=0.3,
TEXT_LINE_NMS_THRESH = 0.99,##文本行之间测iou值
MIN_RATIO=1.0,
LINE_MIN_SCORE=0.2,
TEXT_PROPOSALS_WIDTH=0,
MIN_NUM_PROPOSALS=0,
),
leftAdjust=True,##对检测的文本行进行向左延伸
rightAdjust=True,##对检测的文本行进行向右延伸
alph=0.2,##对检测的文本行进行向右、左延伸的倍数
ifadjustDegree=False##是否先小角度调整文字倾斜角度
)



if billModel=='' or billModel=='通用OCR' :
result = union_rbox(result,0.2)
res = [{'text':x['text'],'name':str(i)} for i,x in enumerate(result)]
elif billModel=='火车票':
res = trainTicket.trainTicket(result)
res = res.res
res =[ {'text':res[key],'name':key} for key in res]

elif billModel=='身份证':

res = idcard.idcard(result)
res = res.res
res =[ {'text':res[key],'name':key} for key in res]
if textLine:
##单行识别
partImg = Image.fromarray(img)
text = model.crnnOcr(partImg.convert('L'))
res =[ {'text':text,'name':'0','box':[0,0,W,0,W,H,0,H]} ]
else:
detectAngle = textAngle
_,result,angle= model.model(img,
detectAngle=detectAngle,##是否进行文字方向检测,通过web传参控制
config=dict(MAX_HORIZONTAL_GAP=50,##字符之间的最大间隔,用于文本行的合并
MIN_V_OVERLAPS=0.6,
MIN_SIZE_SIM=0.6,
TEXT_PROPOSALS_MIN_SCORE=0.1,
TEXT_PROPOSALS_NMS_THRESH=0.3,
TEXT_LINE_NMS_THRESH = 0.7,##文本行之间测iou值
),
leftAdjust=True,##对检测的文本行进行向左延伸
rightAdjust=True,##对检测的文本行进行向右延伸
alph=0.01,##对检测的文本行进行向右、左延伸的倍数
)



if billModel=='' or billModel=='通用OCR' :
result = union_rbox(result,0.2)
res = [{'text':x['text'],
'name':str(i),
'box':{'cx':x['cx'],
'cy':x['cy'],
'w':x['w'],
'h':x['h'],
'angle':x['degree']

}
} for i,x in enumerate(result)]
res = adjust_box_to_origin(img,angle, res)##修正box

elif billModel=='火车票':
res = trainTicket.trainTicket(result)
res = res.res
res =[ {'text':res[key],'name':key,'box':{}} for key in res]

elif billModel=='身份证':

res = idcard.idcard(result)
res = res.res
res =[ {'text':res[key],'name':key,'box':{}} for key in res]


timeTake = time.time()-timeTake
Expand All @@ -89,8 +106,6 @@ def POST(self):
return json.dumps({'res':res,'timeTake':round(timeTake,4)},ensure_ascii=False)




urls = ('/ocr','OCR',)

if __name__ == "__main__":
Expand Down
25 changes: 24 additions & 1 deletion apphelper/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def diff(box1,box2):
h1 = box1['h']
h2 = box2['h']

return abs(cy1-cy2)/max(0.01,min(h1/2,h1/2))
return abs(cy1-cy2)/max(0.01,min(h1/2,h2/2))

def sort_group_box(boxes):
"""
Expand Down Expand Up @@ -589,3 +589,26 @@ def sort_group_box(boxes):
return newBox


def adjust_box_to_origin(img,angle, result):
"""
调整box到原图坐标
"""
h,w = img.shape[:2]
if angle in [90,270]:
imgW,imgH = img.shape[:2]

else:
imgH,imgW= img.shape[:2]
newresult = []
for line in result:
cx =line['box']['cx']
cy = line['box']['cy']
degree =line['box']['angle']
w = line['box']['w']
h = line['box']['h']
x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx, cy, w, h, degree/180*np.pi)
x1,y1,x2,y2,x3,y3,x4,y4 = box_rotate([x1,y1,x2,y2,x3,y3,x4,y4],angle=(360-angle)%360,imgH=imgH,imgW=imgW)
box = x1,y1,x2,y2,x3,y3,x4,y4
newresult.append({'name':line['name'],'text':line['text'],'box':box})

return newresult
15 changes: 8 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
########################文字检测########################
##文字检测引擎 keras,opencv,darknet
##文字检测引擎
pwd = os.getcwd()
opencvFlag = 'keras'
opencvFlag = 'keras' ##keras,opencv,darknet,模型性能 keras>darknet>opencv
IMGSIZE = (608,608)## yolo3 输入图像尺寸
## keras 版本anchors
keras_anchors = '8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800'
keras_anchors = '8,11, 8,16, 8,23, 8,33, 8,48, 8,97, 8,139, 8,198, 8,283'
class_names = ['none','text',]
kerasTextModel=os.path.join(pwd,"models","text.h5")##keras版本模型权重文件

Expand All @@ -23,8 +23,9 @@
GPUID=0##调用GPU序号

## nms选择,支持cython,gpu,python
nmsFlag='gpu'## cython/gpu/python

nmsFlag='gpu'## cython/gpu/python ##容错性 优先启动GPU,其次是cpython 最后是python
if not GPU:
nmsFlag='cython'


##vgg文字方向检测模型
Expand All @@ -38,9 +39,9 @@
##OCR模型是否调用LSTM层
LSTMFLAG = True
##模型选择 True:中英文模型 False:英文模型

ocrFlag = 'torch'##ocr模型 支持 keras torch版本
chinsesModel = True

ocrModelKeras = os.path.join(pwd,"models","ocr-dense-keras.h5")##keras版本OCR,暂时支持dense
if chinsesModel:
if LSTMFLAG:
ocrModel = os.path.join(pwd,"models","ocr-lstm.pth")
Expand Down
42 changes: 42 additions & 0 deletions crnn/crnn_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#coding:utf-8
from crnn.utils import strLabelConverter,resizeNormalize

from crnn.network_keras import keras_crnn as CRNN
import tensorflow as tf
graph = tf.get_default_graph()##解决web.py 相关报错问题

from crnn import keys
from config import ocrModelKeras
import numpy as np
def crnnSource():
alphabet = keys.alphabetChinese##中英文模型
converter = strLabelConverter(alphabet)
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=False)
model.load_weights(ocrModelKeras)
return model,converter

##加载模型
model,converter = crnnSource()

def crnnOcr(image):
"""
crnn模型,ocr识别
image:PIL.Image.convert("L")
"""
scale = image.size[1]*1.0 / 32
w = image.size[0] / scale
w = int(w)
transformer = resizeNormalize((w, 32))
image = transformer(image)
image = image.astype(np.float32)
image = np.array([[image]])
global graph
with graph.as_default():
preds = model.predict(image)
preds = preds[0]
preds = np.argmax(preds,axis=2).reshape((-1,))
sim_pred = converter.decode(preds)
return sim_pred



34 changes: 18 additions & 16 deletions crnn/crnn.py → crnn/crnn_torch.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#coding:utf-8
import torch
import torch.utils.data
import numpy as np
from torch.autograd import Variable
from crnn import util
from crnn import dataset
from crnn.network import CRNN
from crnn.utils import strLabelConverter,resizeNormalize
from crnn.network_torch import CRNN
from crnn import keys
from collections import OrderedDict
from config import ocrModel,LSTMFLAG,GPU
from config import chinsesModel
def crnnSource():
"""
加载模型
"""
if chinsesModel:
alphabet = keys.alphabetChinese##中英文模型
else:
alphabet = keys.alphabetEnglish##英文模型

converter = util.strLabelConverter(alphabet)
converter = strLabelConverter(alphabet)
if torch.cuda.is_available() and GPU:
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
else:
Expand All @@ -27,15 +29,14 @@ def crnnSource():
name = k.replace('module.','') # remove `module.`
modelWeights[name] = v
# load params

model.load_state_dict(modelWeights)
model.eval()

return model,converter

##加载模型
model,converter = crnnSource()

model.eval()
def crnnOcr(image):
"""
crnn模型,ocr识别
Expand All @@ -44,21 +45,22 @@ def crnnOcr(image):
scale = image.size[1]*1.0 / 32
w = image.size[0] / scale
w = int(w)
transformer = dataset.resizeNormalize((w, 32))
transformer = resizeNormalize((w, 32))
image = transformer(image)
image = image.astype(np.float32)
image = torch.from_numpy(image)

if torch.cuda.is_available() and GPU:
image = transformer(image).cuda()
image = image.cuda()
else:
image = transformer(image).cpu()
image = image.cpu()

image = image.view(1, *image.size())
image = image.view(1,1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

sim_pred = converter.decode(preds)
return sim_pred


Loading

0 comments on commit 2da42b1

Please sign in to comment.