Skip to content

Commit

Permalink
1.5.1 新增ocr结果概率表返回以及输出结果限定,重新整理文档
Browse files Browse the repository at this point in the history
  • Loading branch information
sml2h3 committed May 6, 2024
1 parent 0fe1283 commit f0ee6b1
Show file tree
Hide file tree
Showing 8 changed files with 649 additions and 338 deletions.
441 changes: 277 additions & 164 deletions README.md

Large diffs are not rendered by default.

441 changes: 277 additions & 164 deletions ddddocr/README.md

Large diffs are not rendered by default.

100 changes: 92 additions & 8 deletions ddddocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
self.use_import_onnx = False
self.__word = False
self.__resize = []
self.__charset_range = []
self.__channel = 1
if import_onnx_path != "":
det = False
Expand Down Expand Up @@ -2552,7 +2553,47 @@ def get_bbox(self, image_bytes):
return []
return result

def classification(self, img, png_fix: bool = False):
def set_ranges(self, charset_range: int | str):
if isinstance(charset_range, int):
if charset_range == 0:
# 数字
self.__charset_range = list("0123456789")
elif charset_range == 1:
# 小写英文
self.__charset_range = list("abcdefghijklmnopqrstuvwxyz")
elif charset_range == 2:
# 大写英文
self.__charset_range = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
elif charset_range == 3:
# 混合英文
self.__charset_range = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
elif charset_range == 4:
# 小写英文+数字
self.__charset_range = list("abcdefghijklmnopqrstuvwxyz") + list(
"0123456789")
elif charset_range == 5:
# 大写英文+数字
self.__charset_range = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + list(
"0123456789")
elif charset_range == 6:
# 混合大小写+数字
self.__charset_range = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + list(
"0123456789")
elif charset_range == 7:
# 除去英文,数字
delete_range = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + list("0123456789")
self.__charset_range = [item for item in self.__charset if item not in delete_range]
elif isinstance(charset_range, str):
charset_range_list = list(charset_range)
self.__charset_range = charset_range_list
else:
raise TypeError("暂时不支持该类型数据的输入")

# 去重
self.__charset_range = list(set(self.__charset_range)) + [""]


def classification(self, img, png_fix: bool = False, probability=False):
if self.det:
raise TypeError("当前识别类型为目标检测")
if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
Expand Down Expand Up @@ -2601,19 +2642,62 @@ def classification(self, img, png_fix: bool = False):
result = []

last_item = 0

if self.__word:
for item in ort_outs[1]:
result.append(self.__charset[item])
else:
for item in ort_outs[0][0]:
if item == last_item:
continue
if not self.use_import_onnx:
# 概率输出仅限于使用官方模型
if probability:
ort_outs = ort_outs[0]
ort_outs = np.exp(ort_outs) / np.sum(np.exp(ort_outs))
ort_outs_sum = np.sum(ort_outs, axis=2)
ort_outs_probability = np.empty_like(ort_outs)
for i in range(ort_outs.shape[0]):
ort_outs_probability[i] = ort_outs[i] / ort_outs_sum[i]
ort_outs_probability = np.squeeze(ort_outs_probability).tolist()
result = {}
if len(self.__charset_range) == 0:
# 返回全部
result['charsets'] = self.__charset
result['probability'] = ort_outs_probability
else:
result['charsets'] = self.__charset_range
probability_result_index = []
for item in self.__charset_range:
if item in self.__charset:
probability_result_index.append(self.__charset.index(item))
else:
# 未知字符
probability_result_index.append(-1)
probability_result = []
for item in ort_outs_probability:
probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
result['probability'] = probability_result
return result
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
last_item = 0
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
for item in argmax_result:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
return ''.join(result)

return ''.join(result)
else:
last_item = 0
for item in ort_outs[0][0]:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
return ''.join(result)

def detection(self, img_bytes: bytes = None, img_base64: str = None):
if not self.det:
Expand Down
Binary file modified ddddocr/common.onnx
Binary file not shown.
Binary file modified ddddocr/common_old.onnx
Binary file not shown.
Binary file added ddddocr/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="ddddocr",
version="1.4.11",
version="1.5.1",
author="sml2h3",
description="带带弟弟OCR",
long_description=long_description,
Expand All @@ -28,11 +28,12 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
install_requires=['numpy', 'onnxruntime', 'Pillow', 'opencv-python-headless'],
python_requires='<3.12',
python_requires='<3.13',
include_package_data=True,
install_package_data=True,
)

0 comments on commit f0ee6b1

Please sign in to comment.