liukangjia 1 rok temu
rodzic
commit
d8d5111d40

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/../../../../:\jywork\text_to_vector\.idea/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/

+ 73 - 0
.idea/inspectionProfiles/Project_Default.xml

@@ -0,0 +1,73 @@
+<component name="InspectionProjectProfileManager">
+  <profile version="1.0">
+    <option name="myName" value="Project Default" />
+    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="ignoredPackages">
+        <value>
+          <list size="60">
+            <item index="0" class="java.lang.String" itemvalue="traitlets" />
+            <item index="1" class="java.lang.String" itemvalue="gensim" />
+            <item index="2" class="java.lang.String" itemvalue="protobuf" />
+            <item index="3" class="java.lang.String" itemvalue="greenlet" />
+            <item index="4" class="java.lang.String" itemvalue="huggingface-hub" />
+            <item index="5" class="java.lang.String" itemvalue="gast" />
+            <item index="6" class="java.lang.String" itemvalue="jupyter-server" />
+            <item index="7" class="java.lang.String" itemvalue="torchvision" />
+            <item index="8" class="java.lang.String" itemvalue="redis" />
+            <item index="9" class="java.lang.String" itemvalue="paddlepaddle" />
+            <item index="10" class="java.lang.String" itemvalue="filelock" />
+            <item index="11" class="java.lang.String" itemvalue="pyzmq" />
+            <item index="12" class="java.lang.String" itemvalue="bleach" />
+            <item index="13" class="java.lang.String" itemvalue="certifi" />
+            <item index="14" class="java.lang.String" itemvalue="lxml" />
+            <item index="15" class="java.lang.String" itemvalue="beautifulsoup4" />
+            <item index="16" class="java.lang.String" itemvalue="tokenizers" />
+            <item index="17" class="java.lang.String" itemvalue="nbclassic" />
+            <item index="18" class="java.lang.String" itemvalue="transformers" />
+            <item index="19" class="java.lang.String" itemvalue="jupyter_client" />
+            <item index="20" class="java.lang.String" itemvalue="etcd3" />
+            <item index="21" class="java.lang.String" itemvalue="cryptography" />
+            <item index="22" class="java.lang.String" itemvalue="pexpect" />
+            <item index="23" class="java.lang.String" itemvalue="nbconvert" />
+            <item index="24" class="java.lang.String" itemvalue="attrs" />
+            <item index="25" class="java.lang.String" itemvalue="simhash" />
+            <item index="26" class="java.lang.String" itemvalue="regex" />
+            <item index="27" class="java.lang.String" itemvalue="PyMySQL" />
+            <item index="28" class="java.lang.String" itemvalue="html-table-extractor" />
+            <item index="29" class="java.lang.String" itemvalue="ptyprocess" />
+            <item index="30" class="java.lang.String" itemvalue="SQLAlchemy" />
+            <item index="31" class="java.lang.String" itemvalue="sklearn" />
+            <item index="32" class="java.lang.String" itemvalue="wcwidth" />
+            <item index="33" class="java.lang.String" itemvalue="importlib-metadata" />
+            <item index="34" class="java.lang.String" itemvalue="mysql-connector-python" />
+            <item index="35" class="java.lang.String" itemvalue="websocket-client" />
+            <item index="36" class="java.lang.String" itemvalue="zipp" />
+            <item index="37" class="java.lang.String" itemvalue="tenacity" />
+            <item index="38" class="java.lang.String" itemvalue="urllib3" />
+            <item index="39" class="java.lang.String" itemvalue="pymilvus" />
+            <item index="40" class="java.lang.String" itemvalue="scipy" />
+            <item index="41" class="java.lang.String" itemvalue="six" />
+            <item index="42" class="java.lang.String" itemvalue="nbformat" />
+            <item index="43" class="java.lang.String" itemvalue="packaging" />
+            <item index="44" class="java.lang.String" itemvalue="torch" />
+            <item index="45" class="java.lang.String" itemvalue="grpcio-tools" />
+            <item index="46" class="java.lang.String" itemvalue="prometheus-client" />
+            <item index="47" class="java.lang.String" itemvalue="mistune" />
+            <item index="48" class="java.lang.String" itemvalue="pandas" />
+            <item index="49" class="java.lang.String" itemvalue="importlib-resources" />
+            <item index="50" class="java.lang.String" itemvalue="jupyter-console" />
+            <item index="51" class="java.lang.String" itemvalue="typing_extensions" />
+            <item index="52" class="java.lang.String" itemvalue="debugpy" />
+            <item index="53" class="java.lang.String" itemvalue="Flask-BasicAuth" />
+            <item index="54" class="java.lang.String" itemvalue="tensorboardX" />
+            <item index="55" class="java.lang.String" itemvalue="grpcio" />
+            <item index="56" class="java.lang.String" itemvalue="pycryptodome" />
+            <item index="57" class="java.lang.String" itemvalue="pytz" />
+            <item index="58" class="java.lang.String" itemvalue="Pillow" />
+            <item index="59" class="java.lang.String" itemvalue="ujson" />
+          </list>
+        </value>
+      </option>
+    </inspection_tool>
+  </profile>
+</component>

+ 6 - 0
.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 4 - 0
.idea/misc.xml

@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (nlp) (2)" project-jdk-type="Python SDK" />
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/text_to_vector.iml" filepath="$PROJECT_DIR$/.idea/text_to_vector.iml" />
+    </modules>
+  </component>
+</project>

+ 8 - 0
.idea/text_to_vector.iml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="Python 3.7 (nlp) (2)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

BIN
__pycache__/config.cpython-37.pyc


+ 66 - 0
config.py

@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/7/25 
+# @Author  : lkj
+# @description :
+import json
+
+milvus_config = {'db_name':'classify',
+                 'host':"192.168.3.109",
+                 # 'host':'172.17.162.35',
+                 'port':"19530"}  # milvus 配置
+
+daili = '172.17.4.188:9090'  # 服务代理中心172.17.4.188
+vector_url = 'http://192.168.3.109:19104/t2v'
+
+
+redis_config = {
+    'host':'192.168.3.109'
+    ,'port':6379,
+    'db':1,
+    'pwd':'root'
+}
+
+mongo_input_path = {
+            # "port": "192.168.3.71:29099",
+            "user":"JSYJZ_RWBidAi_ProG",
+            "password":"JSLi@20LiefK3d",
+            "db": "qfw",
+            "col": "bidding",
+            # "db": "re4art",
+            # "col": "china_good_classify_test",
+}
+# mongo_input_path2 = {
+#             "port": "192.168.3.71:29099",
+#             # "user":"JSYJZ_RWBidAi_ProG",
+#             # "password":"JSLi@20LiefK3d",
+#             "db": "re4art",
+#             "col": "goods_count2",
+# }
+
+mysql_path2 = {
+    # 'mysql_host': '172.17.46.46',
+    # 'mysql_host': '192.168.3.206',
+    'mysql_user': 'root',
+    'mysql_password': 'Topnet123_jycms',
+    # 'mysql_password': '123456',
+    'mysql_db': 'Call_Accounting',
+    # 'mysql_db': 'jy',
+    'mysql_port': '3376',
+    # 'mysql_port': '3306'
+}
+
+mysql_path = {
+    # 'mysql_host': '172.17.4.242',
+    'mysql_host': '192.168.3.14',
+    # 'mysql_host': '127.0.0.1',
+    # 'mysql_user': 'liukangjia',
+    'mysql_user': 'root',
+    # 'mysql_password': 'Lkj#20230630N',
+    'mysql_password': '=PDT49#80Z!RVv52_z',
+    # 'mysql_password': '123456',
+    # 'mysql_db': 'Call_Accounting',
+    'mysql_db': 'lkj',
+    'mysql_port': '4000',
+    # 'mysql_port': '3306'
+}
+#

+ 241 - 0
data/stoptext.txt

@@ -0,0 +1,241 @@
+竞争性谈判文件
+答疑补遗文件
+分谈分签
+公开选取
+重更改
+集采
+项目包
+自行采购
+第三包
+标段
+网上商城
+直购订单
+第四包
+第一包
+第二包
+合同信息
+合同公告
+询价文件
+第十一次
+履约验收
+中标结果公示
+某部
+某单位
+服务中心
+第六包
+二包
+一包
+网上超市
+三包
+异常公告
+公示单
+四包
+某部
+政府采购意向
+五包
+标项一
+标项二
+预公示
+项目部
+标项三
+标项四
+重更改
+项目包
+第三包
+第四包
+第一包
+第二包
+第六包
+第五包
+更改通知
+集采商品
+控制价
+征求意见公示
+重采购
+重废标
+人民政府
+重流标
+竞争性
+答复函
+包入围
+综合项目
+定点单位
+预公告
+重终止
+重项目
+评标公示
+候选人
+重成交
+澄清公告
+招标文件
+投标文件
+废标文件
+定点采购
+撤销公告
+库征集
+中标公告
+工程类
+工字号
+包采公
+总承包
+承包商
+承包人
+工程量清单
+补遗文件
+邀请书
+拦标价
+调整公告
+需求方案设计
+竞标文件
+采购文件
+更正文件
+响应文件
+取消公告
+需求论证
+单一来源
+物投资
+邀请书
+代理公司
+预算编号
+终止公告
+暂停公告
+备案号
+预公告
+国际招标
+供应商入围资格
+竞价公告
+结果公告
+补充通知
+采购信息
+招标备案
+失败公示
+公开招标
+重更正
+成交公告
+政采工
+网上询价
+质疑答复函
+预中标
+集中采购
+竞争性磋商
+答复文件
+文件公示
+公开招标公告
+招标控制价文件
+竞争性谈判文件
+磋商文件
+代理机构服务
+务实竞争性谈判
+资格预审文件
+政府采购
+合同公示
+再预公示
+截止时间
+地点信息
+招标控制价最高投标限价公布
+询价公告
+中标成交公告
+更正公告
+第五包
+更改通知
+控制价
+征求意见公示
+重采购
+重废标
+人民政府
+重流标
+竞争性
+答复函
+包入围
+综合项目
+定点单位
+预公告
+重终止
+重项目
+评标公示
+候选人
+重成交
+澄清公告
+招标文件
+投标文件
+废标文件
+定点采购
+撤销公告
+库征集
+中标公告
+工程类
+工字号
+包采公
+总承包
+国际公开重新招标公告
+货物类
+审核前
+企业入围
+征求意见
+采购清单
+电子化公开招标
+交通类
+非招标采购供应商库
+定点服务单位
+招标采购供应商库
+痕迹类
+评审工作
+厦门中实
+厦门务实
+承包商
+承包人
+工程量清单
+补遗文件
+邀请书
+拦标价
+代理机构服务
+项目
+供货商
+第二批次
+第三批次
+第一批次
+第四批次
+第五批次
+第六批次
+第七批次
+政采编号
+项目管理
+调整公告
+第1包组
+评审结果
+釆购项目
+新一期
+第一次
+第二次
+第三次
+第四次
+第五次
+第六次
+服务类
+货物类
+工程类
+采购项目
+第七次
+第二十
+2包
+1包
+供应商库
+采购项
+3包
+4包
+二采购
+一采购
+三采购
+四采购
+需求方案设计
+及有关
+造价单位
+拟采购
+分谈
+标后
+定点议价
+中标候选人公示
+定点服务
+电子卖场
+电子验收单
+综合服务

BIN
proto/__pycache__/text2vector_pb2.cpython-37.pyc


+ 12 - 0
proto/text2vector.proto

@@ -0,0 +1,12 @@
+syntax = "proto3";
+package proto;
+
+//请求
+message Text2VectorReq{
+    string text = 1;
+}
+
+//响应
+message Text2VectorResp{
+  repeated float vector = 1;
+}

+ 107 - 0
proto/text2vector_pb2.py

