digital2.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2023/7/10
  3. # @Author : lkj
  4. # @description : 流程2 查询标准向量库语义相似
  5. import time
  6. from collections import Counter
  7. from numpy import dot
  8. from numpy.linalg import norm
  9. from utils.milvus_hlper import Milvus
  10. from utils.request_fun import text_to_vector
  11. from config import milvus_config
  12. m = Milvus(table='jianyu_code_2', **milvus_config)
  13. def count_fun(data:list):
  14. """
  15. 暴力搜索
  16. """
  17. try:
  18. count_dict = {}
  19. for item in data:
  20. split = 0
  21. if len(item) == 1:
  22. count_dict[item] = count_dict.get(item, 0) + 1
  23. else:
  24. for i in range(int((len(item) - 1) / 2)):
  25. if split == 0:
  26. c = item
  27. else:
  28. c = item[:-split]
  29. count_dict[c] = count_dict.get(c, 0) + 1
  30. split += 2
  31. if not count_dict:
  32. return '', 0
  33. max_value = max(count_dict.values()) # 找到字典中最大的 value 值
  34. if max_value == 1:
  35. max_length_key = min(
  36. [key for key, value in count_dict.items() if value == max_value],
  37. key=len
  38. )
  39. else:
  40. max_length_key = max(
  41. [key for key, value in count_dict.items() if value == max_value],
  42. key=len
  43. )
  44. return max_length_key,max_value
  45. except Exception as e:
  46. print('count_fun_errorxxx',e)
  47. return '',0
  48. def check(text_vec,data_: list):
  49. '''
  50. 第二档可信度判断规则
  51. :param text_vec:
  52. :param data_:
  53. :return:
  54. '''
  55. try:
  56. sim_list = [i[2] for i in data_]
  57. sim_mean = sum(sim_list)/len(sim_list)
  58. data = [item for item in data_ if item[2] >= sim_mean-0.025] # 删除es结果中差异性大且相似度低的值
  59. model01_res, model01_ = check_model01(text_vec, data)
  60. if model01_res:
  61. return model01_res, model01_
  62. # 新判断当出现多个大于0.925判断每个上一级分类的选取最高的
  63. score_list = [i[1] for i in data_ if i[2] > 0.92]
  64. count_flag = 3 # 出现词频阈值,根据score_list元素个数计算
  65. if len(score_list) <= 4:
  66. count_flag = 2
  67. word_count = dict(Counter(score_list)) # 统计出现频率高且相似度高的分类
  68. max_word = [(k, v) for k, v in word_count.items() if v >= count_flag]
  69. if len(max_word) == 1:
  70. return max_word[0][0], 'mode01'
  71. elif len(score_list) >= 2:
  72. best_code = ''
  73. best_sim = 0
  74. for code_ in score_list:
  75. if len(code_) ==1:
  76. continue
  77. father_code = code_[:-2]
  78. father_code_name = m.get_name(father_code)
  79. father_vec = text_to_vector(father_code_name)
  80. similarity = round(dot(text_vec, father_vec) / (norm(text_vec) * norm(father_vec)), 4) # 相似度计算
  81. if similarity > best_sim:
  82. best_code = code_
  83. best_sim = similarity
  84. return best_code, 'mode1'
  85. else:
  86. count_code, max_value = count_fun([i[1] for i in data_ if i[2] >= 0.85]) # 如果暴力查询的结果大于6/7 并且分类层级要大于2层
  87. if (max_value/len(data_)) >= (len(data_)-2)/len(data_) and len(count_code)>3:
  88. return count_code,'mode1'
  89. return '', ''
  90. except Exception as e:
  91. print('check_errorfff', e)
  92. return '', 'error'
  93. def check_model01(text_vec,data_es):
  94. """
  95. 第一档可信度判断规则
  96. :param text_vec:
  97. :param data_es:
  98. :return:
  99. """
  100. if data_es[0][2] > 0.945: # es查询结果为正序如果满足极大值或者较大差异直接返回第一个数据
  101. return data_es[0][1], 'mode01'
  102. output_lst = []
  103. word_count = Counter([row[1] for row in data_es if row[2] > 0.85])
  104. if not word_count:
  105. return '', ''
  106. max_pair = max(word_count.items(), key=lambda x: x[1]) # 统计词频如果词频最大值大于等于5/7则输出该值
  107. if max_pair[1]/len(data_es) >= (int(len(data_es)/2)+1)/len(data_es):
  108. return max_pair[0],'mode01'
  109. for i in range(min(3, len(data_es))): # 只判断前三个元素
  110. if data_es[i][0] == data_es[0][0] and data_es[i][2] > 0.85:
  111. output_lst.append(data_es[i][1])
  112. if len(output_lst) == 3:
  113. return data_es[0][1], 'mode01'
  114. elif data_es[0][2] - data_es[1][2] > 0.02 and data_es[0][2] > 0.91: # 第一个结果极大于后面
  115. p_code = data_es[0][1][:-2]
  116. if not p_code or (data_es[0][2] > 0.91): # 如果es得分第一的结果只有一层且相似度大于0.9就默认是正确
  117. return data_es[0][1], 'mode01'
  118. else:
  119. p_name = m.get_name(p_code)
  120. pvec = text_to_vector(p_name)
  121. similarity = round(dot(text_vec, pvec) / (norm(text_vec) * norm(pvec)), 4) # 相似度计算
  122. if similarity > 0.8: # 标的物与该分类的父级的相似度
  123. return data_es[0][1], 'mode01'
  124. return '', ''
  125. def run_mode1(text,baseclass=None):
  126. """
  127. 标的物数字化主函数
  128. :param text:
  129. :param classify_name
  130. :return: result_name:结果名称, similarity:结果与输入文本相似度, mode:流程模式, code:结果编码, credibility:结果可信度
  131. """
  132. try:
  133. vec = text_to_vector(text) # 转成向量
  134. # search_result = m.search_good(vec,7,baseclass) # 查询结果
  135. search_result = m.search_industry(vec,['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'],
  136. industry_list=['物业'])
  137. print(search_result)
  138. similarity = 0 # 文本与结果相似度
  139. result_name = '' # 分类名称
  140. mode = '' # 分类判断模式
  141. credibility = 0 # 可信度
  142. code = '' # 代码
  143. if search_result and len(search_result) > 2:
  144. check_result = check(vec, search_result) # 结果筛选
  145. if check_result[0]:
  146. code = check_result[0]
  147. result_name = m.get_name(code) # 名称映射
  148. mode = check_result[1]
  149. if result_name:
  150. res_vec = text_to_vector(result_name)
  151. similarity = round(dot(vec, res_vec) / (norm(vec) * norm(res_vec)), 4) # 相似度计算
  152. if mode == 'mode1':
  153. credibility = 0.90
  154. if mode == 'mode01':
  155. credibility = 0.95
  156. if credibility == 0: # 可信度为0 则用第一个结果作为输出
  157. result_name = search_result[0][0]
  158. similarity = search_result[0][2]
  159. mode = ''
  160. code = search_result[0][1]
  161. return [result_name, similarity, mode, code, credibility]
  162. except Exception as e:
  163. print('errrrrssss',e)
  164. print(text,baseclass)
  165. return []
  166. def run_mode1_main(text,baseclass):
  167. try:
  168. result = run_mode1(text, baseclass)
  169. if not result[0]:
  170. return ['', '', '', 0, '']
  171. result.pop(2)
  172. code = result[2]
  173. route = m.get_root_zc(code)
  174. result.append(route)
  175. return result
  176. except Exception as e:
  177. print('errrrr',e)
  178. return ['', '', '', '', '']
  179. if __name__ == '__main__':
  180. a = [('服务', 'C', 0.9276, '服务/', '服务'), ('审计服务', 'C2303', 0.9264, '商务服务/审计服务/', '年审计服务'), ('运行维护服务', 'C1607', 0.9234, '信息技术服务/运行维护服务/', '年信息安全服务'), ('物业管理服务', 'C2104', 0.9202, '房地产服务/物业管理服务/', '年物业服务'), ('服务', 'C', 0.9145, '服务/', '综合服务'), ('会议服务', 'C2201', 0.9111, '会议、展览、住宿和餐饮服务/会议服务/', '会务服务'), ('软件运维服务', 'C160703', 0.9108, '信息技术服务/运行维护服务/软件运维服务/', '业务系统服务')]
  181. v = text_to_vector('xxxx')
  182. print(check(v, a))