Skip to content

Commit

Permalink
修复调用set_range没有得到期望结果的BUG
Browse files Browse the repository at this point in the history
  • Loading branch information
lotomer committed Oct 18, 2024
1 parent db75d4a commit 90811d9
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions ddddocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
self.__word = False
self.__resize = []
self.__charset_range = []
self.__valid_charset_range_index = [] # 指定字符对应的有效索引
self.__channel = 1
if import_onnx_path != "":
det = False
Expand Down Expand Up @@ -2593,6 +2594,16 @@ def set_ranges(self, charset_range):

# 去重
self.__charset_range = list(set(self.__charset_range)) + [""]
# 根据指定字符获取对应的索引
valid_charset_range_index = []
if len(self.__charset_range) > 0:
for item in self.__charset_range:
if item in self.__charset:
valid_charset_range_index.append(self.__charset.index(item))
else:
# 未知字符没有索引,直接忽略
pass
self.__valid_charset_range_index = valid_charset_range_index


def classification(self, img, png_fix: bool = False, probability=False):
Expand Down Expand Up @@ -2667,28 +2678,36 @@ def classification(self, img, png_fix: bool = False, probability=False):
result['probability'] = ort_outs_probability
else:
result['charsets'] = self.__charset_range
probability_result_index = []
for item in self.__charset_range:
if item in self.__charset:
probability_result_index.append(self.__charset.index(item))
else:
# 未知字符
probability_result_index.append(-1)
valid_charset_range_index = self.__valid_charset_range_index
probability_result = []
for item in ort_outs_probability:
probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
probability_result.append([item[i] for i in valid_charset_range_index ])
result['probability'] = probability_result
return result
else:
last_item = 0
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
for item in argmax_result:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
if len(self.__charset_range) == 0:
# 没有指定特定的字符集合,直接获取结果
last_item = 0
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
for item in argmax_result:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
else:
# 指定了特定的字符集合
last_item = 0
valid_charset_range_index = self.__valid_charset_range_index
for row in np.squeeze(ort_outs[0]):
# 仅在指定字符集合中寻找最大值
idx = np.argmax(row[list(valid_charset_range_index)])
if idx == last_item:
continue
else:
last_item = idx
result.append(self.__charset[valid_charset_range_index[idx]])
return ''.join(result)

else:
Expand Down

0 comments on commit 90811d9

Please sign in to comment.