@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: text2vector.proto
+
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+  name='text2vector.proto',
+  package='proto',
+  syntax='proto3',
+  serialized_options=None,
+  serialized_pb=b'\n\x11text2vector.proto\x12\x05proto\"\x1e\n\x0eText2VectorReq\x12\x0c\n\x04text\x18\x01 \x01(\t\"!\n\x0fText2VectorResp\x12\x0e\n\x06vector\x18\x01 \x03(\x02\x62\x06proto3'
+)
+
+
+
+
+_TEXT2VECTORREQ = _descriptor.Descriptor(
+  name='Text2VectorReq',
+  full_name='proto.Text2VectorReq',
+  filename=None,
+  file=DESCRIPTOR,
+  containing_type=None,
+  fields=[
+    _descriptor.FieldDescriptor(
+      name='text', full_name='proto.Text2VectorReq.text', index=0,
+      number=1, type=9, cpp_type=9, label=1,
+      has_default_value=False, default_value=b"".decode('utf-8'),
+      message_type=None, enum_type=None, containing_type=None,
+      is_extension=False, extension_scope=None,
+      serialized_options=None, file=DESCRIPTOR),
+  ],
+  extensions=[
+  ],
+  nested_types=[],
+  enum_types=[
+  ],
+  serialized_options=None,
+  is_extendable=False,
+  syntax='proto3',
+  extension_ranges=[],
+  oneofs=[
+  ],
+  serialized_start=28,
+  serialized_end=58,
+)
+
+
+_TEXT2VECTORRESP = _descriptor.Descriptor(
+  name='Text2VectorResp',
+  full_name='proto.Text2VectorResp',
+  filename=None,
+  file=DESCRIPTOR,
+  containing_type=None,
+  fields=[
+    _descriptor.FieldDescriptor(
+      name='vector', full_name='proto.Text2VectorResp.vector', index=0,
+      number=1, type=2, cpp_type=6, label=3,
+      has_default_value=False, default_value=[],
+      message_type=None, enum_type=None, containing_type=None,
+      is_extension=False, extension_scope=None,
+      serialized_options=None, file=DESCRIPTOR),
+  ],
+  extensions=[
+  ],
+  nested_types=[],
+  enum_types=[
+  ],
+  serialized_options=None,
+  is_extendable=False,
+  syntax='proto3',
+  extension_ranges=[],
+  oneofs=[
+  ],
+  serialized_start=60,
+  serialized_end=93,
+)
+
+DESCRIPTOR.message_types_by_name['Text2VectorReq'] = _TEXT2VECTORREQ
+DESCRIPTOR.message_types_by_name['Text2VectorResp'] = _TEXT2VECTORRESP
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+Text2VectorReq = _reflection.GeneratedProtocolMessageType('Text2VectorReq', (_message.Message,), {
+  'DESCRIPTOR' : _TEXT2VECTORREQ,
+  '__module__' : 'text2vector_pb2'
+  # @@protoc_insertion_point(class_scope:proto.Text2VectorReq)
+  })
+_sym_db.RegisterMessage(Text2VectorReq)
+
+Text2VectorResp = _reflection.GeneratedProtocolMessageType('Text2VectorResp', (_message.Message,), {
+  'DESCRIPTOR' : _TEXT2VECTORRESP,
+  '__module__' : 'text2vector_pb2'
+  # @@protoc_insertion_point(class_scope:proto.Text2VectorResp)
+  })
+_sym_db.RegisterMessage(Text2VectorResp)
+
+
+# @@protoc_insertion_point(module_scope)

+ 3 - 0
proto/text2vector_pb2_grpc.py

@@ -0,0 +1,3 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+import grpc
+

+ 143 - 0
text_vector_pulish.py

@@ -0,0 +1,143 @@
+# coding:utf-8
+import re
+from hashlib import md5
+from jy_utils.mongodb_utils import MongoInterface
+from jy_utils.task_manage import AsyncTaskScheduler
+from jy_utils.tools import add_logger_file
+from bson import ObjectId
+import threading
+from utils.request_fun import text_to_vector
+from utils.title_ner import title_topic_process
+add_logger_file("./logs_vector")
+SAVE_MAX_id = 'save_id_max'
+MongoConfig = {
+    "ip_port": "172.17.189.140:27080",
+    # "ip_port": "192.168.3.71:29099",
+    "user": "JSYJZ_RWBidAi_ProG",
+    "password": "JSLi@20LiefK3d",
+    "db": "qfw",
+    # "db": "re4art",
+    "col": "bidding",
+}
+
+MongoConfig2 = {
+    "ip_port": "172.17.189.140:27080",
+    # "ip_port": "192.168.3.71:29099",
+    "user": "JSYJZ_RWBidAi_ProG",
+    "password": "JSLi@20LiefK3d",
+    "db": "ai",
+    # "db": "re4art",
+    "col": "vector_file",
+}
+AsyncConfig = {
+    "max_queue_size": 5000,
+    "producer_interval": 10,
+    "consumer_interval": 2,
+     "run_status": True,
+}
+
+at = AsyncTaskScheduler(AsyncConfig)
+
+mg = MongoInterface(MongoConfig)
+mg2 = MongoInterface(MongoConfig2)
+
+with open('./data/stoptext.txt', 'r', encoding='utf-8') as f:
+    stopcontent = f.readlines()
+
+
+def producer_handle(data):
+    if data.get('toptype','') in ['拟建','产权']: # 排除拟建,产权类
+        return False, data
+    data = data
+    return True, data  # True 代表入队列,data 代表入队列的数据
+
+
+def stop_content(text: str):
+    """
+    停用文本--->当一些固定的词需要切除但是可能会被切词工具切错如:重采购,重招标
+    :param text:
+    :return:
+    """
+    for sw in stopcontent:
+        sw = sw.replace('\n', '')
+        if sw in text:
+            text = text.replace(sw, '')
+    return text
+
+
+def re_tract(title):
+    """
+    标题正则,加速抽取
+    :param title:
+    :return:
+    """
+    patterns = ['.*关于(.*?)的网上超市.*']
+    for pattern in patterns:
+        text = [i for i in re.findall(pattern, title) if i]
+        if text:
+            return ''.join(text)
+
+
+def topic_trace(title,projectname):
+    """
+    主干词抽取
+    """
+    if '采购意向' in projectname and '采购意向' in title:
+        return title,'title'
+    title_topic = re_tract(title)
+    if title_topic:
+        return title_topic,'re'
+    title_topic = re_tract(projectname)
+    if title_topic:
+        return title_topic,'re'
+    if ('采购意向' in title or '...' in title) and '采购意向' not in projectname:
+        title_topic, flag = title_topic_process(stop_content(projectname))
+    else:
+        title_topic, flag = title_topic_process(stop_content(title))
+        if flag == 'title' and projectname:
+            title_topic, flag = title_topic_process(stop_content(projectname))
+    if not title_topic:
+        title_topic = title
+        flag = 'title'
+    title_topic = re.sub(r'[^\w\s]', '', title_topic)
+    return title_topic,flag
+
+
+@at.many_thread(num=2)
+@at.consumer
+def consumer_handle(*args, **kwargs):
+    '''
+    处理逻辑
+    :param data:
+    :return:
+    '''
+    #
+
+    row = kwargs.get("data")
+    ids = row.get('_id', '')
+    projectname = row.get('projectname', '')
+    title = row.get('title', '')
+    title_topic, flag = topic_trace(title, projectname)  # 主干抽取
+    title_topic = title_topic.replace('"', '').replace('\\', '')
+    mg.update_one_by_field(MongoConfig.get('col', ''), {'_id': ids}, {'topic_test': title_topic})  # 主干词入bidding
+    if flag != 'title':
+        topic_hash = md5(title_topic.encode('utf-8')).hexdigest()
+        vector = text_to_vector(title_topic)
+        mg2.update_one_by_field(MongoConfig2.get('col', ''),
+                                {'hash_id':topic_hash},{'topic_name':title_topic,
+                                                        'hash_id': topic_hash,
+                                                        'vector': str(vector)}, True)  # 向量存入向量表,hash_id 不存在则插入新id
+
+
+if __name__ == '__main__':
+
+    filed = ['_id', 'detail', 'title',  'projectname','toptype','topic_word',
+             'rate','tag_subinformation_ai','tag_topinformation_ai']
+    incremental_iterator = mg.bigdata_iterator(MongoConfig["col"], filed,
+                                                   id_range=[ObjectId('0'*24), ObjectId("f"*24 )]
+                                                   ,reverse=True)
+
+    t = threading.Thread(target=at.producer, args=(incremental_iterator, producer_handle))  # 生产者
+    t.start()
+    consumer_handle()  # 消费者
+    t.join()

BIN
utils/__pycache__/redis_helper.cpython-37.pyc


BIN
utils/__pycache__/request_fun.cpython-37.pyc


BIN
utils/__pycache__/title_ner.cpython-37.pyc


+ 205 - 0
utils/digital.py

