jy_code.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2024/2/6
  3. # @Author : lkj
  4. # @description :
  5. from collections import Counter
  6. from numpy import dot
  7. from numpy.linalg import norm
  8. from utils.milvus_hlper import Milvus
  9. from utils.request_fun import text_to_vector
  10. import numpy as np
  11. class JyCode(object):
  12. def __init__(self, db_name, config):
  13. self.m = Milvus(table=db_name, **config)
  14. @staticmethod
  15. def cosine_similarity(vector1, vector2):
  16. """
  17. 余弦相似计算
  18. """
  19. dot_product = np.dot(vector1, vector2)
  20. norm1 = np.linalg.norm(vector1)
  21. norm2 = np.linalg.norm(vector2)
  22. similarity = dot_product / (norm1 * norm2)
  23. return round(similarity, 4)
  24. def check(self, text_vec, data_es: list, offset_value=0.025, min_value=0.9, max_value=0.965):
  25. '''
  26. 规则打可信度
  27. :param text_vec:
  28. :param data_es: 通过向量查询后的数据
  29. :param offset_value topk之间的差异值
  30. :param min_value 可信度最低阈值,低于该阈值不可信
  31. :param max_value 最大阈值,大于该值默认直接可信
  32. :return:
  33. '''
  34. try:
  35. sim_list = [i[2] for i in data_es] # 相似度列表
  36. sim_mean = sum(sim_list) / len(sim_list) # 平均相似度
  37. data = [item for item in data_es if item[2] >= sim_mean - offset_value ] # 删除es结果中差异性大且相似度低的值
  38. count = 3 # 统计个数阈值
  39. if len(data) <= 4:
  40. count = 2
  41. if data[0][2] > max_value: # 查询结果为正序如果满足极大值或者较大差异直接返回第一个数据
  42. return data[0][1], 'mode01'
  43. if data[0][2] - data[1][2] > offset_value and data[0][2] > min_value: # 如果第一个可信度大于第二个0.025并且第一个相似度大于0.9
  44. return data[0][1], 'mode1'
  45. # 新判断当出现多个大于0.9判断每个上一级分类的选取最高的
  46. score_list = [i[1] for i in data_es if i[2] > min_value]
  47. if len(score_list) >= 2:
  48. best_code = ''
  49. best_sim = 0
  50. for end_code in score_list:
  51. father_code = end_code[:-2] # 父级的id
  52. father_code_name = self.m.get_name(father_code) # 父级的name
  53. father_vec = text_to_vector(father_code_name)
  54. similarity = round(dot(text_vec, father_vec) / (norm(text_vec) * norm(father_vec)), 4) # 相似度计算
  55. if similarity > best_sim: # 循环查找父级相似度最大的
  56. best_code = end_code
  57. best_sim = similarity
  58. return best_code, 'mode1'
  59. # 第三档的可信度规则
  60. elif data[0][2] - data[1][2] > (offset_value-0.01) and data[0][2] > min_value:
  61. pcode = data[0][1][:-2]
  62. pname = self.m.get_name(pcode)
  63. pvec = text_to_vector(pname)
  64. similarity = round(dot(text_vec, pvec) / (norm(text_vec) * norm(pvec)), 4) # 相似度计算
  65. if similarity > (min_value-0.1): # 第一个值的父级相似度大于0.8直接返回该值
  66. return data[0][1], 'mode1'
  67. else:
  68. # 统计整个向量库中返回的数据的出现的频率做规则
  69. code_list = []
  70. for row in data:
  71. if len(row[1][:-2]) > 1:
  72. code_list.append(row[1][:-2])
  73. code_list.append(row[1])
  74. word_count = dict(Counter(code_list)) # 统计对应分类及其父类的词频,如果某个词的父类频率高则定位到该类
  75. max_word = [(k, v) for k, v in word_count.items() if v >= count]
  76. if len(max_word) == 1:
  77. return max_word[0][0], 'mode2'
  78. else: # 如果存在多个值进行对比
  79. code = ''
  80. code_sim = 0
  81. for word in max_word:
  82. code_ = word[0]
  83. code_name = self.m.get_name(code_)
  84. vec = text_to_vector(code_name)
  85. sim_ = round(dot(text_vec, vec) / (norm(text_vec) * norm(vec)), 4) # 相似度计算
  86. if sim_ > code_sim:
  87. code_sim = sim_
  88. code = code_
  89. return code, 'mode3'
  90. return '', 'mode0'
  91. except Exception as e:
  92. print('check_error', e)
  93. return '', 'error'
  94. def run_mode1(self,text, baseclass=None):
  95. """
  96. 标的物数字化主函数
  97. :param text:
  98. :param baseclass
  99. :return: result_name:结果名称, similarity:结果与输入文本相似度, mode:流程模式, code:结果编码, credibility:结果可信度
  100. """
  101. vec = text_to_vector(text) # 转成向量
  102. search_result = self.m.search_china(vec, baseclass) # 查询结果
  103. if not search_result:
  104. return '', 0, 'mode0', '', 0
  105. similarity = 0 # 文本与结果相似度
  106. result_name = ''
  107. mode = 'mode0'
  108. credibility = 0
  109. code = ''
  110. if search_result:
  111. check_result = self.check(vec, search_result) # 结果筛选
  112. if check_result[0]:
  113. code = check_result[0]
  114. result_name = self.m.get_name(check_result[0]) # 名称映射
  115. mode = check_result[1]
  116. if result_name:
  117. res_vec = text_to_vector(result_name)
  118. similarity = round(dot(vec, res_vec) / (norm(vec) * norm(res_vec)), 4) # 相似度计算
  119. if mode == 'mode01':
  120. credibility = 0.99
  121. if mode == 'mode1':
  122. credibility = 0.95
  123. if mode == 'mode3' and similarity > 0.85:
  124. credibility = 0.90
  125. if mode == 'mode2' and similarity > 0.8:
  126. credibility = 0.85
  127. if mode in ['mode2', 'mode3']:
  128. pcode = code[:-2]
  129. if not pcode:
  130. pcode = code
  131. p_name = self.m.get_name(pcode) # 父类名称
  132. p_name_vec = text_to_vector(p_name) # 父类向量
  133. p_similarity = round(dot(p_name_vec, vec) / (norm(p_name_vec) * norm(vec)), 4) # 文本与父类计算
  134. if p_similarity > 0.85 and similarity > 0.9 or similarity == 0.99:
  135. mode = 'mode4'
  136. credibility = 0.99
  137. if credibility == 0:
  138. result_name = search_result[0][0]
  139. similarity = search_result[0][2]
  140. mode = ''
  141. code = search_result[0][1]
  142. return result_name, similarity, mode, code, credibility
  143. def run_mode1_main(self,text, baseclass=None):
  144. """
  145. """
  146. try:
  147. result = list(self.run_mode1(text, baseclass))
  148. if not result[0]:
  149. return ['', '', '', 0, '']
  150. result.pop(2)
  151. code = result[2]
  152. route = self.m.get_root_zc(code)
  153. result.append(route)
  154. if result[1] > 0.9 and result[-2] == 0:
  155. result[-2] = 0.85
  156. if result[1] == 1.0:
  157. result[-2] = 0.99
  158. return result
  159. except Exception as e:
  160. print('政采分类错误--->', e)
  161. return ['', '', '', '', '']