# -*- coding: utf-8 -*- # @Time : 2023/10/11 # @Author : lkj # @description : milvus from typing import Optional, List import numpy from numpy import dot from numpy.linalg import norm from pymilvus import (connections, FieldSchema,Collection,CollectionSchema,DataType,utility) from config import redis_config from utils.redis_helper import RedisString # 使用默认数据库 ‘default’,也可以自己建数据库 class Milvus(object): def __init__(self,table,**kwargs): connections.connect(**kwargs) self.col = Collection(table) self.search_params = { "metric_type": "L2", "ignore_growing": False, "params": {"nprobe": 100}, } self.r = RedisString(redis_config) def search(self,vec,fileds=None,expr=None): res = self.col.search([vec], 'embeddings', self.search_params, 7, output_fields=fileds, expr=expr) return res def load(self): self.col.load() def release(self): self.col.release() def delete(self,expr): self.col.delete(expr=expr) def query(self,q): res = self.col.query(expr=q) return res def insert(self,data): self.col.insert(data=data) @staticmethod def sim(a: list, b: list): """ 余弦计算两个向量相似度 :param a: :param b: :return: """ s = dot(a, b) / (norm(a) * norm(b)) return round(s, 4) def get_name(self,code): """" 基于redis查询code对应name """ while code[-1] == '0' and code[-2] == '0': code = code[:-1] code = code[:-1] name = self.r.string_get('jycode_' + code) return name def get_root_zc(self, re_code): """ 根据code查询对应root """ split = 0 root = '' level = (len(re_code) - 1) / 2 for i in range(int(level)): if split == 0: c = re_code else: c = re_code[:-split] split += 2 name_code = self.get_name(c) root = name_code + '/' + root return root def search_good(self,vec,num=7,base=None): try: res = self.col.search([vec], 'embeddings', self.search_params, num, output_fields=['code','class_name','embeddings','explain','root'],) #expr=f'baseclass=="{base}"' result_list = [] for hit in res: # print(hit) for row in hit: row = row.to_dict() code = row.get('entity',{}).get('code','') while code[-1] == '0' and code[-2] == '0': code = code[:-1] code = code[:-1] explain = row.get('entity', {}).get('explain', '') root = self.get_root_zc(code) if not root: root = row.get('entity', {}).get('root', '') vec_cls = row.get('entity', {}).get('embeddings', []) sim_res = self.sim(vec, vec_cls) cls_name = row.get('entity', {}).get('class_name', '') result_list.append((cls_name, code, sim_res, root, explain)) result_list = sorted(result_list, key=lambda x: x[2], reverse=True) return result_list except Exception as e: print('关系库错误:', e) return [] def search_industry(self, vec, output_fields: list, topk=7, industry_list: Optional[List[str]] = None,): """ 查询统计局分类函数 vec : 标的物转成的向量 industry_list:行业范围 -->list """ try: public = True if industry_list: # expr = f'industry in {industry_list}' public = False else: expr = None res = self.col.search([vec], 'embeddings', self.search_params, topk, output_fields=output_fields, expr=None) result_list = [] for hit in res: for row in hit: row = row.to_dict() code = row.get('entity', {}).get('code', '') cls_name = row.get('entity', {}).get('class_name', '') if not public: code = row.get('entity', {}).get('private_code', '') if code == 'null': code = row.get('entity', {}).get('code', '') cls_name = self.get_name(code) while code[-1] == '0' and code[-2] == '0': code = code[:-1] code = code[:-1] explain = row.get('entity', {}).get('explain', '') root = self.get_root_zc(code) if not root: root = row.get('entity', {}).get('root', '') vec_cls = row.get('entity', {}).get('embeddings', []) sim_res = self.sim(vec, vec_cls) result_list.append((cls_name, code, sim_res, root, explain)) result_list = sorted(result_list, key=lambda x: x[2], reverse=True) return result_list except Exception as e: print('统计局分类错误:', e) return [] def search_china(self, vec,base=None,topk=7): """ 查询统计局分类函数 vec : 标的物转成的向量 base:标的物分类 """ try: res = self.col.search([vec], 'embeddings', self.search_params, topk, output_fields=['code','class_name','embeddings','explain','root'], )#expr=f'baseclass=="{base}"' result_list = [] for hit in res: # print(hit) for row in hit: row = row.to_dict() code = row.get('entity',{}).get('code','') while code[-1] == '0' and code[-2] == '0': code = code[:-1] code = code[:-1] ids = row['id'] explain = row.get('entity', {}).get('explain', '') root = self.get_root_zc(code) if not root: root = row.get('entity', {}).get('root', '') vec_cls = row.get('entity', {}).get('embeddings', []) sim_res = self.sim(vec, vec_cls) cls_name = row.get('entity', {}).get('class_name', '') result_list.append((cls_name, code, sim_res, root, explain)) result_list = sorted(result_list, key=lambda x: x[2], reverse=True) return result_list except Exception as e: print('统计局分类错误:', e) return [] def update(self, data): self.col.insert(data) def create(): onn = connections.connect(db_name="classify", host="192.168.3.109", port=19530) fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="class_name", dtype=DataType.VARCHAR, max_length=100), FieldSchema(name='code', dtype=DataType.VARCHAR, max_length=16), FieldSchema(name="p_name", dtype=DataType.VARCHAR, max_length=100), FieldSchema(name='p_code', dtype=DataType.VARCHAR, max_length=16), # FieldSchema(name='baseclass', dtype=DataType.VARCHAR, max_length=10), FieldSchema(name='explain', dtype=DataType.VARCHAR, max_length=500), # FieldSchema(name='root', dtype=DataType.VARCHAR, max_length=300), FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=768), ] schema = CollectionSchema(fields, description='场景分类') milvus_conn = Collection('scene', schema=schema) index = {"index_type": "IVF_FLAT", "metric_type": "L2", "params":{"nlist":128}} milvus_conn.create_index("embeddings", index) if __name__ == '__main__': # create() # pass from request_fun import text_to_vector from config import milvus_config text = '施工人员派遣' vector = text_to_vector(text) col = Milvus('jianyu_code',**milvus_config) print(col.search_industry(vector, ['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'], industry_list=['物业']))