@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/7/10 
+# @Author  : lkj
+# @description :  流程1 只查询
+import re
+import time
+from collections import Counter
+from numpy import dot
+from numpy.linalg import norm
+from utils.milvus_hlper import Milvus
+from utils.request_fun import text_to_vector
+from config import milvus_config
+import numpy as np
+
+m = Milvus(table='zc_classify', **milvus_config)
+
+
+def cosine_similarity(vector1, vector2):
+    """
+    余弦相似计算
+    """
+    dot_product = np.dot(vector1, vector2)
+    norm1 = np.linalg.norm(vector1)
+    norm2 = np.linalg.norm(vector2)
+    similarity = dot_product / (norm1 * norm2)
+    return round(similarity,4)
+
+
+def check(text_vec,data_es: list):
+    '''
+    规则打可信度
+    :param text_vec:
+    :param data_es:
+    :return:
+    '''
+    try:
+        sim_list = [i[2] for i in data_es]  # 相似度列表
+        sim_mean = sum(sim_list)/len(sim_list)  # 平均相似度
+        data = [item for item in data_es if item[2] >= sim_mean-0.025]  # 删除es结果中差异性大且相似度低的值
+        count = 3  # 统计个数阈值
+        if len(data) <= 4:
+            count = 2
+        if data[0][2] > 0.965:  # 查询结果为正序如果满足极大值或者较大差异直接返回第一个数据
+            return data[0][1], 'mode01'
+        if data[0][2] - data[1][2] > 0.025 and data[0][2] > 0.9:  # 如果第一个可信度大于第二个0.025并且第一个相似度大于0.9
+            return data[0][1], 'mode1'
+        # 新判断当出现多个大于0.9判断每个上一级分类的选取最高的
+        score_list = [i[1] for i in data_es if i[2] > 0.9]
+        if len(score_list) >= 2:
+            best_code = ''
+            best_sim = 0
+            for end_code in score_list:
+                father_code = end_code[:-2]  # 父级的id
+                father_code_name = m.get_name(father_code)  # 父级的name
+                father_vec = text_to_vector(father_code_name)
+                similarity = round(dot(text_vec, father_vec) / (norm(text_vec) * norm(father_vec)), 4)  # 相似度计算
+                if similarity > best_sim:  # 循环查找父级相似度最大的
+                    best_code = end_code
+                    best_sim = similarity
+            return best_code, 'mode1'
+
+        #  第三档的可信度规则
+        elif data[0][2] - data[1][2] > 0.015 and data[0][2] > 0.9:
+            pcode = data[0][1][:-2]
+            pname = m.get_name(pcode)
+            pvec = text_to_vector(pname)
+            similarity = round(dot(text_vec, pvec) / (norm(text_vec) * norm(pvec)), 4)  # 相似度计算
+            if similarity > 0.8:  # 第一个值的父级相似度大于0.8直接返回该值
+                return data[0][1], 'mode1'
+        else:
+            # 统计整个向量库中返回的数据的出现的频率做规则
+            code_list = []
+            for row in data:
+                if len(row[1][:-2]) > 1:
+                    code_list.append(row[1][:-2])
+                code_list.append(row[1])
+            word_count = dict(Counter(code_list))  # 统计对应分类及其父类的词频,如果某个词的父类频率高则定位到该类
+            max_word = [(k, v) for k, v in word_count.items() if v >= count]
+            if len(max_word) == 1:
+                return max_word[0][0], 'mode2'
+            else:  # 如果存在多个值进行对比
+                code = ''
+                code_sim = 0
+                for word in max_word:
+                    code_ = word[0]
+                    code_name = m.get_name(code_)
+                    vec = text_to_vector(code_name)
+                    sim_ = round(dot(text_vec, vec) / (norm(text_vec) * norm(vec)), 4)  # 相似度计算
+                    if sim_ > code_sim:
+                        code_sim = sim_
+                        code = code_
+                return code, 'mode3'
+        return '', 'mode0'
+    except Exception as e:
+        print('check_error', e)
+        return '', 'error'
+
+
+def run_mode1(text, baseclass):
+    """
+    标的物数字化主函数
+    :param text:
+    :param baseclass
+    :param es_classify_name
+    :return: result_name:结果名称, similarity:结果与输入文本相似度, mode:流程模式, code:结果编码, credibility:结果可信度
+    """
+    vec = text_to_vector(text)  # 转成向量
+    search_result = m.search_china(vec, baseclass)  # 查询结果
+    if not search_result:
+        return '', 0, 'mode0','',0
+    similarity = 0  # 文本与结果相似度
+    result_name = ''
+    mode = 'mode0'
+    credibility = 0
+    code = ''
+    if search_result:
+        check_result = check(vec, search_result)  # 结果筛选
+        if check_result[0]:
+            code = check_result[0]
+            result_name = m.get_name(check_result[0]) # 名称映射
+            mode = check_result[1]
+            if result_name:
+                res_vec = text_to_vector(result_name)
+                similarity = round(dot(vec, res_vec) / (norm(vec) * norm(res_vec)), 4)  # 相似度计算
+                if mode == 'mode01':
+                    credibility = 0.99
+                if mode == 'mode1':
+                    credibility = 0.95
+                if mode == 'mode3' and similarity > 0.85:
+                    credibility = 0.90
+                if mode == 'mode2' and similarity > 0.8:
+                    credibility = 0.85
+    if mode in ['mode2','mode3']:
+        pcode = code[:-2]
+        if not pcode:
+            pcode = code
+        p_name = m.get_name(pcode)  # 父类名称
+        p_name_vec = text_to_vector(p_name)  # 父类向量
+        p_similarity = round(dot(p_name_vec, vec) / (norm(p_name_vec) * norm(vec)), 4)  # 文本与父类计算
+        if p_similarity > 0.85 and similarity > 0.9 or similarity == 0.99:
+            mode = 'mode4'
+            credibility = 0.99
+    if credibility == 0:
+        result_name = search_result[0][0]
+        similarity = search_result[0][2]
+        mode = ''
+        code = search_result[0][1]
+    return result_name, similarity, mode, code, credibility
+
+
+def run_mode1_main(text, baseclass=None):
+    """
+
+    """
+    try:
+        result = list(run_mode1(text, baseclass))
+        if not result[0]:
+            return ['', '', '', 0, '']
+        result.pop(2)
+        code = result[2]
+        route = m.get_root_zc(code)
+        result.append(route)
+        if result[1] > 0.9 and result[-2] == 0:
+            result[-2] = 0.85
+        if result[1] == 1.0:
+            result[-2] = 0.99
+        return result
+    except Exception as e:
+        print('政采分类错误--->',e)
+        return ['', '', '', '', '']
+
+
+if __name__ == '__main__':
+    print(run_mode1_main('成型设备'
+                         '','工程'))
+    exit()
+    while True:
+        t = input('输入文本:')
+        print(run_mode1_main(t))
+    # exit()
+    # import pandas as pd
+    # data = pd.read_csv('./data/test.csv',encoding='utf-8',sep='\t')
+    # for name in data['name']:
+    #     print('intput--->', name)
+    #     run_result = run_mode1(name)
+    #
+    #     china_name = run_result[0]
+    #     china_name_code = run_result[3]
+    #     china_name_dis = run_result[1]
+    #     score = 0.99
+    #     root = ''
+    #     stop = 2
+    #     for i in range(0, len(china_name_code), 2):
+    #         root = root + name_maps.get(china_name_code[0:stop], '') + '/'
+    #         stop += 2
+    #     res = [name,china_name,china_name_code,root,run_result[4]]
+    #     print('相似度:',run_result[4])
+    #     with open('data/result3.csv', 'a', newline='', encoding='utf-8') as f:
+    #         witer = csv.writer(f)
+    #         witer.writerow(res)
+    #     print('output--->',run_result[0])
+    #     # print('相似度--->',res[1])
+    #     print('模式--->',run_result[2])
+    #         # print('分类解释:', i[3] + '\n')
+    #     print('*' * 30)

+ 204 - 0
utils/digital2.py

@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/7/10 
+# @Author  : lkj
+# @description :  流程2 查询标准向量库语义相似
+import time
+from collections import Counter
+from numpy import dot
+from numpy.linalg import norm
+from utils.milvus_hlper import Milvus
+from utils.request_fun import text_to_vector
+from config import milvus_config
+
+
+m = Milvus(table='jianyu_code_2', **milvus_config)
+
+
+def count_fun(data:list):
+    """
+    暴力搜索
+    """
+    try:
+        count_dict = {}
+        for item in data:
+            split = 0
+            if len(item) == 1:
+                count_dict[item] = count_dict.get(item, 0) + 1
+            else:
+                for i in range(int((len(item) - 1) / 2)):
+                    if split == 0:
+                        c = item
+                    else:
+                        c = item[:-split]
+                    count_dict[c] = count_dict.get(c, 0) + 1
+                    split += 2
+        if not count_dict:
+            return '', 0
+        max_value = max(count_dict.values())  # 找到字典中最大的 value 值
+        if max_value == 1:
+            max_length_key = min(
+                [key for key, value in count_dict.items() if value == max_value],
+                key=len
+            )
+        else:
+            max_length_key = max(
+                [key for key, value in count_dict.items() if value == max_value],
+                key=len
+            )
+
+        return max_length_key,max_value
+    except Exception as e:
+        print('count_fun_errorxxx',e)
+        return '',0
+
+
+def check(text_vec,data_: list):
+    '''
+   第二档可信度判断规则
+    :param text_vec:
+    :param data_:
+    :return:
+    '''
+    try:
+        sim_list = [i[2] for i in data_]
+        sim_mean = sum(sim_list)/len(sim_list)
+        data = [item for item in data_ if item[2] >= sim_mean-0.025]  # 删除es结果中差异性大且相似度低的值
+        model01_res, model01_ = check_model01(text_vec, data)
+        if model01_res:
+            return model01_res, model01_
+        # 新判断当出现多个大于0.925判断每个上一级分类的选取最高的
+        score_list = [i[1] for i in data_ if i[2] > 0.92]
+        count_flag = 3  # 出现词频阈值,根据score_list元素个数计算
+        if len(score_list) <= 4:
+            count_flag = 2
+        word_count = dict(Counter(score_list))  # 统计出现频率高且相似度高的分类
+        max_word = [(k, v) for k, v in word_count.items() if v >= count_flag]
+        if len(max_word) == 1:
+            return max_word[0][0], 'mode01'
+        elif len(score_list) >= 2:
+            best_code = ''
+            best_sim = 0
+            for code_ in score_list:
+                if len(code_) ==1:
+                    continue
+                father_code = code_[:-2]
+                father_code_name = m.get_name(father_code)
+                father_vec = text_to_vector(father_code_name)
+                similarity = round(dot(text_vec, father_vec) / (norm(text_vec) * norm(father_vec)), 4)  # 相似度计算
+                if similarity > best_sim:
+                    best_code = code_
+                    best_sim = similarity
+            return best_code, 'mode1'
+        else:
+            count_code, max_value = count_fun([i[1] for i in data_ if i[2] >= 0.85])  # 如果暴力查询的结果大于6/7 并且分类层级要大于2层
+            if (max_value/len(data_)) >= (len(data_)-2)/len(data_) and len(count_code)>3:
+                return count_code,'mode1'
+        return '', ''
+    except Exception as e:
+        print('check_errorfff', e)
+
+        return '', 'error'
+
+
+def check_model01(text_vec,data_es):
+    """
+    第一档可信度判断规则
+    :param text_vec:
+    :param data_es:
+    :return:
+    """
+    if data_es[0][2] > 0.945:  # es查询结果为正序如果满足极大值或者较大差异直接返回第一个数据
+        return data_es[0][1], 'mode01'
+    output_lst = []
+    word_count = Counter([row[1] for row in data_es if row[2] > 0.85])
+    if not word_count:
+        return '', ''
+    max_pair = max(word_count.items(), key=lambda x: x[1])   # 统计词频如果词频最大值大于等于5/7则输出该值
+    if max_pair[1]/len(data_es) >= (int(len(data_es)/2)+1)/len(data_es):
+        return max_pair[0],'mode01'
+
+    for i in range(min(3, len(data_es))):  # 只判断前三个元素
+        if data_es[i][0] == data_es[0][0] and data_es[i][2] > 0.85:
+            output_lst.append(data_es[i][1])
+    if len(output_lst) == 3:
+        return data_es[0][1], 'mode01'
+
+    elif data_es[0][2] - data_es[1][2] > 0.02 and data_es[0][2] > 0.91:  # 第一个结果极大于后面
+        p_code = data_es[0][1][:-2]
+        if not p_code or (data_es[0][2] > 0.91):  # 如果es得分第一的结果只有一层且相似度大于0.9就默认是正确
+            return data_es[0][1], 'mode01'
+        else:
+            p_name = m.get_name(p_code)
+            pvec = text_to_vector(p_name)
+            similarity = round(dot(text_vec, pvec) / (norm(text_vec) * norm(pvec)), 4)  # 相似度计算
+            if similarity > 0.8:   # 标的物与该分类的父级的相似度
+                return data_es[0][1], 'mode01'
+    return '', ''
+
+
+def run_mode1(text,baseclass=None):
+    """
+    标的物数字化主函数
+    :param text:
+    :param classify_name
+    :return: result_name:结果名称, similarity:结果与输入文本相似度, mode:流程模式, code:结果编码, credibility:结果可信度
+    """
+    try:
+
+        vec = text_to_vector(text)  # 转成向量
+        # search_result = m.search_good(vec,7,baseclass)  # 查询结果
+        search_result = m.search_industry(vec,['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'],
+                                          industry_list=['物业'])
+        print(search_result)
+        similarity = 0  # 文本与结果相似度
+        result_name = ''  # 分类名称
+        mode = ''  # 分类判断模式
+        credibility = 0  # 可信度
+        code = ''  # 代码
+
+        if search_result and len(search_result) > 2:
+            check_result = check(vec, search_result)  # 结果筛选
+            if check_result[0]:
+                code = check_result[0]
+                result_name = m.get_name(code)  # 名称映射
+                mode = check_result[1]
+                if result_name:
+                    res_vec = text_to_vector(result_name)
+                    similarity = round(dot(vec, res_vec) / (norm(vec) * norm(res_vec)), 4)  # 相似度计算
+                    if mode == 'mode1':
+                        credibility = 0.90
+                    if mode == 'mode01':
+                        credibility = 0.95
+        if credibility == 0:  # 可信度为0 则用第一个结果作为输出
+            result_name = search_result[0][0]
+            similarity = search_result[0][2]
+            mode = ''
+            code = search_result[0][1]
+        return [result_name, similarity, mode, code, credibility]
+    except Exception as e:
+        print('errrrrssss',e)
+        print(text,baseclass)
+        return []
+
+
+def run_mode1_main(text,baseclass):
+    try:
+        result = run_mode1(text, baseclass)
+
+        if not result[0]:
+            return ['', '', '', 0, '']
+        result.pop(2)
+        code = result[2]
+        route = m.get_root_zc(code)
+
+        result.append(route)
+        return result
+    except Exception as e:
+        print('errrrr',e)
+        return ['', '', '', '', '']
+
+
+if __name__ == '__main__':
+    a = [('服务', 'C', 0.9276, '服务/', '服务'), ('审计服务', 'C2303', 0.9264, '商务服务/审计服务/', '年审计服务'), ('运行维护服务', 'C1607', 0.9234, '信息技术服务/运行维护服务/', '年信息安全服务'), ('物业管理服务', 'C2104', 0.9202, '房地产服务/物业管理服务/', '年物业服务'), ('服务', 'C', 0.9145, '服务/', '综合服务'), ('会议服务', 'C2201', 0.9111, '会议、展览、住宿和餐饮服务/会议服务/', '会务服务'), ('软件运维服务', 'C160703', 0.9108, '信息技术服务/运行维护服务/软件运维服务/', '业务系统服务')]
+    v = text_to_vector('xxxx')
+    print(check(v, a))

+ 78 - 0
utils/htmltag.py

