From 90811d985fcba32a6f8c253f6d0bc480056e0888 Mon Sep 17 00:00:00 2001 From: "lotomer@163.com" Date: Fri, 18 Oct 2024 09:02:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=B0=83=E7=94=A8set=5Frange?= =?UTF-8?q?=E6=B2=A1=E6=9C=89=E5=BE=97=E5=88=B0=E6=9C=9F=E6=9C=9B=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E7=9A=84BUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ddddocr/__init__.py | 53 ++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/ddddocr/__init__.py b/ddddocr/__init__.py index c689590..69b8409 100644 --- a/ddddocr/__init__.py +++ b/ddddocr/__init__.py @@ -54,6 +54,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: self.__word = False self.__resize = [] self.__charset_range = [] + self.__valid_charset_range_index = [] # 指定字符对应的有效索引 self.__channel = 1 if import_onnx_path != "": det = False @@ -2593,6 +2594,16 @@ def set_ranges(self, charset_range): # 去重 self.__charset_range = list(set(self.__charset_range)) + [""] + # 根据指定字符获取对应的索引 + valid_charset_range_index = [] + if len(self.__charset_range) > 0: + for item in self.__charset_range: + if item in self.__charset: + valid_charset_range_index.append(self.__charset.index(item)) + else: + # 未知字符没有索引,直接忽略 + pass + self.__valid_charset_range_index = valid_charset_range_index def classification(self, img, png_fix: bool = False, probability=False): @@ -2667,28 +2678,36 @@ def classification(self, img, png_fix: bool = False, probability=False): result['probability'] = ort_outs_probability else: result['charsets'] = self.__charset_range - probability_result_index = [] - for item in self.__charset_range: - if item in self.__charset: - probability_result_index.append(self.__charset.index(item)) - else: - # 未知字符 - probability_result_index.append(-1) + valid_charset_range_index = self.__valid_charset_range_index probability_result = [] for item in ort_outs_probability: - probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ]) + probability_result.append([item[i] for i in valid_charset_range_index ]) result['probability'] = probability_result return result else: - last_item = 0 - argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2)) - for item in argmax_result: - if item == last_item: - continue - else: - last_item = item - if item != 0: - result.append(self.__charset[item]) + if len(self.__charset_range) == 0: + # 没有指定特定的字符集合,直接获取结果 + last_item = 0 + argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2)) + for item in argmax_result: + if item == last_item: + continue + else: + last_item = item + if item != 0: + result.append(self.__charset[item]) + else: + # 指定了特定的字符集合 + last_item = 0 + valid_charset_range_index = self.__valid_charset_range_index + for row in np.squeeze(ort_outs[0]): + # 仅在指定字符集合中寻找最大值 + idx = np.argmax(row[list(valid_charset_range_index)]) + if idx == last_item: + continue + else: + last_item = idx + result.append(self.__charset[valid_charset_range_index[idx]]) return ''.join(result) else: