123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- # -*- coding: utf-8 -*-
- # @Time : 2023/10/11
- # @Author : lkj
- # @description : milvus
- from typing import Optional, List
- import numpy
- from numpy import dot
- from numpy.linalg import norm
- from pymilvus import (connections, FieldSchema,Collection,CollectionSchema,DataType,utility)
- from config import redis_config
- from utils.redis_helper import RedisString
- # 使用默认数据库 ‘default’,也可以自己建数据库
- class Milvus(object):
- def __init__(self,table,**kwargs):
- connections.connect(**kwargs)
- self.col = Collection(table)
- self.search_params = {
- "metric_type": "L2",
- "ignore_growing": False,
- "params": {"nprobe": 100},
- }
- self.r = RedisString(redis_config)
- def search(self,vec,fileds=None,expr=None):
- res = self.col.search([vec], 'embeddings', self.search_params, 7,
- output_fields=fileds,
- expr=expr)
- return res
- def load(self):
- self.col.load()
- def release(self):
- self.col.release()
- def delete(self,expr):
- self.col.delete(expr=expr)
- def query(self,q):
- res = self.col.query(expr=q)
- return res
- def insert(self,data):
- self.col.insert(data=data)
- @staticmethod
- def sim(a: list, b: list):
- """
- 余弦计算两个向量相似度
- :param a:
- :param b:
- :return:
- """
- s = dot(a, b) / (norm(a) * norm(b))
- return round(s, 4)
- def get_name(self,code):
- """"
- 基于redis查询code对应name
- """
- while code[-1] == '0' and code[-2] == '0':
- code = code[:-1]
- code = code[:-1]
- name = self.r.string_get('jycode_' + code)
- return name
- def get_root_zc(self, re_code):
- """
- 根据code查询对应root
- """
- split = 0
- root = ''
- level = (len(re_code) - 1) / 2
- for i in range(int(level)):
- if split == 0:
- c = re_code
- else:
- c = re_code[:-split]
- split += 2
- name_code = self.get_name(c)
- root = name_code + '/' + root
- return root
- def search_good(self,vec,num=7,base=None):
- try:
- res = self.col.search([vec],
- 'embeddings', self.search_params, num,
- output_fields=['code','class_name','embeddings','explain','root'],) #expr=f'baseclass=="{base}"'
- result_list = []
- for hit in res:
- # print(hit)
- for row in hit:
- row = row.to_dict()
- code = row.get('entity',{}).get('code','')
- while code[-1] == '0' and code[-2] == '0':
- code = code[:-1]
- code = code[:-1]
- explain = row.get('entity', {}).get('explain', '')
- root = self.get_root_zc(code)
- if not root:
- root = row.get('entity', {}).get('root', '')
- vec_cls = row.get('entity', {}).get('embeddings', [])
- sim_res = self.sim(vec, vec_cls)
- cls_name = row.get('entity', {}).get('class_name', '')
- result_list.append((cls_name, code, sim_res, root, explain))
- result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
- return result_list
- except Exception as e:
- print('关系库错误:', e)
- return []
- def search_industry(self, vec, output_fields: list, topk=7,
- industry_list: Optional[List[str]] = None,):
- """
- 查询统计局分类函数
- vec : 标的物转成的向量
- industry_list:行业范围 -->list
- """
- try:
- public = True
- if industry_list:
- # expr = f'industry in {industry_list}'
- public = False
- else:
- expr = None
- res = self.col.search([vec], 'embeddings', self.search_params, topk,
- output_fields=output_fields,
- expr=None)
- result_list = []
- for hit in res:
- for row in hit:
- row = row.to_dict()
- code = row.get('entity', {}).get('code', '')
- cls_name = row.get('entity', {}).get('class_name', '')
- if not public:
- code = row.get('entity', {}).get('private_code', '')
- if code == 'null':
- code = row.get('entity', {}).get('code', '')
- cls_name = self.get_name(code)
- while code[-1] == '0' and code[-2] == '0':
- code = code[:-1]
- code = code[:-1]
- explain = row.get('entity', {}).get('explain', '')
- root = self.get_root_zc(code)
- if not root:
- root = row.get('entity', {}).get('root', '')
- vec_cls = row.get('entity', {}).get('embeddings', [])
- sim_res = self.sim(vec, vec_cls)
- result_list.append((cls_name, code, sim_res, root, explain))
- result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
- return result_list
- except Exception as e:
- print('统计局分类错误:', e)
- return []
- def search_china(self, vec,base=None,topk=7):
- """
- 查询统计局分类函数
- vec : 标的物转成的向量
- base:标的物分类
- """
- try:
- res = self.col.search([vec], 'embeddings', self.search_params, topk,
- output_fields=['code','class_name','embeddings','explain','root'],
- )#expr=f'baseclass=="{base}"'
- result_list = []
- for hit in res:
- # print(hit)
- for row in hit:
- row = row.to_dict()
- code = row.get('entity',{}).get('code','')
- while code[-1] == '0' and code[-2] == '0':
- code = code[:-1]
- code = code[:-1]
- ids = row['id']
- explain = row.get('entity', {}).get('explain', '')
- root = self.get_root_zc(code)
- if not root:
- root = row.get('entity', {}).get('root', '')
- vec_cls = row.get('entity', {}).get('embeddings', [])
- sim_res = self.sim(vec, vec_cls)
- cls_name = row.get('entity', {}).get('class_name', '')
- result_list.append((cls_name, code, sim_res, root, explain))
- result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
- return result_list
- except Exception as e:
- print('统计局分类错误:', e)
- return []
- def update(self, data):
- self.col.insert(data)
- def create():
- onn = connections.connect(db_name="classify", host="192.168.3.109", port=19530)
- fields = [
- FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
- FieldSchema(name="class_name", dtype=DataType.VARCHAR, max_length=100),
- FieldSchema(name='code', dtype=DataType.VARCHAR, max_length=16),
- FieldSchema(name="p_name", dtype=DataType.VARCHAR, max_length=100),
- FieldSchema(name='p_code', dtype=DataType.VARCHAR, max_length=16),
- # FieldSchema(name='baseclass', dtype=DataType.VARCHAR, max_length=10),
- FieldSchema(name='explain', dtype=DataType.VARCHAR, max_length=500),
- # FieldSchema(name='root', dtype=DataType.VARCHAR, max_length=300),
- FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=768),
- ]
- schema = CollectionSchema(fields, description='场景分类')
- milvus_conn = Collection('scene', schema=schema)
- index = {"index_type": "IVF_FLAT", "metric_type": "L2", "params":{"nlist":128}}
- milvus_conn.create_index("embeddings", index)
- if __name__ == '__main__':
- # create()
- # pass
- from request_fun import text_to_vector
- from config import milvus_config
- text = '施工人员派遣'
- vector = text_to_vector(text)
- col = Milvus('jianyu_code',**milvus_config)
- print(col.search_industry(vector, ['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'],
- industry_list=['物业']))
|