@@ -0,0 +1,78 @@
+# coding:utf-8
+import re
+br_reg = re.compile('<br[/]*>', re.I)
+table_reg = re.compile('<([/]*table[^>]*)>', re.I)
+tablebody_reg = re.compile('<([/]*tbody[^>]*)>', re.I)
+input_reg = re.compile(r'<[/]*input[^>].*?value="(.*?)"[/]>', re.I)
+tr_reg = re.compile('<([/]*tr[^>]*)>', re.I)
+th_reg = re.compile('<([/]*th[^>]*)>', re.I)
+td_reg = re.compile('<([/]*td[^>]*)>', re.I)
+p_reg = re.compile('<[/]?p>', re.I)
+othertag_reg = re.compile('<[^>]+>', re.I | re.M)
+other_symbol_reg = re.compile('[\t| ]*')
+seg_first_space_reg = re.compile('\n+\\s*', re.M)
+mul_crcf_reg = re.compile('\n+', re.M)
+brackets_reg = re.compile('\\s+')
+table_fk_reg = re.compile('(\\[table[^\\]]*\\])(.*?)(\\[/table\\])', re.M | re.S | re.I)
+
+
+##html标签清理
+def Clean(html: str):
+    html = br_reg.sub('\n', html)
+    html = table_reg.sub('', html)
+    html = tablebody_reg.sub('', html)
+    html = tr_reg.sub('\n', html)
+    html = td_reg.sub(' ', html)
+    html = p_reg.sub('\n', html)
+    html = othertag_reg.sub('', html)
+    html = other_symbol_reg.sub('', html)
+    html = seg_first_space_reg.sub('\n', html)
+    html = mul_crcf_reg.sub('\n', html)
+    return html
+
+
+def ClearSpace(txt: str):
+    return brackets_reg.sub('', txt)
+
+
+##html标签清理,但保留table表格
+def CleanKeepTable(html: str):
+    html = br_reg.sub('\n', html)
+    html = table_reg.sub(subFunc4Match, html)
+    html = tablebody_reg.sub(subFunc4Match, html)
+    html = tr_reg.sub(subFunc4Match, html)
+    html = td_reg.sub(subFunc4Match, html)
+    html = th_reg.sub(subFunc4Match, html)
+    html = p_reg.sub('\n', html)
+    html = othertag_reg.sub('', html)
+    # html = other_symbol_reg.sub('',html)
+    html = seg_first_space_reg.sub('\n', html)
+    # print("-->", html)
+    html = table_fk_reg.sub(lambda x: x.group(1) + mul_crcf_reg.sub(' ', x.group(2)) + x.group(3), html)
+    html = mul_crcf_reg.sub('\n', html)
+    # 清理table标签中的空格
+    html = html.replace('[', '<').replace(']', '>')
+    html = html.replace('<table', '\n<table').replace('</table>', '</table>\n')
+    return html
+
+
+def subFunc4Match(strmatch):
+    try:
+        if strmatch:
+            return '[%s]' % (strmatch.group(1))
+        else:
+            return ""
+    except Exception as e:
+        print(e)
+
+
+def extract_input_value(html):
+    input_reg = re.compile(r'<[/]*input[^>].*?value="(.*?)"[/]>', re.I)
+    input_r = re.compile(r'<[/]*input[^>].*?[/]>', re.I)
+    result = input_r.findall(html)
+    for input_detail in result:
+        ret = input_reg.findall(input_detail)
+        if ret:
+            html = html.replace(input_detail, f"</td><td>{ret[0]}")
+    return html
+

+ 169 - 0
utils/jy_code.py

@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2024/2/6 
+# @Author  : lkj
+# @description :
+from collections import Counter
+from numpy import dot
+from numpy.linalg import norm
+from utils.milvus_hlper import Milvus
+from utils.request_fun import text_to_vector
+import numpy as np
+
+
+class JyCode(object):
+    def __init__(self, db_name, config):
+        self.m = Milvus(table=db_name, **config)
+
+    @staticmethod
+    def cosine_similarity(vector1, vector2):
+        """
+        余弦相似计算
+        """
+        dot_product = np.dot(vector1, vector2)
+        norm1 = np.linalg.norm(vector1)
+        norm2 = np.linalg.norm(vector2)
+        similarity = dot_product / (norm1 * norm2)
+        return round(similarity, 4)
+
+    def check(self, text_vec, data_es: list, offset_value=0.025, min_value=0.9, max_value=0.965):
+        '''
+        规则打可信度
+        :param text_vec:
+        :param data_es: 通过向量查询后的数据
+        :param offset_value topk之间的差异值
+        :param min_value  可信度最低阈值,低于该阈值不可信
+        :param max_value 最大阈值,大于该值默认直接可信
+        :return:
+        '''
+        try:
+            sim_list = [i[2] for i in data_es]  # 相似度列表
+            sim_mean = sum(sim_list) / len(sim_list)  # 平均相似度
+            data = [item for item in data_es if item[2] >= sim_mean - offset_value ]  # 删除es结果中差异性大且相似度低的值
+            count = 3  # 统计个数阈值
+            if len(data) <= 4:
+                count = 2
+            if data[0][2] > max_value:  # 查询结果为正序如果满足极大值或者较大差异直接返回第一个数据
+                return data[0][1], 'mode01'
+            if data[0][2] - data[1][2] > offset_value and data[0][2] > min_value:  # 如果第一个可信度大于第二个0.025并且第一个相似度大于0.9
+                return data[0][1], 'mode1'
+            # 新判断当出现多个大于0.9判断每个上一级分类的选取最高的
+            score_list = [i[1] for i in data_es if i[2] > min_value]
+            if len(score_list) >= 2:
+                best_code = ''
+                best_sim = 0
+                for end_code in score_list:
+                    father_code = end_code[:-2]  # 父级的id
+                    father_code_name = self.m.get_name(father_code)  # 父级的name
+                    father_vec = text_to_vector(father_code_name)
+                    similarity = round(dot(text_vec, father_vec) / (norm(text_vec) * norm(father_vec)), 4)  # 相似度计算
+                    if similarity > best_sim:  # 循环查找父级相似度最大的
+                        best_code = end_code
+                        best_sim = similarity
+                return best_code, 'mode1'
+
+            #  第三档的可信度规则
+            elif data[0][2] - data[1][2] > (offset_value-0.01) and data[0][2] > min_value:
+                pcode = data[0][1][:-2]
+                pname = self.m.get_name(pcode)
+                pvec = text_to_vector(pname)
+                similarity = round(dot(text_vec, pvec) / (norm(text_vec) * norm(pvec)), 4)  # 相似度计算
+                if similarity > (min_value-0.1):  # 第一个值的父级相似度大于0.8直接返回该值
+                    return data[0][1], 'mode1'
+            else:
+                # 统计整个向量库中返回的数据的出现的频率做规则
+                code_list = []
+                for row in data:
+                    if len(row[1][:-2]) > 1:
+                        code_list.append(row[1][:-2])
+                    code_list.append(row[1])
+                word_count = dict(Counter(code_list))  # 统计对应分类及其父类的词频,如果某个词的父类频率高则定位到该类
+                max_word = [(k, v) for k, v in word_count.items() if v >= count]
+                if len(max_word) == 1:
+                    return max_word[0][0], 'mode2'
+                else:  # 如果存在多个值进行对比
+                    code = ''
+                    code_sim = 0
+                    for word in max_word:
+                        code_ = word[0]
+                        code_name = self.m.get_name(code_)
+                        vec = text_to_vector(code_name)
+                        sim_ = round(dot(text_vec, vec) / (norm(text_vec) * norm(vec)), 4)  # 相似度计算
+                        if sim_ > code_sim:
+                            code_sim = sim_
+                            code = code_
+                    return code, 'mode3'
+            return '', 'mode0'
+        except Exception as e:
+            print('check_error', e)
+            return '', 'error'
+
+    def run_mode1(self,text, baseclass=None):
+        """
+        标的物数字化主函数
+        :param text:
+        :param baseclass
+        :return: result_name:结果名称, similarity:结果与输入文本相似度, mode:流程模式, code:结果编码, credibility:结果可信度
+        """
+        vec = text_to_vector(text)  # 转成向量
+        search_result = self.m.search_china(vec, baseclass)  # 查询结果
+        if not search_result:
+            return '', 0, 'mode0', '', 0
+        similarity = 0  # 文本与结果相似度
+        result_name = ''
+        mode = 'mode0'
+        credibility = 0
+        code = ''
+        if search_result:
+            check_result = self.check(vec, search_result)  # 结果筛选
+            if check_result[0]:
+                code = check_result[0]
+                result_name = self.m.get_name(check_result[0])  # 名称映射
+                mode = check_result[1]
+                if result_name:
+                    res_vec = text_to_vector(result_name)
+                    similarity = round(dot(vec, res_vec) / (norm(vec) * norm(res_vec)), 4)  # 相似度计算
+                    if mode == 'mode01':
+                        credibility = 0.99
+                    if mode == 'mode1':
+                        credibility = 0.95
+                    if mode == 'mode3' and similarity > 0.85:
+                        credibility = 0.90
+                    if mode == 'mode2' and similarity > 0.8:
+                        credibility = 0.85
+        if mode in ['mode2', 'mode3']:
+            pcode = code[:-2]
+            if not pcode:
+                pcode = code
+            p_name = self.m.get_name(pcode)  # 父类名称
+            p_name_vec = text_to_vector(p_name)  # 父类向量
+            p_similarity = round(dot(p_name_vec, vec) / (norm(p_name_vec) * norm(vec)), 4)  # 文本与父类计算
+            if p_similarity > 0.85 and similarity > 0.9 or similarity == 0.99:
+                mode = 'mode4'
+                credibility = 0.99
+        if credibility == 0:
+            result_name = search_result[0][0]
+            similarity = search_result[0][2]
+            mode = ''
+            code = search_result[0][1]
+        return result_name, similarity, mode, code, credibility
+
+    def run_mode1_main(self,text, baseclass=None):
+        """
+
+        """
+        try:
+            result = list(self.run_mode1(text, baseclass))
+            if not result[0]:
+                return ['', '', '', 0, '']
+            result.pop(2)
+            code = result[2]
+            route = self.m.get_root_zc(code)
+            result.append(route)
+            if result[1] > 0.9 and result[-2] == 0:
+                result[-2] = 0.85
+            if result[1] == 1.0:
+                result[-2] = 0.99
+            return result
+        except Exception as e:
+            print('政采分类错误--->', e)
+            return ['', '', '', '', '']

+ 107 - 0
utils/manager.py

@@ -0,0 +1,107 @@
+# coding:utf-8
+def answers_writing(question: str) -> str:
+    """
+    问答
+    :param question: 输入文本
+    :return:
+    """
+    input_text = f"问答:\n问题:{question}\n答案:"
+    return input_text
+
+
+def dialogue(context: str, question: str) -> str:
+    """
+    对话
+    :param context:
+    :param question:
+    :return:
+    """
+    input_text = context + "\n用户:" + question + "\n剑鱼:"
+    return input_text
+
+
+def create_context(context: str, question: str, answer: str) -> str:
+    """
+    生成上下文
+    :param context:  上文
+    :param question: 问题
+    :param answer: 回答
+    :return:
+    """
+    context = context + "\n用户:" + question + "\n剑鱼:" + answer
+    return context
+
+
+def title_words_article(title: str, words: str) -> str:
+    """
+    标题、关键词生文章
+    :param title:
+    :param words:
+    :return:
+    """
+    input_text = f"根据标题和关键词生成文章:\n标题:{title};关键词:{words}\n答案:"
+    return input_text
+
+
+def title_article(title: str) -> str:
+    """
+    标题生文章
+    :param title:
+    :return:
+    """
+    input_text = f"根据标题生成文章:\n标题:{title}\n答案:"
+    return input_text
+
+
+def copy_writing(title: str, question: str) -> str:
+    """
+    营销文案生成
+    :param title:
+    :param question:
+    :return:
+    """
+    input_text = f"营销文案生成:\n标题:{title}\n{question}\n答案:"
+    return input_text
+
+
+def info_extract(title: str, question: str) -> str:
+    """
+    通用信息抽取
+    :param title:
+    :param question:
+    :return:
+    """
+    input_text = f"信息抽取:\n{title}\n问题:{question}\n答案:"
+    return input_text
+
+
+def extract_words(title: str) -> str:
+    """
+    关键词抽取
+    :param title: 文章
+    :return:
+    """
+    input_text = f"抽取关键词:\n{title}\n关键词:"
+    return input_text
+
+
+def reading(title: str, question: str) -> str:
+    """
+    阅读理解
+    :param title:
+    :param question:
+    :return:
+    """
+    input_text = f"阅读理解:\n段落:{title}\n问题:{question}\n答案:"
+    return input_text
+
+
+def classify(title: str, question: str) -> str:
+    """
+    文本分类
+    :param title:
+    :param question:
+    :return:
+    """
+    input_text = f"文本分类:\n{title}\n选项:{question}\n答案:"
+    return input_text

