|
@@ -0,0 +1,232 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# @Time : 2023/10/11
|
|
|
+# @Author : lkj
|
|
|
+# @description : milvus
|
|
|
+from typing import Optional, List
|
|
|
+
|
|
|
+import numpy
|
|
|
+from numpy import dot
|
|
|
+from numpy.linalg import norm
|
|
|
+from pymilvus import (connections, FieldSchema,Collection,CollectionSchema,DataType,utility)
|
|
|
+from config import redis_config
|
|
|
+from utils.redis_helper import RedisString
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# 使用默认数据库 ‘default’,也可以自己建数据库
|
|
|
+
|
|
|
+
|
|
|
+class Milvus(object):
|
|
|
+ def __init__(self,table,**kwargs):
|
|
|
+
|
|
|
+ connections.connect(**kwargs)
|
|
|
+ self.col = Collection(table)
|
|
|
+ self.search_params = {
|
|
|
+ "metric_type": "L2",
|
|
|
+ "ignore_growing": False,
|
|
|
+ "params": {"nprobe": 100},
|
|
|
+ }
|
|
|
+ self.r = RedisString(redis_config)
|
|
|
+
|
|
|
+ def search(self,vec,fileds=None,expr=None):
|
|
|
+ res = self.col.search([vec], 'embeddings', self.search_params, 7,
|
|
|
+ output_fields=fileds,
|
|
|
+ expr=expr)
|
|
|
+ return res
|
|
|
+
|
|
|
+ def load(self):
|
|
|
+ self.col.load()
|
|
|
+
|
|
|
+ def release(self):
|
|
|
+ self.col.release()
|
|
|
+
|
|
|
+ def delete(self,expr):
|
|
|
+ self.col.delete(expr=expr)
|
|
|
+
|
|
|
+ def query(self,q):
|
|
|
+ res = self.col.query(expr=q)
|
|
|
+ return res
|
|
|
+
|
|
|
+ def insert(self,data):
|
|
|
+ self.col.insert(data=data)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def sim(a: list, b: list):
|
|
|
+ """
|
|
|
+ 余弦计算两个向量相似度
|
|
|
+ :param a:
|
|
|
+ :param b:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ s = dot(a, b) / (norm(a) * norm(b))
|
|
|
+ return round(s, 4)
|
|
|
+
|
|
|
+ def get_name(self,code):
|
|
|
+ """"
|
|
|
+ 基于redis查询code对应name
|
|
|
+ """
|
|
|
+ while code[-1] == '0' and code[-2] == '0':
|
|
|
+ code = code[:-1]
|
|
|
+ code = code[:-1]
|
|
|
+ name = self.r.string_get('jycode_' + code)
|
|
|
+ return name
|
|
|
+
|
|
|
+ def get_root_zc(self, re_code):
|
|
|
+ """
|
|
|
+ 根据code查询对应root
|
|
|
+ """
|
|
|
+ split = 0
|
|
|
+ root = ''
|
|
|
+ level = (len(re_code) - 1) / 2
|
|
|
+ for i in range(int(level)):
|
|
|
+ if split == 0:
|
|
|
+ c = re_code
|
|
|
+ else:
|
|
|
+ c = re_code[:-split]
|
|
|
+ split += 2
|
|
|
+ name_code = self.get_name(c)
|
|
|
+ root = name_code + '/' + root
|
|
|
+ return root
|
|
|
+
|
|
|
+ def search_good(self,vec,num=7,base=None):
|
|
|
+ try:
|
|
|
+ res = self.col.search([vec],
|
|
|
+ 'embeddings', self.search_params, num,
|
|
|
+ output_fields=['code','class_name','embeddings','explain','root'],) #expr=f'baseclass=="{base}"'
|
|
|
+ result_list = []
|
|
|
+ for hit in res:
|
|
|
+ # print(hit)
|
|
|
+ for row in hit:
|
|
|
+ row = row.to_dict()
|
|
|
+ code = row.get('entity',{}).get('code','')
|
|
|
+ while code[-1] == '0' and code[-2] == '0':
|
|
|
+ code = code[:-1]
|
|
|
+ code = code[:-1]
|
|
|
+ explain = row.get('entity', {}).get('explain', '')
|
|
|
+ root = self.get_root_zc(code)
|
|
|
+ if not root:
|
|
|
+ root = row.get('entity', {}).get('root', '')
|
|
|
+ vec_cls = row.get('entity', {}).get('embeddings', [])
|
|
|
+ sim_res = self.sim(vec, vec_cls)
|
|
|
+ cls_name = row.get('entity', {}).get('class_name', '')
|
|
|
+ result_list.append((cls_name, code, sim_res, root, explain))
|
|
|
+ result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
|
|
|
+ return result_list
|
|
|
+ except Exception as e:
|
|
|
+ print('关系库错误:', e)
|
|
|
+ return []
|
|
|
+
|
|
|
+ def search_industry(self, vec, output_fields: list, topk=7,
|
|
|
+ industry_list: Optional[List[str]] = None,):
|
|
|
+ """
|
|
|
+ 查询统计局分类函数
|
|
|
+ vec : 标的物转成的向量
|
|
|
+ industry_list:行业范围 -->list
|
|
|
+
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ public = True
|
|
|
+ if industry_list:
|
|
|
+ # expr = f'industry in {industry_list}'
|
|
|
+ public = False
|
|
|
+ else:
|
|
|
+ expr = None
|
|
|
+ res = self.col.search([vec], 'embeddings', self.search_params, topk,
|
|
|
+ output_fields=output_fields,
|
|
|
+ expr=None)
|
|
|
+ result_list = []
|
|
|
+ for hit in res:
|
|
|
+ for row in hit:
|
|
|
+ row = row.to_dict()
|
|
|
+ code = row.get('entity', {}).get('code', '')
|
|
|
+ cls_name = row.get('entity', {}).get('class_name', '')
|
|
|
+ if not public:
|
|
|
+ code = row.get('entity', {}).get('private_code', '')
|
|
|
+ if code == 'null':
|
|
|
+ code = row.get('entity', {}).get('code', '')
|
|
|
+ cls_name = self.get_name(code)
|
|
|
+ while code[-1] == '0' and code[-2] == '0':
|
|
|
+ code = code[:-1]
|
|
|
+ code = code[:-1]
|
|
|
+ explain = row.get('entity', {}).get('explain', '')
|
|
|
+ root = self.get_root_zc(code)
|
|
|
+ if not root:
|
|
|
+ root = row.get('entity', {}).get('root', '')
|
|
|
+ vec_cls = row.get('entity', {}).get('embeddings', [])
|
|
|
+ sim_res = self.sim(vec, vec_cls)
|
|
|
+ result_list.append((cls_name, code, sim_res, root, explain))
|
|
|
+ result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
|
|
|
+ return result_list
|
|
|
+ except Exception as e:
|
|
|
+ print('统计局分类错误:', e)
|
|
|
+ return []
|
|
|
+
|
|
|
+ def search_china(self, vec,base=None,topk=7):
|
|
|
+ """
|
|
|
+ 查询统计局分类函数
|
|
|
+ vec : 标的物转成的向量
|
|
|
+ base:标的物分类
|
|
|
+
|
|
|
+ """
|
|
|
+ try:
|
|
|
+
|
|
|
+ res = self.col.search([vec], 'embeddings', self.search_params, topk,
|
|
|
+ output_fields=['code','class_name','embeddings','explain','root'],
|
|
|
+ )#expr=f'baseclass=="{base}"'
|
|
|
+ result_list = []
|
|
|
+ for hit in res:
|
|
|
+ # print(hit)
|
|
|
+ for row in hit:
|
|
|
+
|
|
|
+ row = row.to_dict()
|
|
|
+ code = row.get('entity',{}).get('code','')
|
|
|
+ while code[-1] == '0' and code[-2] == '0':
|
|
|
+ code = code[:-1]
|
|
|
+ code = code[:-1]
|
|
|
+ ids = row['id']
|
|
|
+ explain = row.get('entity', {}).get('explain', '')
|
|
|
+ root = self.get_root_zc(code)
|
|
|
+ if not root:
|
|
|
+ root = row.get('entity', {}).get('root', '')
|
|
|
+ vec_cls = row.get('entity', {}).get('embeddings', [])
|
|
|
+ sim_res = self.sim(vec, vec_cls)
|
|
|
+ cls_name = row.get('entity', {}).get('class_name', '')
|
|
|
+ result_list.append((cls_name, code, sim_res, root, explain))
|
|
|
+ result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
|
|
|
+ return result_list
|
|
|
+ except Exception as e:
|
|
|
+ print('统计局分类错误:', e)
|
|
|
+ return []
|
|
|
+
|
|
|
+ def update(self, data):
|
|
|
+ self.col.insert(data)
|
|
|
+
|
|
|
+def create():
|
|
|
+ onn = connections.connect(db_name="classify", host="192.168.3.109", port=19530)
|
|
|
+ fields = [
|
|
|
+ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
|
+ FieldSchema(name="class_name", dtype=DataType.VARCHAR, max_length=100),
|
|
|
+ FieldSchema(name='code', dtype=DataType.VARCHAR, max_length=16),
|
|
|
+ FieldSchema(name="p_name", dtype=DataType.VARCHAR, max_length=100),
|
|
|
+ FieldSchema(name='p_code', dtype=DataType.VARCHAR, max_length=16),
|
|
|
+ # FieldSchema(name='baseclass', dtype=DataType.VARCHAR, max_length=10),
|
|
|
+ FieldSchema(name='explain', dtype=DataType.VARCHAR, max_length=500),
|
|
|
+ # FieldSchema(name='root', dtype=DataType.VARCHAR, max_length=300),
|
|
|
+ FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=768),
|
|
|
+ ]
|
|
|
+
|
|
|
+ schema = CollectionSchema(fields, description='场景分类')
|
|
|
+ milvus_conn = Collection('scene', schema=schema)
|
|
|
+ index = {"index_type": "IVF_FLAT", "metric_type": "L2", "params":{"nlist":128}}
|
|
|
+ milvus_conn.create_index("embeddings", index)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ # create()
|
|
|
+ # pass
|
|
|
+ from request_fun import text_to_vector
|
|
|
+ from config import milvus_config
|
|
|
+ text = '施工人员派遣'
|
|
|
+ vector = text_to_vector(text)
|
|
|
+ col = Milvus('jianyu_code',**milvus_config)
|
|
|
+ print(col.search_industry(vector, ['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'],
|
|
|
+ industry_list=['物业']))
|