+ 232 - 0
utils/milvus_hlper.py

@@ -0,0 +1,232 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/10/11 
+# @Author  : lkj
+# @description : milvus
+from typing import Optional, List
+
+import numpy
+from numpy import dot
+from numpy.linalg import norm
+from pymilvus import (connections, FieldSchema,Collection,CollectionSchema,DataType,utility)
+from config import redis_config
+from utils.redis_helper import RedisString
+
+
+
+# 使用默认数据库 ‘default’,也可以自己建数据库
+
+
+class Milvus(object):
+    def __init__(self,table,**kwargs):
+
+        connections.connect(**kwargs)
+        self.col = Collection(table)
+        self.search_params = {
+            "metric_type": "L2",
+            "ignore_growing": False,
+            "params": {"nprobe": 100},
+        }
+        self.r = RedisString(redis_config)
+
+    def search(self,vec,fileds=None,expr=None):
+        res = self.col.search([vec], 'embeddings', self.search_params, 7,
+                              output_fields=fileds,
+                              expr=expr)
+        return res
+
+    def load(self):
+        self.col.load()
+
+    def release(self):
+        self.col.release()
+
+    def delete(self,expr):
+        self.col.delete(expr=expr)
+
+    def query(self,q):
+        res = self.col.query(expr=q)
+        return res
+
+    def insert(self,data):
+        self.col.insert(data=data)
+
+    @staticmethod
+    def sim(a: list, b: list):
+        """
+        余弦计算两个向量相似度
+        :param a:
+        :param b:
+        :return:
+        """
+        s = dot(a, b) / (norm(a) * norm(b))
+        return round(s, 4)
+
+    def get_name(self,code):
+        """"
+        基于redis查询code对应name
+        """
+        while code[-1] == '0' and code[-2] == '0':
+            code = code[:-1]
+            code = code[:-1]
+        name = self.r.string_get('jycode_' + code)
+        return name
+
+    def get_root_zc(self, re_code):
+        """
+        根据code查询对应root
+        """
+        split = 0
+        root = ''
+        level = (len(re_code) - 1) / 2
+        for i in range(int(level)):
+            if split == 0:
+                c = re_code
+            else:
+                c = re_code[:-split]
+            split += 2
+            name_code = self.get_name(c)
+            root = name_code + '/' + root
+        return root
+
+    def search_good(self,vec,num=7,base=None):
+        try:
+            res = self.col.search([vec],
+                                  'embeddings', self.search_params, num,
+                                  output_fields=['code','class_name','embeddings','explain','root'],) #expr=f'baseclass=="{base}"'
+            result_list = []
+            for hit in res:
+                # print(hit)
+                for row in hit:
+                    row = row.to_dict()
+                    code = row.get('entity',{}).get('code','')
+                    while code[-1] == '0' and code[-2] == '0':
+                        code = code[:-1]
+                        code = code[:-1]
+                    explain = row.get('entity', {}).get('explain', '')
+                    root = self.get_root_zc(code)
+                    if not root:
+                        root = row.get('entity', {}).get('root', '')
+                    vec_cls = row.get('entity', {}).get('embeddings', [])
+                    sim_res = self.sim(vec, vec_cls)
+                    cls_name = row.get('entity', {}).get('class_name', '')
+                    result_list.append((cls_name, code, sim_res, root, explain))
+            result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
+            return result_list
+        except Exception as e:
+            print('关系库错误:', e)
+            return []
+
+    def search_industry(self, vec, output_fields: list, topk=7,
+                        industry_list: Optional[List[str]] = None,):
+        """
+        查询统计局分类函数
+        vec : 标的物转成的向量
+        industry_list:行业范围 -->list
+
+        """
+        try:
+            public = True
+            if industry_list:
+                # expr = f'industry in {industry_list}'
+                public = False
+            else:
+                expr = None
+            res = self.col.search([vec], 'embeddings', self.search_params, topk,
+                                  output_fields=output_fields,
+                                  expr=None)
+            result_list = []
+            for hit in res:
+                for row in hit:
+                    row = row.to_dict()
+                    code = row.get('entity', {}).get('code', '')
+                    cls_name = row.get('entity', {}).get('class_name', '')
+                    if not public:
+                        code = row.get('entity', {}).get('private_code', '')
+                        if code == 'null':
+                            code = row.get('entity', {}).get('code', '')
+                        cls_name = self.get_name(code)
+                    while code[-1] == '0' and code[-2] == '0':
+                        code = code[:-1]
+                        code = code[:-1]
+                    explain = row.get('entity', {}).get('explain', '')
+                    root = self.get_root_zc(code)
+                    if not root:
+                        root = row.get('entity', {}).get('root', '')
+                    vec_cls = row.get('entity', {}).get('embeddings', [])
+                    sim_res = self.sim(vec, vec_cls)
+                    result_list.append((cls_name, code, sim_res, root, explain))
+            result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
+            return result_list
+        except Exception as e:
+            print('统计局分类错误:', e)
+            return []
+
+    def search_china(self, vec,base=None,topk=7):
+        """
+        查询统计局分类函数
+        vec : 标的物转成的向量
+        base:标的物分类
+
+        """
+        try:
+
+            res = self.col.search([vec], 'embeddings', self.search_params, topk,
+                                  output_fields=['code','class_name','embeddings','explain','root'],
+                                  )#expr=f'baseclass=="{base}"'
+            result_list = []
+            for hit in res:
+                # print(hit)
+                for row in hit:
+
+                    row = row.to_dict()
+                    code = row.get('entity',{}).get('code','')
+                    while code[-1] == '0' and code[-2] == '0':
+                        code = code[:-1]
+                        code = code[:-1]
+                    ids = row['id']
+                    explain = row.get('entity', {}).get('explain', '')
+                    root = self.get_root_zc(code)
+                    if not root:
+                        root = row.get('entity', {}).get('root', '')
+                    vec_cls = row.get('entity', {}).get('embeddings', [])
+                    sim_res = self.sim(vec, vec_cls)
+                    cls_name = row.get('entity', {}).get('class_name', '')
+                    result_list.append((cls_name, code, sim_res, root, explain))
+            result_list = sorted(result_list, key=lambda x: x[2], reverse=True)
+            return result_list
+        except Exception as e:
+            print('统计局分类错误:', e)
+            return []
+
+    def update(self, data):
+        self.col.insert(data)
+
+def create():
+    onn = connections.connect(db_name="classify", host="192.168.3.109", port=19530)
+    fields = [
+        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
+        FieldSchema(name="class_name", dtype=DataType.VARCHAR, max_length=100),
+        FieldSchema(name='code', dtype=DataType.VARCHAR, max_length=16),
+        FieldSchema(name="p_name", dtype=DataType.VARCHAR, max_length=100),
+        FieldSchema(name='p_code', dtype=DataType.VARCHAR, max_length=16),
+        # FieldSchema(name='baseclass', dtype=DataType.VARCHAR, max_length=10),
+        FieldSchema(name='explain', dtype=DataType.VARCHAR, max_length=500),
+        # FieldSchema(name='root', dtype=DataType.VARCHAR, max_length=300),
+        FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=768),
+    ]
+
+    schema = CollectionSchema(fields, description='场景分类')
+    milvus_conn = Collection('scene', schema=schema)
+    index = {"index_type": "IVF_FLAT", "metric_type": "L2",  "params":{"nlist":128}}
+    milvus_conn.create_index("embeddings", index)
+
+if __name__ == '__main__':
+    # create()
+    # pass
+    from request_fun import text_to_vector
+    from config import milvus_config
+    text = '施工人员派遣'
+    vector = text_to_vector(text)
+    col = Milvus('jianyu_code',**milvus_config)
+    print(col.search_industry(vector, ['code', 'class_name', 'embeddings', 'explain', 'root', 'private_code'],
+                              industry_list=['物业']))

+ 60 - 0
utils/mysql_helper.py

@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/8/26 
+# @Author  : lkj
+# @description :
+import time
+
+from loguru import logger
+import mysql.connector
+
+
+class MysqlConn(object):
+    def __init__(self,mysql_path):
+        self.mysql_path = mysql_path
+        self.cursor = ''
+        self.mysql_conn = ''
+        self.mysql_conn, self.cursor = self.link_mysql()
+
+    def link_mysql(self,):
+        """
+        连接
+        """
+        try:
+            mysql_conn = mysql.connector.connect(
+                host=self.mysql_path.get('mysql_host'),
+                port=self.mysql_path.get('mysql_port'),
+                user=self.mysql_path['mysql_user'],
+                password=self.mysql_path['mysql_password'],
+                database=self.mysql_path['mysql_db'],
+                connect_timeout=6000,
+                charset = 'utf8',
+            )
+
+            cursor = mysql_conn.cursor(buffered=True)
+            print('mysql_course--->success')
+            return mysql_conn, cursor
+        except Exception as e:
+            logger.info('连接mysql数据库失败-->{}'.format(e))
+            return '',''
+
+    def search(self, sql,data):
+        """
+        查询
+        """
+        try:
+            self.cursor.execute(sql,(data,))
+            data = self.cursor.fetchall()
+
+            return data, True
+        except Exception as e:
+            print('mysql查询错误',e)
+            print(sql)
+            time.sleep(5)
+            self.mysql_conn, self.cursor = self.link_mysql()
+            return [], False
+
+    def close(self):
+        self.cursor.close()
+        self.mysql_conn.close()
+
+

+ 103 - 0
utils/predict.py

@@ -0,0 +1,103 @@
+# coding:utf-8
+# @description : 三大分类预测
+import os
+import torch
+import pickle as pkl
+from tqdm import tqdm
+import models.FastText as m
+
+MAX_VOCAB_SIZE = 10000
+UNK, PAD = '<UNK>', '<PAD>'
+
+
+def build_vocab(file_path, tokenizer, max_size, min_freq):
+    vocab_dic = {}
+    with open(file_path, 'r', encoding='UTF-8') as f:
+        for line in tqdm(f):
+            lin = line.strip()
+            if not lin:
+                continue
+            content = lin.split('\t')[0]
+            for word in tokenizer(content):
+                vocab_dic[word] = vocab_dic.get(word, 0) + 1
+        vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[
+                     :max_size]
+        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
+        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
+    return vocab_dic
+
+
+class PredictObject(object):
+    def __init__(self, config, ):
+        if os.path.exists(config.vocab_path):
+            self.vocab = pkl.load(open(config.vocab_path, 'rb'))
+        else:
+            assert IOError("词典文件不存在!!!")
+        self.buckets = config.n_gram_vocab
+        config.n_vocab = len(self.vocab)
+        self.config = config
+        self.model = m.Model(config)
+        self.model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
+        self.model.eval()
+
+    @staticmethod
+    def biGramHash(sequence, t, buckets):
+        t1 = sequence[t - 1] if t - 1 >= 0 else 0
+        return (t1 * 14918087) % buckets
+
+    @staticmethod
+    def triGramHash(sequence, t, buckets):
+        t1 = sequence[t - 1] if t - 1 >= 0 else 0
+        t2 = sequence[t - 2] if t - 2 >= 0 else 0
+        return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
+
+    def convert2vec(self, org_contents, pad_size=32):
+        tokenizer = lambda x: [y for y in x]  # char-level
+        contents = []
+        for token in org_contents:
+            words_line = []
+            seq_len = len(token)
+            if pad_size:
+                if len(token) < pad_size:
+                    token.extend([PAD] * (pad_size - len(token)))
+                else:
+                    token = token[:pad_size]
+                    seq_len = pad_size
+            # word to id
+            for word in token:
+                words_line.append(self.vocab.get(word, self.vocab.get(UNK)))
+
+            # fasttext ngram
+            bigram = []
+            trigram = []
+            # ------ngram------
+            for i in range(pad_size):
+                bigram.append(self.biGramHash(words_line, i, self.buckets))
+                trigram.append(self.triGramHash(words_line, i, self.buckets))
+            # -----------------
+            contents.append((words_line, seq_len, bigram, trigram))
+        return contents
+
+    def to_tensor(self, datas):
+        x = torch.LongTensor([_[0] for _ in datas]).to('cpu')
+        bigram = torch.LongTensor([_[2] for _ in datas]).to('cpu')
+        trigram = torch.LongTensor([_[3] for _ in datas]).to('cpu')
+
+        # pad前的长度(超过pad_size的设为pad_size)
+        seq_len = torch.LongTensor([_[1] for _ in datas]).to('cpu')
+        return x, seq_len, bigram, trigram
+
+    def predict(self, texts):
+        with torch.no_grad():
+            texts = self.convert2vec(texts, self.config.pad_size)
+            tensor_text = self.to_tensor(texts)
+            outputs = self.model(tensor_text)
+            pre_result = torch.max(outputs.data, 1)[1]
+            cod = outputs
+            return pre_result, cod
+
+
+if __name__ == '__main__':
+    from docs.config import config
+    p = PredictObject(config)
+    print(p.predict('中国人民武装警察部队七台河支队武警七台河支队更换基层套装门项目更正公告'))

+ 59 - 0
utils/process.py

@@ -0,0 +1,59 @@
+# coding:utf-8
+import pandas as pd
+import re
+
+
+def text_parse(text):
+    # 正则过滤掉特殊符号,标点,英文,数字等
+    reg_1 = '[!"#%&\'()*+,-./::;;|=?@,\t—。?★、…【】《》?“”‘’![\\]^_`{|}~]+  '
+    # 去除空格
+    reg_2 = '\\s+'
+    text = re.sub(reg_1, ' ', text)
+    text = re.sub(reg_2, '', text)
+    # 去除换行符
+    text = text.replace('\n', '')
+    text = re.sub(reg_2, '', text)
+    return text
+
+
+def split_data(file_path):
+    df = pd.read_csv(file_path, encoding='utf-8')
+    df.text = df.text.map(text_parse)
+    df['label_id'] = df.label
+    df = df[['text', 'label_id']]
+    df = df.drop(df[df['text'].map(len) <2].index)
+    df = df.drop(df[df['text'].map(len) > 250].index)
+    df.drop_duplicates(inplace=True)
+    # 0.7,0.15,0.15比例划分训练集,测试集,验证集
+    df = df.sample(frac=1.0)
+    rows, cols = df.shape
+    split_index_1 = int(rows * 0.15)
+    split_index_2 = int(rows * 0.3)
+
+    # 数据分割
+    df_test = df.iloc[0:split_index_1, :]
+    df_dev = df.iloc[split_index_1:split_index_2, :]
+    df_train = df.iloc[split_index_2: rows, :]
+
+    df_test.to_csv('./data/test.txt', sep="\t", index=False, header=None, encoding='utf-8')
+    df_train.to_csv('./data/train.txt', sep="\t", index=False, header=None, encoding='utf-8')
+    df_dev.to_csv('./data/dev.txt', sep="\t", index=False, header=None,encoding='utf-8')
+    return df_test, df_dev, df_train
+
+
+def data_show(file):
+    df = pd.read_csv(file, names=['text', 'label'], sep='\t', encoding='utf-8')
+    len_list = []
+
+    for i in df.text:
+        len_list.append(len(i))
+    import matplotlib.pyplot as plt
+    import numpy as np
+    len_list.sort()
+    plt.bar(np.arange(len(len_list)), len_list)
+    plt.show()
+
+
+if __name__ == '__main__':
+    split_data('../data/other_data/target_label3.csv')
+    # pass

+ 160 - 0
utils/redis_helper.py

@@ -0,0 +1,160 @@
+# coding:utf-8
+import time
+
+import redis
+
+
+class RedisString(object):
+
+    def __init__(self, config):
+        self.config = config
+        self.r = self.connect()
+
+    def connect(self):
+        while True:
+            try:
+                self.r = redis.StrictRedis(host=self.config.get("host"), port=self.config.get("port"),
+                                           password=self.config.get('pwd'),
+                                           db=self.config.get('db',1), decode_responses=True)
+                break
+            except Exception as e:
+                print('redis_error',e)
+                time.sleep(20)
+        return self.r
+
+    def string_set(self, k,v,ex=None):
+        """ set -- 设置值 设置超时间"""
+        rest = self.r.set(k, v, ex=ex)
+        return rest
+
+    def update(self, key, v,ex=None):
+        if not self.r.exists(key):
+            rest = self.r.set(key, v,ex=ex)
+            return rest
+        else:
+            return 'exists'
+
+    def string_get(self, check_data):
+        """ get -- 获取值 """
+        try:
+            rest = self.r.get(check_data)
+        except Exception as e:
+            print(e)
+            rest = None
+        return rest
+
+    def string_mset(self):
+        """ mset -- 设置多个键值对 """
+        d = {
+            'user3': 'Bob',
+            'user4': 'Bobx'
+        }
+        rest = self.r.mset(d)
+        print(rest)
+        return rest
+
+    def string_mget(self):
+        """ mget -- 获取多个键值对 """
+        ls = ['user3', 'user4']
+        rest = self.r.mget(ls)
+        print(rest)
+        return rest
+
+    def exists(self,key):
+        if self.r.exists(key):
+            return True
+        return False
+
+    def string_del(self,key):
+        """ del """
+        rest = self.r.delete(key)
+        return rest
+
+    def add_string_to_set(self,key, string):
+        # 判断字符串是否存在于集合中
+        if not self.r.sismember(key, string):
+            # 将字符串添加到集合中
+            rest = self.r.sadd(key, string)
+            return True
+        else:
+            return False
+
+# 定义一个函数来删除不以 "jycode_" 开头的键
+
+
+
+
+if __name__ == '__main__':
+    configs = {'host':'192.168.3.109','port':'6379','pwd':'root','db':1}
+    r = RedisString(configs)
+    import json
+    mysql_path = {
+        # 'mysql_host': '172.17.4.242',
+        'mysql_host': '192.168.3.14',
+        # 'mysql_host': '127.0.0.1',
+        # 'mysql_user': 'liukangjia',
+        'mysql_user': 'root',
+        # 'mysql_password': 'Lkj#20230630N',
+        'mysql_password': '=PDT49#80Z!RVv52_z',
+        # 'mysql_password': '123456',
+        # 'mysql_db': 'Call_Accounting',
+        'mysql_db': 'lkj',
+        'mysql_port': '4000',
+        # 'mysql_port': '3306'
+    }
+    # with open('../data/code_to_name.json', 'r', encoding='utf-8') as f:
+    #     name_maps = json.load(f)
+    # print(r.update('jycode_C23120301','sss',ex=30))
+    # cursor = '0'
+    # while True:
+    #     cursor, keys = r.r.scan(cursor=cursor)
+    #     for key in keys:
+    #         if not key.startswith('jycode_'):
+    #             r.r.delete(key)
+    #     if cursor == '0':
+    #         break
+    print(r.string_get('jycode_C210407',))
+    # exit()
+    # for k,v in name_maps.items():
+    #     print(k,v)
+    #     r.string_set('jycode_'+k,v)
+    # from mysql_helper import MysqlConn
+    # m = MysqlConn(mysql_path)
+    # max_id = 0
+    # import pandas as pd
+    # df = pd.read_csv('../gpa.csv',names=['a','b'])
+    # # print(df.head())
+    # # print(len(df))
+    # # print(df.isna().sum().sum())
+    # # with open('../标题高频主干词.csv','r',encoding='utf-8') as f:
+    #     # data = f.readlines()
+    # # for i in data:
+    # #     print(i)
+    # for i in df.values:
+    #     # print(i)
+    # #     print(i[0])
+    #     if r.exists(i[0]):
+    #         continue
+    #     if type(i[0]) != str:
+    #         continue
+    #     with open('high_words2.csv','a',encoding='utf-8') as f:
+    #         f.write(i[0]+'\n')
+    #     r.string_set(i[0],'',14 * 24 * 60 * 60)
+    # while True:
+    #     m.cursor.execute(f'select * from zc_topic where id >{max_id} limit 1000')
+    #     data = m.cursor.fetchall()
+    #     if len(data) == 0:
+    #         break
+    #     for row in data:
+    #         print(row)
+    #         max_id = row[0]
+    #         v = {'code':row[2], 'root': row[4],'source':0.99, 'mode': 'mode1'}
+    #         v = json.dumps(v,ensure_ascii=False)
+    #         r.string_set(row[1],v,300)
+        # exit()
+    # print(r.string_get('10M纯铜电话线(带水晶头)'))
+    # r.string_set('10M纯铜电话线(带水晶头)','')
+    # r.string_set('lk',11)
+    # print(r.string_get('lk'))
+    # print(res.string_del('德州市环境保护科学研究所有限公司'))
+    # res.string_del('lkj')

+ 207 - 0
utils/request_fun.py

@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/7/10
+# @Author  : lkj
+# @description :  服务请求
+import copy
+import json
+import grpc
+from proto.text2vector_pb2 import Text2VectorReq, Text2VectorResp
+import requests
+from pymongo import MongoClient
+from a2s.a2s_client import a2s_execute
+from a2s.tools import json_serialize, json_deserialize
+from config import daili
+
+
+def top_t(text):
+    data = {
+        "text": text,
+    }
+    result = a2s_execute(a2s_ip=daili, topic="topic_tract", timeout=6, bytes_data=json_serialize(data))
+    result=json_deserialize(result)
+    if result.get('status',400) == 200:
+        return result.get('result','')
+    else:
+        return ''
+
+
+def text_to_vector(data):
+    """
+    文本转向量
+    :param text:
+    :return:
+    """
+    result = []
+    try:
+        # 调用服务端方法
+        resp = Text2VectorReq(text=data)
+        response = a2s_execute(daili, "t2v2", timeout=30, bytes_data=resp.SerializeToString())
+        # 获取结果
+        req = Text2VectorResp()
+        req.ParseFromString(response)
+        result = list(req.vector)
+    except Exception as e:
+        print(e)
+    return result
+
+
+def chat_glm(text):
+    "百川"
+    response = requests.post('http://123.207.51.36:7885/', json={"prompt":text,"identity":"剑鱼chat","top_p":0.8,"temperature":0.8,},timeout=150).json()
+    result = response.get('response', '')
+    glm_flg = 1
+    # print('glm',result)
+    for i in ['属于不同类', '不属于','可能属于', '不属于同一类', '不是同一类', '是不同类', '属于不同', '不是', '不同的', '不一定', '一定不是']:
+        if i in result:
+            glm_flg = 0
+            break
+    return glm_flg,result
+
+
+def ali2(text):
+    """
+    阿里大模型
+    :param text:
+    :param classify:
+    :return:
+    """
+    response = requests.post('http://123.207.51.36:7880', json={"prompt":text,"identity":"剑鱼chat","top_p":0.8,"temperature":0.8,'max_length':100},timeout=150  ).json()
+    result = response.get('response', '')
+    glm_flg = 1
+    # print('al',result)
+    for i in ['属于不同类', '不属于','可能属于', '不属于同一类', '不是同一类', '是不同类', '属于不同', '不是', '不同的', '不一定', '一定不是']:
+        if i in result:
+            glm_flg = 0
+            break
+    return glm_flg,result
+
+
+def chat_jzy(text):
+    '书生'
+    response = requests.post(f'http://119.91.64.110:7885/', json={"prompt":text,"identity":"剑鱼chat","top_p":0.8,"temperature":0.8,"max_length":100},timeout=300).json()
+    result = response.get('response', '')
+    glm_flg = 1
+    # print('zy',result)
+    for i in ['属于不同类', '不属于','可能属于', '不属于同一类', '不是同一类', '是不同类', '属于不同', '不是', '不同的', '不一定', '一定不是']:
+        if i in result:
+            glm_flg = 0
+            break
+    return glm_flg,result
+
+
+def three_classify(title_,detail):
+    if '维修' in title_:
+        title = title_.replace('维修','')
+        response = requests.post('http://192.168.3.109:20623/search/', data={'title': title, 'detail': detail}).json()
+        result = response.get('result', [])
+        if result[0] == '货物':
+            result = ['服务', '60%']
+            return result
+        else:
+            response = requests.post('http://192.168.3.109:20623/search/',
+                                     data={'title': title_, 'detail': detail}).json()
+            result = response.get('result', [])
+            return result
+    else:
+        response = requests.post('http://192.168.3.109:20623/search/', data={'title': title_, 'detail': detail}).json()
+        result = response.get('result', [])
+    return result
+
+
+def seq_gpt(text,labels,task='分类'):
+    """
+    #### 三大分类测试    text:'{text}\n属于下面哪一个分类'    labels:'服务类,工程类,货物类'
+    :param text:
+    :param labels:
+    :param task: 可选'分类'or '抽取' 默认分类
+    :return:
+    """
+    response = requests.post('http://192.168.3.109:20016/cls',data={
+        'text':f'{text}','task':f'{task}','labels':f'{labels}'}).json()
+    print(response)
+    result = response.get('result',{}).get('text','')
+    return result
+
+def is_goods(title):
+    response = requests.post('http://192.168.3.109:20622/',data={'title':title}).json()
+    result = response.get('result', [])
+    return result
+
+
+def ts_ent(title):
+    response = requests.post('http://192.168.3.109:20631/',data={'text':title}).json()
+    print(response)
+    result = response.get('result', [])
+    entity = result.get('entity',[])
+    entity_list = []
+    for e in entity:
+        if e:
+            for flag in ['basic','food','product', 'matter','medicine','activity']:
+                if flag in e[1] and e[0] not in entity_list:
+                    entity_list.append(e[0])
+    res = ''.join(entity_list)
+    if not res:
+        res = title
+    return result,res
+
+
+def process_model(text):
+    # response = requests.get(f'http://192.168.3.109:8998/product_detail?prompt={text}').text
+    old_text = copy.deepcopy(text)
+    # response = eval(response)
+    # print(response.text)
+    text_list = []
+    # result = response.get('output', [])
+    result = a2s_execute("192.168.3.240:9090","recognition_goods",timeout=30,bytes_data=json_serialize({"text":text}))
+    result = eval(result).get('output',[])
+    for res in result:
+        type_ = res.get('type','')
+        span = res.get('span','')
+        if '材质' in type_ or '款式' in type_ or '产品' in type_ or '对象' in type_ or '适用场景' in type_:
+            if span not in text_list:
+                text_list.append(span)
+        # for flag in ['品牌','系列','型号','规格','颜色','其他','修饰','组织机构']:
+        #     if flag in type_:
+        #         if '材质' in type_ or '款式' in type_ or '产品' in type_:
+        #             continue
+        #
+        #         text = text.replace(span, '')
+        #         break
+    text = ''.join(text_list)
+    if len(text)<=1:
+        text = old_text
+    return text
+
+
+def con():
+    mongo_client = MongoClient('192.168.3.71:29099')
+    col_ = mongo_client['re4art']['better_goods']
+    return col_
+
+
+if __name__ == '__main__':
+    # text = '晨光 M&G Eplus盒 装黑色长尾夹 ABS92728 32mm 12个/盒 12盒/包 120盒/箱'
+    # c = '长尾票夹'
+    # print(1111,three_classify('恒源祥 H4YH10全棉长绒棉60支贡缎刺四件套1.5/1.8米床单被套200*230',''))
+    # print(222,chat_glm('新华书店/百年初心成大道——党史学习教育案例选编/政治类书籍属于装订图书吗'))
+    # print(333,chat_jzy(text,c))
+    # exit()
+    tes = '润农生物饲料研发生产线技改项目'
+    # print(three_classify(tes,''))
+    print(text_to_vector(tes))
+    exit()
+    col = con()
+    count = 0
+
+    for row in col.find({"purchasing_score": {'$gte':0.7}},{'title':1,'detail':1,'s_subscopeclass':1,'projectname':1,'purchasinglist':1,'purchasing_score':1}).sort('_id', 1):
+        title = row.get('title', '')
+        ids = row['_id']
+        detail = row.get('detail', '')
+        c = three_classify(title,detail)
+        basicClass = ''
+        rate = '0%'
+        if c:
+            basicClass = c[0]
+            rate = c[1]
+        # col.update_one({'_id': ids}, {'$set': {'basicClass': basicClass,'rate':rate}})
+        print(ids)

Plik diff jest za duży
+ 62 - 0
utils/request_test.py


+ 29 - 0
utils/t2v_client.py

@@ -0,0 +1,29 @@
+# coding:utf-8
+import grpc
+from a2s.a2s_client import a2s_execute
+# from proto import service_pb2, service_pb2_grpc
+
+from proto.text2vector_pb2 import Text2VectorReq, Text2VectorResp
+
+
+def start(data: str):
+    # 本次不使用SSL,所以channel是不安全的
+    result = []
+    try:
+            # 调用服务端方法
+            resp = Text2VectorReq(text=data)
+            response=a2s_execute("192.168.3.240:9090","t2v2",timeout=60,bytes_data=resp.SerializeToString())
+            # 获取结果
+            req = Text2VectorResp()
+            req.ParseFromString(response)
+            result = list(req.vector)
+    except Exception as e:
+        print(e)
+    return result
+
+
+if __name__ == '__main__':
+    r = start("公共厕所服务")
+    print(len(r))
+    print(r)
+# http://172.17.145.164:19805,http://172.17.148.50:19805,http://172.17.4.184:19805

+ 101 - 0
utils/title_ner.py

@@ -0,0 +1,101 @@
+# coding:utf-8
+import re
+
+from a2s.a2s_client import a2s_execute
+from a2s.tools import json_serialize, json_deserialize
+from loguru import logger
+from config import daili
+from utils.request_fun import top_t
+
+
+def start(data: dict):
+    # 本次不使用SSL,所以channel是不安全的
+    result = {}
+    try:
+        retry = 5
+        for r in range(retry):
+            bytes_data = json_serialize(data)
+            result = a2s_execute(daili, 'title_ner', 60, bytes_data)
+            if result is None:
+                continue
+            result = json_deserialize(result)
+            return result
+    except Exception as e:
+        logger.info(str(e))
+    return result
+
+
+def title_topic_merge(text):
+    """
+    标题信息抽取,合并多标的物查询
+    """
+    tet = re.sub(r'[^\w\s]', '', text)
+    print(tet)
+    input_text = {"text": tet}
+    res = start(input_text)
+    topic_res = ''
+    flag = ''
+    if res:
+        res_list = res.get('result',[])
+        for i in res_list:
+            target = i.get('TARGET', [])
+            topic_res = ''.join([topic[0] for topic in target])
+    if topic_res in ['建设']:
+        topic_res = ''
+    return topic_res,flag
+
+
+def title_topic_process(text,):
+    """
+    标题信息抽取
+    """
+    input_text = {"text": re.sub(r'[^\w\s]', '', text).replace('定点','')}
+    pattern = r'项目'
+    count_re = len(re.findall(pattern, text))
+    res = start(input_text)
+    topic_res = ''
+    flag = ''
+    if res:
+        res_list = res.get('result',[])
+        for i in res_list:
+            target = i.get('TARGET', [])
+            if count_re >=2:
+                topic_res = ''.join([topic[0] for topic in target])
+            else:
+                for j in target:
+                    topic_res = j[0]
+                    flag = 'ner'
+    if topic_res in ['建设']:
+        topic_res = ''
+    if not topic_res:
+        topic_res = text
+        flag = 'title'
+    return topic_res,flag
+
+
+def topic_trace(title,projectname):
+    """
+    主干词抽取
+    """
+    if '采购意向' in projectname:
+        return title
+    if ('采购意向' in title or '...' in title) and projectname:
+        title_topic, flag = title_topic_process(projectname)
+    else:
+        title_topic, flag = title_topic_process(title)
+        if title_topic == title and projectname:
+            title_topic, flag = title_topic_process(projectname)
+    if not title_topic:
+        title_topic = top_t(title)
+    if not title_topic:
+        title_topic = top_t(projectname)
+    if not title_topic:
+        title_topic = title
+    title_topic = re.sub(r'[^\w\s]', '', title_topic)
+    return title_topic
+
+if __name__ == '__main__':
+    data = " 广州公司-(珠海)智慧能源-显示屏-2312(急)变更公告"
+    r = start({"text": data})
+    print(topic_trace(data, data))
+    print(r)

+ 179 - 0
utils/topic_extract.py

@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2023/3/3 14:04
+# @Author  : lkj
+import json
+import re
+from pathlib import Path
+
+import jieba
+from LAC import LAC
+
+jieba.add_word('等保')
+
+
+class Topic(object):
+
+    def __init__(self):
+
+        self.base_dir = Path(__file__).resolve().parent.parent
+        self.lac = LAC(mode='lac')
+        self.lac.load_customization('./data/lac_dict.txt')
+        with open('./data/stopwords_topic.txt', 'r', encoding='utf-8') as f:
+            stopwords = f.readlines()
+            self.stopwords = [i.replace('\n', '') for i in stopwords]
+        with open('./data/stoptext.txt', 'r', encoding='utf-8') as f:
+            self.stopcontent = f.readlines()
+        self.hw = open('./data/hw.txt', 'r', encoding='utf-8').readlines()
+        self.gcs = open('./data/gc.txt', 'r', encoding='utf-8').readlines()
+        self.fws = open('./data/fw.txt', 'r', encoding='utf-8').readlines()
+
+    def classify(self, text):
+        """
+        分类规则
+        :param text:
+        :return:
+        """
+        class_name = []
+        flag = text[-4::]
+        for good in self.hw:  # 货物
+            good = good.replace('\n', '')
+            if good in flag:
+                class_name.append('货物')
+        for gc in self.gcs:  # 工程
+            gc = gc.replace('\n', '')
+            if gc in flag:
+                class_name.append('工程')
+        for fw in self.fws:  # 服务
+            fw = fw.replace('\n', '')
+            if fw in flag:
+                class_name.append('服务')
+        class_name = list(set(class_name))
+        for i in ['及', '建设', '系统', '升级']:  # 不能确定规则
+            if i in text[-8::]:
+                class_name.clear()
+        if len(class_name) > 1:
+            class_name.clear()
+        return class_name
+
+    def lac_cut(self, text):
+        """
+        lac 切除头部数据
+        :param text:
+        :return:
+        """
+        lac_result = self.lac.run(text)
+        lac_res = []
+        index_list = []
+        for index, pos in enumerate(lac_result[1]):
+            if pos in ['PER', 'LOC', 'ORG']:
+                index_list.append(index)
+        if index_list:  # 识别到地点等词性直接去除前边所有
+            del lac_result[0][0:max(index_list) + 1]
+            del lac_result[1][0:max(index_list) + 1]
+        for index, pos in enumerate(lac_result[1]):
+            if pos in ['w', 't', 'ns', ]:  # 判断如果词性保留w词性中‘.’
+                start = index - 1
+                if start < 0:
+                    start = 0
+                end = index + 1
+                if end == len(lac_result[1]):
+                    end = index
+                if lac_result[1][start] and lac_result[1][end] == 'm':  # 小数点定位
+                    lac_res.append(lac_result[0][index])
+                continue
+            lac_res.append(lac_result[0][index])
+        lac_res = "".join(lac_res)
+        return lac_res
+
+    @staticmethod
+    def re_process(text):
+        """
+        正则匹配规则
+        :param text:
+        :return:
+        """
+        text = re.sub('第.*?包', '', text)
+        re_list2 = re.findall('\(.*?\)', text)
+        for i in re_list2:
+            if i not in ['(勘察)', '(测绘)', '(监理)']:
+                text = text.replace(i, '')
+        text = re.sub('\[.*?\]', '', text)
+        text = re.sub('(.*?)', '', text)
+        text = re.sub('.*大楼', '', text)
+        text = re.sub('.*号楼', '', text)
+        text = re.sub(r"\d{4}年\d{1,2}至\d{1,2}月", '', text)
+        text = re.sub('.*?-竞争性磋商-[a-zA\--Z0-9_]{4,20}', '', text)
+        text = re.sub('.*?-竞争性谈判-[a-zA\--Z0-9_]{4,20}', '', text)
+        text = re.sub('.*?-公开招标-[a-zA\--Z0-9_]{4,20}', '', text)
+        text = re.sub('[0-9]{1,9}年度', '', text)
+        text = re.sub('[0-9]{4,9}年', '', text)
+        text = re.sub('[0-9]{1,2}月', '', text)
+        text = re.sub('[0-9]{1,2}日', '', text)
+        text = re.sub('[!#%&()*+,/\-·$¥::;;,()|=?@\t—?★【】《》?、!\[\[^_`{|}~]', '', text)
+        text = re.sub('[a-zA-Z0-9_]{5,30}', '', text)
+        text = re.sub('工字.*', '', text)
+        text = re.sub('.*县', '', text)
+        text = re.sub('.*委员会', '', text)
+        text = re.sub('第[0-9]{0,4}包', '', text)
+        text = re.sub('.*村委会', '', text)
+        text = re.sub('.*州界', '', text)
+        text = re.sub('.*大学', '', text)
+        text = re.sub('.*学院', '', text)
+        text = re.sub('20[0-9]{2}级', '', text)
+        text = re.sub('20[0-9]{2}', '', text)
+        return text
+
+    def stop_word(self, text: str):
+        """
+        停用词
+        :param text:
+        :return:
+        """
+        jieba_cut = jieba.lcut(text)
+        new_text = []
+        for ind, i in enumerate(jieba_cut):
+            if i not in self.stopwords:
+                new_text.append(i)
+        text = ''.join(new_text)
+        return text
+
+    def stop_content(self, text: str):
+        """
+        停用文本--->当一些固定的词需要切除但是可能会被切词工具切错如:重采购,重招标
+        :param text:
+        :return:
+        """
+        for sw in self.stopcontent:
+            sw = sw.replace('\n', '')
+            if sw in text:
+                text = text.replace(sw, '')
+        return text
+
+    def tract(self, text):
+        """
+        main 函数
+        :param text:
+        :return:
+        """
+        try:
+            old_text = text
+            text = self.re_process(text)  # 正则
+            text = self.stop_content(text)  # 停特定文本词汇
+            text = self.lac_cut(text)  # lac去loc,org等词性
+            text = self.stop_word(text)  # 停用词
+            cls = ''
+            if text:
+                if jieba.lcut(text)[0] in ['及', '至', '和', '与', '所', '并']:
+                    text = text[1::]
+                cls = ''.join(self.classify(text))
+                # print('类别-->', cls)
+            return text, cls
+        except Exception as e:
+            print('规则error',e)
+            return '',''
+
+if __name__ == '__main__':
+    t = Topic()
+    while True:
+        a = input('>>>>>')
+        print(t.tract(a))

+ 116 - 0
utils/train_eval.py

@@ -0,0 +1,116 @@
+# coding: UTF-8
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from sklearn import metrics
+import time
+from utils import get_time_dif
+from tensorboardX import SummaryWriter
+
+
+# 权重初始化,默认xavier
+def init_network(model, method='xavier', exclude='embedding', seed=123):
+    for name, w in model.named_parameters():
+        if exclude not in name:
+            if 'weight' in name:
+                if method == 'xavier':
+                    nn.init.xavier_normal_(w)
+                elif method == 'kaiming':
+                    nn.init.kaiming_normal_(w)
+                else:
+                    nn.init.normal_(w)
+            elif 'bias' in name:
+                nn.init.constant_(w, 0)
+            else:
+                pass
+
+
+def train(config, model, train_iter, dev_iter, test_iter):
+    start_time = time.time()
+    model.train()
+    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
+    total_batch = 0  # 记录进行到多少batch
+    dev_best_loss = float('inf')
+    last_improve = 0  # 记录上次验证集loss下降的batch数
+    flag = False  # 记录是否很久没有效果提升
+    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
+    for epoch in range(config.num_epochs):
+        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
+        for i, (trains, labels) in enumerate(train_iter):
+            #
+            outputs = model(trains)  #)
+            model.zero_grad()  # 每次迭代,梯度清零,不然会累加
+            loss = F.cross_entropy(outputs, labels)
+            loss.backward()
+            optimizer.step()
+            if total_batch % 100 == 0:
+                true = labels.data.cpu()  # 后面要打印数据,提前送回CPU
+                predic = torch.max(outputs.data, 1)[1].cpu()  # torch.max返回最大值和索引,[1]说明只需要索引
+                train_acc = metrics.accuracy_score(true, predic)  # 计算正确率
+                dev_acc, dev_loss = evaluate(config, model, dev_iter)  # 每100个训练batch就评估模型
+                if dev_loss < dev_best_loss:  # 只要效果好,就保存模型
+                    dev_best_loss = dev_loss
+                    torch.save(model.state_dict(), config.save_path)
+                    improve = '*'
+                    last_improve = total_batch
+                else:
+                    improve = ''
+                time_dif = get_time_dif(start_time)
+                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
+                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
+                writer.add_scalar("loss/train", loss.item(), total_batch)
+                writer.add_scalar("loss/dev", dev_loss, total_batch)
+                writer.add_scalar("acc/train", train_acc, total_batch)
+                writer.add_scalar("acc/dev", dev_acc, total_batch)
+
+                model.train()
+            total_batch += 1
+            if total_batch - last_improve > config.require_improvement:
+                print("No optimization for a long time, auto-stopping...")
+                flag = True
+                break
+        if flag:
+            break
+    writer.close()
+    test(config, model, test_iter)
+
+
+def test(config, model, test_iter):
+    # test
+    model.load_state_dict(torch.load(config.save_path))  # 加载保存的当前最好的模型
+    model.eval()  # 评估模式,冻结dropout等层
+    start_time = time.time()
+    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
+    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
+    print(msg.format(test_loss, test_acc))
+    print("Precision, Recall and F1-Score...")
+    print(test_report)
+    print("Confusion Matrix...")
+    print(test_confusion)
+    time_dif = get_time_dif(start_time)
+    print("Time usage:", time_dif)
+
+
+def evaluate(config, model, data_iter, test=False):
+    model.eval()
+    loss_total = 0
+    predict_all = np.array([], dtype=int)
+    labels_all = np.array([], dtype=int)
+    with torch.no_grad():
+        for texts, labels in data_iter:
+            outputs = model(texts)
+            loss = F.cross_entropy(outputs, labels)
+            loss_total += loss
+            labels = labels.data.cpu().numpy()  # 这里后面用到了np.append,所以需要.numpy(
+            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
+            labels_all = np.append(labels_all, labels)  # 拼接所有label
+            predict_all = np.append(predict_all, predic)  # 拼接所有predict
+
+    acc = metrics.accuracy_score(labels_all, predict_all)
+    if test:
+        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
+
+        confusion = metrics.confusion_matrix(labels_all, predict_all)
+        return acc, loss_total / len(data_iter), report, confusion
+    return acc, loss_total / len(data_iter)

+ 155 - 0
utils/utils.py

@@ -0,0 +1,155 @@
+# coding: UTF-8
+import os
+import torch
+import numpy as np
+import pickle as pkl
+from tqdm import tqdm
+import time
+from datetime import timedelta
+import pandas as pd
+
+MAX_VOCAB_SIZE = 100000  # 词表长度限制
+
+
+def build_vocab(file_path, tokenizer, max_size, min_freq):  # 建立词表
+    UNK, PAD = '<UNK>', '<PAD>' 
+    vocab_dic = {}
+    with open(file_path, 'r', encoding='UTF-8') as f:
+        for line in tqdm(f):
+            lin = line.strip()
+            if not lin:
+                continue
+            content = lin.split('\t')[0]
+            for word in tokenizer(content):
+                vocab_dic[word] = vocab_dic.get(word, 0) + 1
+        vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
+        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
+        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
+    return vocab_dic
+
+
+def build_dataset(config, ues_word):  # 数据加载
+    UNK, PAD = '<UNK>', '<PAD>' 
+    if ues_word:
+        tokenizer = lambda x: x.split(' ')  # 以空格隔开,word-level
+    else:
+        tokenizer = lambda x: [y for y in x]  # char-level
+    if os.path.exists(config.vocab_path):
+        vocab = pkl.load(open(config.vocab_path, 'rb'))
+    else:
+        vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
+        pkl.dump(vocab, open(config.vocab_path, 'wb'))
+    print(f"Vocab size: {len(vocab)}")
+
+    def load_dataset(path, pad_size=300):
+        contents = []
+        with open(path, 'r', encoding='UTF-8') as f:
+            for line in tqdm(f):
+                lin = line.strip()
+                if not lin:
+                    continue
+                content, label = lin.split('\t')
+                words_line = []
+                token = tokenizer(content)
+                seq_len = len(token)
+                if pad_size:
+                    if len(token) < pad_size:
+                        token.extend([PAD] * (pad_size - len(token)))
+                    else:
+                        token = token[:pad_size]
+                        seq_len = pad_size
+                # word to id
+                for word in token:
+                    words_line.append(vocab.get(word, vocab.get(UNK)))
+                contents.append((words_line, int(label), seq_len))
+
+        return contents  # [([...], 0), ([...], 1), ...]
+    train = load_dataset(config.train_path, config.pad_size)
+    dev = load_dataset(config.dev_path, config.pad_size)
+    test = load_dataset(config.test_path, config.pad_size)
+    return vocab, train, dev, test
+
+
+class DatasetIterater(object):
+    def __init__(self, batches, batch_size, device,model_name):
+        self.batch_size = batch_size
+        self.batches = batches
+        self.n_batches = len(batches) // batch_size
+        self.residue = False  # 记录batch数量是否为整数
+        if len(batches) % self.n_batches != 0:
+            self.residue = True
+        self.index = 0
+        self.device = device
+        self.model_name = model_name
+
+    def _to_tensor(self, datas):
+        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
+        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
+
+        # pad前的长度(超过pad_size的设为pad_size)
+        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
+        if self.model_name == 'Bert':
+            mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
+            return (x, seq_len, mask), y
+        return (x, seq_len), y
+
+    def __next__(self):
+        if self.residue and self.index == self.n_batches:
+            batches = self.batches[self.index * self.batch_size: len(self.batches)]
+            self.index += 1
+            batches = self._to_tensor(batches)
+            return batches
+
+        elif self.index >= self.n_batches:
+            self.index = 0
+            raise StopIteration
+        else:
+            batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
+            self.index += 1
+            batches = self._to_tensor(batches)
+            return batches
+
+    def __iter__(self):
+        return self
+
+    def __len__(self):
+        if self.residue:
+            return self.n_batches + 1
+        else:
+            return self.n_batches
+
+
+def build_iterator(dataset, config):
+    iter = DatasetIterater(dataset, config.batch_size, config.device,config.model_name)
+    return iter
+
+def get_vocab():
+    df = pd.read_csv('../data/vocab.pkl', names=['word', 'id'])
+    return list(df['word']), dict(df.values)
+
+def get_time_dif(start_time):
+    """获取已使用时间"""
+    end_time = time.time()
+    time_dif = end_time - start_time
+    return timedelta(seconds=int(round(time_dif)))
+
+
+if __name__ == "__main__":
+    '''提取预训练词向量'''
+    data_path = '../data'
+    train_dir = data_path + "/train.txt"
+    vocab_dir = data_path + "/vocab.pkl"
+    # pretrain_dir = "./data/word_embedding/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5"
+    emb_dim = 300
+    filename_trimmed_dir = "../data/word_embedding/embedding_table"
+    if os.path.exists(vocab_dir):
+        word_to_id = pkl.load(open(vocab_dir, 'rb'))
+    else:
+        # tokenizer = lambda x: x.split(' ')  # 以词为单位构建词表(数据集中词之间以空格隔开)
+        tokenizer = lambda x: [y for y in x]  # 以字为单位构建词表
+        word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
+        pkl.dump(word_to_id, open(vocab_dir, 'wb'))
+        print(word_to_id)
+
+    embeddings = np.random.rand(len(word_to_id), emb_dim)
+    np.save(filename_trimmed_dir, embeddings=embeddings)

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików