lijunliang 1 gadu atpakaļ
revīzija
e5608bafff

BIN
ai-0.0.4-py3-none-any.whl


+ 43 - 0
create_dict.py

@@ -0,0 +1,43 @@
+# coding:utf-8
+from machine_models.tools import link_db
+from machine_models.tools import chinese2vector
+from machine_models.tools import tfidf
+import joblib
+
+if __name__ == '__main__':
+    m_config = {
+        "db": "re4art",
+        "col": "bidding_china_4_9",
+        "host": "192.168.3.207:27092",
+    }
+    with open("stopwords.txt", "r") as f:
+        stop_words = [word.strip() for word in f.readlines()]
+
+    client, col = link_db(m_config)
+    corpus = []
+    with open("./target.csv", "r") as f:
+        read_data = f.read()
+        read_data = read_data.replace("\n", " ")
+    other = chinese2vector(read_data, remove_word=["x"], stopwords=stop_words)
+    print(other)
+    corpus.append(other)
+    count = 0
+    for row in col.find({}).sort("_id", 1):
+        # detail = row.get("detail", "")
+        # title = row.get("title", "")
+        count += 1
+        print(count)
+        # corpus = chinese2vector(title + detail.lower(), remove_word=["x", "m"], stopwords=stop_words)
+        # col.update_one({"_id": row["_id"]}, {"$set": {"cut_detail": corpus}})
+        cut_detail = row.get("cut_detail", "")
+        corpus.append(cut_detail)
+        # if len(contents) > 10000:
+        #     cut_ret = chinese2vectors(contents, remove_word=["x"], stop_words=stop_words)
+        #     corpus.extend(cut_ret)
+        #     contents = []
+    # if contents:
+    #     cut_ret = chinese2vectors(contents, remove_word=["x"], stop_words=stop_words)
+    #     corpus.extend(cut_ret)
+    #     contents = []
+    tfidf_vec, tfidf_ret = tfidf(analyzer="word", space_words=corpus)
+    joblib.dump(tfidf_vec, "docs/model/dictionary")

BIN
docs/__pycache__/config.cpython-37.pyc


+ 54 - 0
docs/config.py

@@ -0,0 +1,54 @@
+# coding:utf-8
+
+mysql_config = {
+    "db": "machineLearning",
+    "ip": "192.168.3.109",
+    "port": "4000",
+    "user": "root",
+    "pwd": "Tidb#20220214",
+    "charset": "utf8"
+}
+
+source_mongo_config = {
+    "host": "192.168.3.207:27001,192.168.3.206:27002",
+    "user": "jyDevGroup",
+    "password": "jy@DevGroup",
+    "db": "qfw_data",
+    "col": "bidding"
+}
+
+catch_mongo_config = {
+    "host": "192.168.3.207:27092",
+    "user": "",
+    "password": "",
+    "db": "re4art",
+    "col": "catch_mongo_test"
+}
+
+oss_file_config = {
+    "access_key_id": "LTAI4G5x9aoZx8dDamQ7vfZi",
+    "access_key_secret": "Bk98FsbPYXcJe72n1bG3Ssf73acuNh",
+    "endpoint": "oss-cn-beijing.aliyuncs.com",
+    "bucket_name": "jy-datafile",
+}
+
+oss_txt_config = {
+    "access_key_id": "LTAI4G5x9aoZx8dDamQ7vfZi",
+    "access_key_secret": "Bk98FsbPYXcJe72n1bG3Ssf73acuNh",
+    "endpoint": "oss-cn-beijing.aliyuncs.com",
+    "bucket_name": "topjy",
+}
+
+convertField = {
+    "标题": "cut_title",
+    "正文": "cut_detail",
+    "采购单位": "cut_buyer",
+    "中标单位": "cut_winner",
+    "标的物": "cut_purchasing",
+    "附件": "cut_attach_text"
+}
+
+stopWordsPath = "./docs/stopwords.txt"
+baseDir = "./docs/"
+dictionaryPath = "./docs/dictionary"
+dictionaryUrl = "111111"

BIN
docs/model/dictionary


BIN
docs/model/model.model


+ 627 - 0
docs/stopwords.txt

@@ -0,0 +1,627 @@
+项目
+公告
+-
+招标
+)
+(
+2019
+公示
+、
+中标
+年
+结果
+的
+中心
+公开
+成交
+竞争性
+限公司
+关于
+(
+[
+]
+)
+及
+建设
+有
+合同
+和
+中国
+:
+等
+更
+谈判
+“
+平台
+”
+询价
+单一来源
+安全
+有限公司
+1
+_
+磋
+商公告
+二次
+公司
+编号
+【
+】
+与
+询价公告
+公安局
+集团
+管理
+字
+选人
+化
+年度
+磋商
+县
+段
+项目
+文件
+2019
+联系
+单位
+时间
+信息
+提供
+进行
+功能
+支持
+中心
+名称
+内容
+公告
+政府
+品牌
+方式
+满足
+00
+工作
+限公司
+合同
+质疑
+供应
+具有
+地址
+必须
+交易
+工程
+成交
+编号
+建设
+应商
+现场
+相关
+具备
+05
+10
+资格
+地点
+需求
+30
+中标
+公司
+根据
+检测
+管理
+通过
+注册
+条件
+平台
+公共
+中国
+安全
+企业
+31
+网上
+截止
+报名
+网络
+包括
+修改
+资料
+资质
+06
+用户
+不予
+国家
+本次
+规定
+附件
+咨询
+此条
+使用
+磋商
+生产
+金额
+保证
+cn
+数据
+文件
+记录
+项目
+投标
+2019
+招标
+时间
+供应商
+有限公司
+进行
+满足
+功能
+内容
+以上
+具有
+00
+质疑
+必须
+联系人
+公共资源
+05
+10 
+30
+联系方式
+现场
+名称
+公司
+根据
+地址
+其他
+使用
+报名
+联系电话
+工作
+31
+条件
+包括
+06
+不予
+中心
+规定
+本次
+同时
+此条
+方式
+cn
+有关
+记录
+可以
+上虞
+所有
+http
+www
+完成
+如下
+保证金
+实际
+或者
+gov 
+不得
+三家
+12
+接受
+17
+问题
+文档
+人员
+11
+只有
+下载
+15
+情况
+34
+指定
+以下
+获取
+结果
+在线
+直接
+隐藏
+全部
+能力
+唯一
+按照
+&#
+集团
+参与
+CA
+倾向性
+竞争性
+查看
+com 
+有效
+400
+能够
+14
+开远市
+24
+统一
+组织
+存在
+列入
+综合
+承诺函
+失信
+16
+即可
+及时
+不少
+独立
+联合体
+010
+一个
+是否
+此参数
+20
+000
+形式
+0606
+58851111
+同一
+以及
+13
+没有
+输入
+拓久
+之日起
+亮显
+为了
+处理
+正式
+18
+影响
+唯一性
+IEBOARD
+视通
+东威
+标的
+09
+主要
+最高
+所在
+方有权
+鸿合
+指向性
+其它
+中原
+需要
+知识点
+限制性
+VGA
+分公司
+人民
+2018
+各个
+了解
+业绩
+社会
+为准
+法律法规
+上午
+日立
+小组
+下列
+ggzy
+发现
+三年
+原因
+有限责任
+参考
+无效
+携带
+欢迎
+删除
+IQ
+高科
+湖山
+联系
+无法
+无条件
+其中
+至少
+后期
+2.5
+确定
+并且
+含有
+海康
+不变
+内江
+01
+偏离
+结束
+用途
+必需
+除外
+快捷键
+此项
+220kV
+提出
+备注
+3.1
+4.2
+否则
+通用
+三个
+08
+期间
+日内
+45
+天内
+予以
+由此
+尺寸
+日至
+工作日内
+及其
+差别待遇
+28
+第二
+终止
+明确
+我司
+爱普生
+最终
+印天
+白光
+巨龙
+所以
+网点
+类型
+第一
+不再
+最大
+提升
+4.1
+一切
+品目
+联动
+物理
+任意
+条例
+成立
+变化
+明基
+松下
+赛尔
+独有
+带式
+送货
+一份
+公开
+限制
+文号
+产生
+04
+之间
+方向
+KYRW
+第三
+场地
+效果
+50
+理光
+鲅鱼圈区
+锐取
+airitilibrary
+学术
+第二十二条
+开始
+现将
+威胁
+符合国家
+大华
+一级
+霍邱
+一年
+方须
+医共体
+RJ45
+扣发
+总价
+我们
+包件
+华能
+从而
+健全
+限于
+因此
+多点
+办法
+可自相
+接到
+不能
+ccpc
+届满
+下同
+已经
+最新
+扩展
+落实
+稳定
+特点
+黑名单
+班班
+现对
+真实
+creditchina
+工期
+文件夹
+不足
+PCIe
+准确
+校正
+每个
+便于
+以次充好
+投机取巧
+因无
+歧视性
+包号
+项下
+5.3
+正本
+放弃
+原则上
+得到
+之前
+详细
+统计
+代理商
+整个
+跟踪
+后方
+时至
+提高
+目标
+长乐
+有意
+迷你
+退款
+蓝色
+富民县
+两侧
+板面
+西路
+或是
+拖动
+随时
+互动
+设定
+有赛
+练习
+RGB
+瀚驰
+漫游
+排斥
+小于
+专区
+登陆
+发售
+大厅
+联合
+大街
+成员名单
+关闭
+产地
+中途
+水货
+以便
+开具
+退还
+答疑
+按规定
+不间断
+务必
+初验
+成熟
+版权
+澄清
+原则
+方仅
+本级
+快捷
+历史
+多种
+主体
+对应
+银信
+目的
+信服
+划分
+龙井
+IP
+html
+之一
+重点
+女士
+定点
+计算
+上行政区域
+专用
+汇款
+容量
+适应
+全额
+单上
+常用
+实质性
+批次
+歧视
+精品
+有希沃
+演示
+转载
+非常
+语文
+逼真
+降低
+时间表
+秦淮
+天得
+路天恒
+为鸿合
+出入口
+禁止
+答题
+虚拟
+主页
+紫旭
+分局
+采办
+地区
+加装
+之外
+保有
+剩余时间
+GDC
+对接
+时限
+关系
+获得
+尚未
+户名
+超级
+进一步
+另行
+涉及
+每天
+到场
+平衡
+一台
+DRGs
+固定
+3T
+SSD
+保留
+GHz
+推荐
+自带
+登记表
+转包
+RIS
+临床
+指导
+雪亮
+对于
+现行
+方面
+估算
+产业
+立体声
+社库
+全文
+标包
+发出
+海棠
+
+
+
+
+
+
+
+

+ 122 - 0
machine_models/__init__.py

@@ -0,0 +1,122 @@
+# coding:utf-8
+from docs.config import dictionaryPath
+from docs.config import dictionaryUrl
+from machine_models.databases import File
+from machine_models.train_model import train
+from machine_models.databases.mysql_helper import Model
+from machine_models.databases import session
+from machine_models.databases.mysql_helper import Project
+from machine_models.predict_model import predict
+from util.file_operations import generate_directory, del_directory
+from docs.config import baseDir
+import os
+import joblib
+import uuid
+import datetime
+
+# 词典文件加载,只加载一次
+if not os.path.exists(dictionaryPath):
+    status = File.download_file(dictionaryUrl, dictionaryPath)
+    if not status:
+        raise ValueError("词典文件下载失败")
+tfidf_vec = joblib.load(dictionaryPath)
+
+def train_fail(project_id, user_id):
+    '''
+    记录失败日志
+    :param project_id:
+    :param user_id:
+    :return:
+    '''
+    fail_model = Model(state=2, projectId=project_id, createperson=user_id,
+                       createTime=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
+    session.add(fail_model)
+    session.commit()
+    return False
+
+
+def train_model(request_params: dict):
+    # 清空数据库链接对象缓存
+    session.expire_all()
+    session.commit()
+
+    # 获取训练项目数据
+    project_id = request_params.get("id")
+    user_id = request_params.get("userId", "")
+    label_type = request_params.get("type", 1)
+    fields = request_params.get("fields", "")
+    model_dir = ""
+    try:
+        # 不存在项目Id
+        if not project_id:
+            return train_fail(project_id, user_id)
+        model_dir = os.path.join(baseDir, str(uuid.uuid4()))
+        dir_status = generate_directory(model_dir)
+        # 文件夹生成错误
+        if not dir_status:
+            return train_fail(project_id, user_id)
+        # 开始训练
+        model_detail = train(project_id, fields.split(","), tfidf_vec, label_type, model_dir)
+        # 训练失败
+        if not model_detail:
+            return train_fail(project_id, user_id)
+        # 训练成功记录
+        model_detail.projectId = project_id
+        model_detail.state = 0
+        model_detail.createperson = user_id
+        model_detail.createTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+        session.add(model_detail)
+        session.commit()
+        # 清空本次训练生成文件
+        del_directory(model_dir)
+        return True
+    except Exception as e:
+        train_fail(project_id, user_id)
+        # 清空本次训练生成文件
+        if model_dir and os.path.exists(model_dir):
+            del_directory(model_dir)
+        return train_fail(project_id, user_id)
+
+
+def predict_model(request_params):
+    # 清空本地数据库缓存
+    session.expire_all()
+    session.commit()
+    # 获取预测参数
+    project_id = request_params.get("id", -1)
+    id_list = request_params.get("id_list", [])
+    model_id = request_params.get("model_id", -1)
+    project_info = session.query(Project).filter_by(id=project_id).first()
+
+    # 查询项目信息
+    if not project_info:
+        return {"error_code": 0, "error_message": f"项目信息不存在--> {project_id}"}
+    focus_field, target_label, label_type = project_info.focusField.split(
+        ","), project_info.labels.split(","), project_info.type
+
+    # 查询模型信息
+    model_info = session.query(Model).filter_by(id=model_id).first()
+    print(type(model_id), "-->", model_id)
+    if not model_info:
+        return {"error_code": 0, "error_message": f"模型信息不存在--> {model_id}"}
+
+    # 加载模型
+    model_url = model_info.modelFile
+    model_dir = os.path.join(baseDir, model_url)
+    model_path = os.path.join(model_dir, "model.model")
+    if not os.path.exists(model_path):
+        dir_status = generate_directory(model_dir)
+        if not dir_status:
+            return {"error_code": 0, "error_message": f"文件夹创建失败,请检查存储设备-->"}
+        status = File.download_file(model_url, model_path)
+        if not status:
+            return {"error_code": 0, "error_message": f"oss储存模型加载失败--> {model_id}"}
+    try:
+        data = predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path)
+    except Exception as e:
+        print(e)
+        return {"error_code": 0, "error_message": "预测过程出错"}
+    # 清空缓存
+    if os.path.exists(model_dir):
+        del_directory(model_dir)
+    return {"error_code": 1, "data": data}

BIN
machine_models/__pycache__/__init__.cpython-37.pyc


BIN
machine_models/__pycache__/predict_model.cpython-37.pyc


BIN
machine_models/__pycache__/tools.cpython-37.pyc


BIN
machine_models/__pycache__/train_model.cpython-37.pyc


+ 144 - 0
machine_models/databases/__init__.py

@@ -0,0 +1,144 @@
+# coding:utf-8
+'''
+数据加载、建立缓存
+'''
+from sqlalchemy.orm.session import sessionmaker
+from machine_models.databases.mysql_helper import init_db
+from machine_models.databases.mysql_helper import Model
+from machine_models.databases.mysql_helper import AnnotatedData
+from machine_models.databases.mongo_helper import MongoConnect
+from docs.config import mysql_config
+from docs.config import source_mongo_config
+from docs.config import catch_mongo_config
+from util.fs_client import FileServeClient
+from machine_models.tools import chinese2vector
+from docs.config import stopWordsPath
+from bson import ObjectId
+from docs.config import oss_file_config
+from docs.config import oss_txt_config
+from util.oss_file import OssServeClient
+
+# 链接初始化
+Fs = FileServeClient(oss_txt_config)
+File = OssServeClient(oss_file_config)
+engine = init_db(mysql_config)
+Connect = sessionmaker(bind=engine)
+session = Connect()
+source_mongo = MongoConnect(source_mongo_config)
+catch_mongo = MongoConnect(catch_mongo_config)
+
+# 加载停用词
+with open(stopWordsPath, "r") as f:
+    stop_words = [word.strip() for word in f.readlines()]
+
+
+def get_info(m_id, focus_field: list, need_doc: bool = False):
+    """
+    关注字段获取
+    :param m_id:
+    :param focus_field:
+    :param need_doc: 获取原文档
+    :return:
+    """
+    select_fields = ["title", "detail", "href", "buyer", "winner", "purchasing", "attach_text", "cut_title",
+                     "cut_detail", "cut_buyer", "cut_winner", "cut_purchasing", "cut_attach_text"]
+    fields = {field: 1 for field in select_fields}
+    c_info = catch_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
+    if c_info:
+        # 获取字段内容
+        content, add_field = select_field(c_info, focus_field)
+        # 添加缓存
+        if add_field:
+            catch_mongo.update(c_info["_id"], add_field)
+        doc = c_info if need_doc else {}
+        return content, doc
+    s_info = source_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
+    if s_info:
+        # 获取字段内容
+        content, add_field = select_field(s_info, focus_field)
+        # 添加缓存
+        s_info.update(add_field)
+        catch_mongo.insert(s_info)
+        doc = s_info if need_doc else {}
+        return content, doc
+    return "", {}
+
+
+def select_field(info, focus_field):
+    """
+    字段筛选
+    :param info:
+    :param focus_field:
+    :return:
+    """
+    content = ""  # 合并的切词文本
+    add_field = {}  # 添加的缓存切词字段
+    for field in focus_field:
+        content += " "
+        if field in info:
+            content += info[field]
+        else:
+            original_field = field.split("_", 1)[-1]
+            if original_field in info:
+                add_field[field] = get_content(original_field, info.get(original_field, ""))
+                content += add_field[field]
+    return content, add_field
+
+
+def get_content(field: dict, value: any) -> str:
+    """
+    需求字段合成文本内容
+    :param field:字段
+    :param value:值
+    :return:合并文本
+    """
+    content = ""  # 正文文本
+    if value and field == "attach_text":  # 附件单独处理
+        for ind, attach in value.items():
+            for topic, topic_detail in attach.items():
+                attach_url = topic_detail.get("attach_url", "")
+                # 加载oss附件文本
+                state, attach_txt = Fs.download_text_content(attach_url)
+                if state:
+                    content += attach_txt
+    else:
+        # 通用处理方法
+        if isinstance(value, str):
+            content = value if value else ""
+        else:
+            return ""
+    return chinese2vector(content, remove_word=["x"], stopwords=stop_words)
+
+
+def loading_train_data(project_id, focus_field):
+    """
+    加载训练数据
+    :param project_id:
+    :param focus_field:
+    :return:
+    """
+    train_data = []
+    labels = []
+    result = session.query(AnnotatedData).filter_by(projectId=project_id).order_by(AnnotatedData.id).all()
+    for row in result:
+        label, m_id = row.label, row.infoId
+        many_label = [tag.strip() for tag in label.split(",") if tag.strip()]
+        if not many_label:
+            continue
+        content, doc = get_info(m_id, focus_field)
+        # 添加训练文本
+        if content.strip():
+            train_data.append(content)
+            labels.append(many_label)
+    return train_data, labels, len(labels)
+
+
+def loading_predict_data(m_id: str, focus_field: list):
+    """
+    加载预测数据
+    :param m_id:
+    :param focus_field:
+    :return:
+    """
+    content, doc = get_info(m_id, focus_field, need_doc=True)
+    return content, doc

BIN
machine_models/databases/__pycache__/__init__.cpython-37.pyc


BIN
machine_models/databases/__pycache__/mongo_helper.cpython-37.pyc


BIN
machine_models/databases/__pycache__/mysql_helper.cpython-37.pyc


+ 76 - 0
machine_models/databases/mongo_helper.py

@@ -0,0 +1,76 @@
+# coding:utf-8
+
+"""
+mongodb 数据库连接文件
+"""
+
+from pymongo import MongoClient
+import urllib.parse as parse
+from loguru import logger
+from pymongo.errors import CursorNotFound
+
+
+class MongoConnect(object):
+    def __init__(self, config):
+        self.__host = config.get("host", "")
+        self.__user = config.get("user", "")
+        self.__password = config.get("password", "")
+        self.__database = config.get("db", "")
+        self.__col = config.get("col", "")
+        self.__charset = config.get("charset", "")
+        self.client, self.col = self.connect()
+
+    def connect(self):
+        """
+        连接数据库
+        :return:
+        """
+        # 特殊符号转义
+        self.__user = parse.quote_plus(self.__user)
+        self.__password = parse.quote_plus(self.__password)
+
+        # 连接数据库
+        if self.__user:
+            client = MongoClient(
+                "mongodb://{}:{}@{}".format(self.__user, self.__password, self.__host),
+                unicode_decode_error_handler='ignore')
+        else:
+            client = MongoClient(
+                "mongodb://{}".format(self.__host),
+                unicode_decode_error_handler='ignore')
+        col = client[self.__database][self.__col]
+        return client, col
+
+    def get_by_mid(self, m_id, fields):
+        info = {}
+        for i in range(2):
+            try:
+                info = self.col.find_one({"_id": m_id}, fields)
+                break
+            except CursorNotFound as e:
+                logger.warning(e)
+                self.client, self.col = self.connect()
+
+        return info
+
+    def insert(self, row):
+        info = {}
+        for i in range(2):
+            try:
+                info = self.col.insert_one(row)
+                break
+            except CursorNotFound as e:
+                logger.warning(e)
+                self.client, self.col = self.connect()
+        return info
+
+    def update(self, m_id, row):
+        info = {}
+        for i in range(2):
+            try:
+                info = self.col.update_one({"_id": m_id}, {"$set": row})
+                break
+            except CursorNotFound as e:
+                logger.warning(e)
+                self.client, self.col = self.connect()
+        return info

+ 67 - 0
machine_models/databases/mysql_helper.py

@@ -0,0 +1,67 @@
+# coding:utf-8
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import Column, String, Integer, Float, DATETIME
+
+Base = declarative_base()
+
+
+class AnnotatedData(Base):
+    # 训练数据对照表
+    __tablename__ = 'annotatedData'
+    id = Column(Integer, primary_key=True)
+    tenantID = Column(Integer, comment="租户ID")
+    projectId = Column(Integer, comment="项目ID")
+    infoId = Column(String(100), comment="信息ID")
+    label = Column(String(255), comment="标签标识")
+    createTime = Column(DATETIME, comment="创建时间")
+
+
+class Model(Base):
+    # 模型表
+    __tablename__ = 'model'
+    id = Column(Integer, primary_key=True)
+    createperson = Column(String(100), comment="创建人")
+    createTime = Column(DATETIME, comment='创建时间')
+    sampleData = Column(Integer, comment='模型数据量')
+    recallRate = Column(Float(11, 2), comment='召回率')
+    precision = Column(Float(11, 2), comment='精度')
+    accuracyRate = Column(Float(11, 2), comment='准确率')
+    state = Column(Integer(), comment='是否是默认模型  0 不是   1 是')
+    modelFile = Column(String(255), comment='模型文件(oss存储)')
+    projectId = Column(Integer, comment='项目标识')
+
+
+class Project(Base):
+    # 项目表
+    __tablename__ = 'project'
+    id = Column(Integer, primary_key=True)
+    name = Column(String(255), comment="项目名称")
+    labels = Column(String(255), comment='标签集')
+    type = Column(Integer, comment='多标签')
+    userId = Column(Integer, comment='用户id')
+    model = Column(Integer, comment='模型ID')
+    focusField = Column(String(255), comment='模型文件(oss存储)')
+    createTime = Column(DATETIME, comment='创建时间')
+    totalCount = Column(Integer, comment="总数")
+
+
+def init_db(mysql_config):
+    """
+根据类创建数据库表
+    :return:
+    """
+    db = mysql_config.get("db")
+    ip = mysql_config.get("ip")
+    port = mysql_config.get("port")
+    user = mysql_config.get("user")
+    pwd = mysql_config.get("pwd")
+    charset = mysql_config.get("charset")
+    engine = create_engine(
+        f"mysql+pymysql://{user}:{pwd}@{ip}:{port}/{db}?charset={charset}",
+        max_overflow=0,  # 超过连接池⼤⼩外最多创建的连接
+        pool_size=5,  # 连接池⼤⼩
+        pool_timeout=30,  # 池中没有线程最多等待的时间,否则报错
+        pool_recycle=-1  # 多久之后对线程池中的线程进⾏⼀次连接的回收
+    )
+    return engine

+ 45 - 0
machine_models/predict_model.py

@@ -0,0 +1,45 @@
+# coding:utf-8
+
+from machine_models.databases import loading_predict_data
+import joblib
+from machine_models.tools import encode2label
+from docs.config import convertField
+
+
+def predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path):
+    '''
+    预测入口
+    :param id_list: id列表
+    :param tfidf_vec: tf-idf 词典
+    :param label_type: 类型
+    :param focus_field:关注字段
+    :param target_label:目标标签
+    :param model_path:model_path
+    :return:
+    '''
+    model, le = joblib.load(model_path)
+    # 开始预测
+    focus_field = [convertField[field] for field in focus_field if field in convertField]
+    predict_result = []
+    for m_id in id_list:
+        content, doc = loading_predict_data(m_id, focus_field)
+        if not doc:
+            predict_result = [{"id": m_id, "title": "",
+                               "url": "", "labels": ""}]
+            continue
+        content_vec = tfidf_vec.transform([content])
+        # 单标签
+        if label_type == 1:
+            predict_y = model.predict(content_vec)
+            target = le.classes_[predict_y[0]] if len(predict_y) > 0 else ""
+            predict_result.append({"id": m_id, "title": doc.get("title", ""),
+                                   "url": doc.get("href", ""), "labels": target})
+
+        else:
+            # 多标签
+            predict_y = model.predict(content_vec)
+            result = encode2label(le, predict_y, target_label)
+            target = result[0] if result else ""
+            predict_result.append({"id": m_id, "title": doc.get("title", ""),
+                                   "url": doc.get("href", ""), "labels": target})
+    return predict_result

+ 120 - 0
machine_models/tools.py

@@ -0,0 +1,120 @@
+# coding:utf-8
+
+import jieba.posseg as psg
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.preprocessing import LabelEncoder, OneHotEncoder
+from sklearn.preprocessing import MultiLabelBinarizer
+from util.htmlutil.htmltag import Clean
+import jieba
+import multiprocessing
+
+jieba.enable_parallel(multiprocessing.cpu_count())
+
+
+def chinese2vectors(chinese: list, remove_word: list, stop_words: list) -> list:
+    """
+    中文转向量(多文本)
+    :param chinese:
+    :param remove_word: 去除词性 x ,n , eng
+    :param stop_words: 停用词
+    :return:
+    """
+    if not remove_word:
+        remove_word = ["x"]
+    if not stop_words:
+        stop_words = []
+    space_words = []
+    for row in chinese:
+        cut_ret = [word for word, x in psg.lcut(Clean(row)) if x not in remove_word and word not in stop_words]
+        space_words.append(" ".join(cut_ret))
+    return space_words
+
+
+def chinese2vector(chinese: str, remove_word: list, stopwords: list) -> str:
+    """
+    中文转向量(但文本)
+    :param chinese:
+    :param remove_word: 去除词性 x ,n , eng
+    :param stopwords: 停用词
+    :return:
+    """
+    if not stopwords:
+        stopwords = []
+    if not remove_word:
+        remove_word = ["x"]
+    cut_ret = [word for word, x in psg.lcut(Clean(chinese)) if x not in remove_word and word not in stopwords]
+    cut_ret = " ".join(cut_ret)
+    return cut_ret
+
+
+def tfidf(analyzer, space_words) -> tuple:
+    '''
+    tf-idf编码
+    :param analyzer:
+    :param space_words:
+    :return:
+    '''
+    tfidf_vec = TfidfVectorizer(analyzer=analyzer)
+    tfidf_ret = tfidf_vec.fit_transform(space_words)
+    return tfidf_vec, tfidf_ret
+
+
+def one2hot(space_words) -> tuple:
+    '''
+    onehot编码
+    :param space_words:
+    :return:
+    '''
+    oht = OneHotEncoder()
+    oht_ret = oht.fit_transform(space_words)
+    return oht, oht_ret
+
+
+def combine_row(target_one: [list], target_two: [list]) -> list:
+    """
+    二维元组
+    :param target_one:
+    :param target_two:
+    :return:
+    """
+    if len(target_one) != len(target_two):
+        raise ValueError("两个列表维度不同")
+    try:
+        for ind, row in enumerate(target_two):
+            target_one[ind] += row
+    except Exception as e:
+        raise e
+    return target_one
+
+
+def label2encode(labels: []) -> tuple:
+    """
+    labelEncode 标签向量化
+    :param labels:
+    :return:
+    """
+    le = LabelEncoder()
+    train_labels = []
+    for row in labels:
+        train_labels += row
+    le.fit_transform(train_labels)
+    le_ret = [le.transform(row) for row in labels]
+    le_ret = MultiLabelBinarizer().fit_transform(le_ret)
+    return le, le_ret
+
+
+def encode2label(le, predict_results, target_label: list) -> list:
+    """
+    向量转标签
+    :param le: 标签词典对象
+    :param predict_results: 预测结果
+    :param target_label: 需要的分类
+    :return:
+    """
+    detail_labels = []
+    for i, label in enumerate(predict_results):
+        if label.sum() > 0:
+            label = [i for (i, x) in enumerate(label) if x > 0]
+            label_str = ','.join([label for label in le.inverse_transform(label) if label in target_label])
+            detail_labels.append(label_str)
+    return detail_labels

+ 127 - 0
machine_models/train_model.py

@@ -0,0 +1,127 @@
+# coding:utf-8
+
+from sklearn.svm import LinearSVC
+import os
+import joblib
+from machine_models.databases import loading_train_data
+from sklearn.model_selection import train_test_split
+from machine_models.tools import label2encode
+from sklearn.multiclass import OneVsRestClassifier
+import datetime
+from docs.config import convertField
+from machine_models.databases import File
+import numpy as np
+import uuid
+from machine_models.databases.mysql_helper import Model
+
+
+def many_recall_score(y_test, y_pred):
+    '''
+    多标签召回率计算
+    :param y_test:
+    :param y_pred:
+    :return:
+    '''
+    correct_count = 0
+    total_count = 0
+    for values in zip(y_test, y_pred):
+        test_result = values[0]
+        pred_result = values[1]
+        total_count += test_result.sum()
+        correct_count += pred_result[test_result > 0].sum()
+    return correct_count / total_count
+
+
+def recall_score(y_test, y_pred):
+    '''
+    单标签召回率计算
+    :param y_test:
+    :param y_pred:
+    :return:
+    '''
+    return (y_test == y_pred).sum() / y_test.size
+
+
+def train_ones_label(x_train, y_train):
+    '''
+    单标签训练
+    :return:
+    '''
+    seed = int(datetime.datetime.now().timestamp())
+    model = LinearSVC(random_state=seed)
+    model.fit(x_train, y_train)
+    return model
+
+
+def train_many_labels(x_train, y_train):
+    '''
+    多标签训练
+    :param x_train:
+    :param y_train:
+    :return:
+    '''
+    seed = int(datetime.datetime.now().timestamp())
+    model = LinearSVC(random_state=seed)
+    clf = OneVsRestClassifier(model, n_jobs=-1)  # 根据二分类器构建多分类器
+    clf.fit(x_train, y_train)  # 训练模型
+    return clf
+
+
+def train(project_id, focus_field, tfidf_vec, label_type: int, model_dir: str):
+    """
+    模型训练
+    :param project_id:
+    :param focus_field:
+    :param tfidf_vec:
+    :param label_type:
+    :param model_dir:
+    :return:
+    """
+    # 关注字段
+    focus_field = [convertField[field] for field in focus_field if field in convertField]
+    # 读取数据
+    train_data, train_label, count = loading_train_data(project_id, focus_field)
+    # 训练数据向量化
+    train_vec = tfidf_vec.transform(train_data)
+    # label转向量
+    le, label_vec = label2encode(train_label)
+    if label_type == 1:
+        single_label = []
+        for label in label_vec:
+            for ind, tag in enumerate(label):
+                if tag == 1:
+                    single_label.append(ind)
+                    break
+        label_vec = single_label
+    x_train, x_test, y_train, y_test = train_test_split(train_vec, label_vec, test_size=0.2, shuffle=True)
+    model_path = os.path.join(model_dir, "model.model")
+    try:
+        if label_type == 1:
+            # 单标签训练
+            y_test = np.array(y_test)
+            clf = train_ones_label(x_train, y_train)
+            y_pred = clf.predict(x_test)
+            # 模型评估
+            score = (y_test == y_pred).sum() / y_test.size
+            recall = recall_score(y_test, y_pred)
+        else:
+            # 多标签训练
+            clf = train_many_labels(x_train, y_train)
+            y_pred = clf.predict(x_test)
+            # 模型评估
+            score = (y_test == y_pred).sum() / y_test.size
+            recall = many_recall_score(y_test, y_pred)
+    except Exception:
+        return False
+        # 模型储存
+    joblib.dump((clf, le), model_path)
+
+    # 上传模型
+    model_url = str(uuid.uuid4())
+    with open(model_path, "rb") as f:
+        File.upload_bytes_file(model_url, f.read())
+    f1_score = ((score * recall) / (score + recall)) * 2 if score and recall else 0
+    # 生成数据库记录对象
+    add_model = Model(sampleData=count, recallRate=recall, precision=score, accuracyRate=f1_score,
+                      modelFile=model_url)
+    return add_model

+ 31 - 0
predict_server.py

@@ -0,0 +1,31 @@
+# coding:utf-8
+"""
+ 预测客户端
+"""
+import tornado.ioloop
+import tornado.web
+import json
+from machine_models import predict_model
+from loguru import logger
+
+logger.add('./logs/predict_{time}.log', rotation='00:00')
+
+
+class MainHandler(tornado.web.RequestHandler):
+    def post(self):
+        request_params = self.request.body.decode('utf-8')
+        try:
+            request_dict = json.loads(request_params)
+            predict_result = predict_model(request_dict)
+            response_data = json.dumps(predict_result)
+        except Exception as e:
+            logger.warning(e)
+            response_data = json.dumps({"error_code": 0})
+        self.write(response_data)
+
+
+if __name__ == '__main__':
+    application = tornado.web.Application([(r"/jy_machining/predict", MainHandler), ])
+    application.listen(8686)
+    print('server start')
+    tornado.ioloop.IOLoop.instance().start()

+ 53 - 0
train_server.py

@@ -0,0 +1,53 @@
+# coding:utf-8
+'''
+训练客户端
+'''
+import nsq
+import json
+from machine_models import train_model
+from loguru import logger
+from queue import Queue
+import time
+from threading import Thread
+
+logger.add('./logs/runtime_{time}.log', rotation='00:00')
+queueSave = Queue(maxsize=10000)  # 任务队列
+
+
+def train_start():
+    # 检查任务列表,开始训练
+    global queueSave
+    while True:
+        if not queueSave.empty():
+            params = queueSave.get()
+            train_model(params)
+            continue
+        time.sleep(5)
+
+
+def handler(message):
+    '''
+    nsq队列回调函数
+    :param message:
+    :return:
+    '''
+    global queueSave
+    try:
+        body = message.body
+        body = json.loads(body)
+        queueSave.put(body)
+    except Exception as e:
+        logger.warning("start-->", e)
+    return True
+
+
+r = nsq.Reader(message_handler=handler, nsqd_tcp_addresses=['192.168.3.13:4150'], topic='machine_train',
+               channel='NO.1',
+               lookupd_poll_interval=5,
+               lookupd_connect_timeout=10000,
+               lookupd_request_timeout=10000)
+if __name__ == '__main__':
+    train_thread = Thread(target=train_start)
+    train_thread.start()
+    nsq.run()
+    train_thread.join()

BIN
util/__pycache__/file_operations.cpython-37.pyc


BIN
util/__pycache__/fs_client.cpython-37.pyc


BIN
util/__pycache__/oss_file.cpython-37.pyc


+ 62 - 0
util/file_operations.py

@@ -0,0 +1,62 @@
+# --coding:utf-8--
+'''
+文件操作
+'''
+import os
+import shutil
+from shutil import copyfile
+
+
+def save_file(file: bytes, filename):
+    '''
+    bytes保存文件
+    :param file:
+    :param filename:
+    :return:
+    '''
+    try:
+        with open(filename, "wb") as fw:
+            fw.write(file)
+            fw.close()
+            return True
+    except Exception:
+        return False
+
+
+def generate_directory(dir_path: str) -> bool:
+    '''
+    生成文件夹
+    :param dir_path:文件夹路径
+    :return:
+    '''
+    try:
+        if not os.path.exists(dir_path):
+            os.makedirs(dir_path)
+            os.chmod(dir_path, 0o777)
+    except Exception:
+        return False
+    return True
+
+
+def del_directory(dir_path: str):
+    '''
+    删除文件夹
+    :param dir_path: 文件夹路径
+    :return:
+    '''
+    if os.path.exists(dir_path):
+        shutil.rmtree(dir_path)
+
+
+def file_copy(source_path: str, target_path: str):
+    """
+    文件copy到目标文件夹
+    :param source_path: 文件原路径
+    :param target_path: 目标文件夹
+    :return:
+    """
+    try:
+        copyfile(source_path, target_path)
+    except IOError as e:
+        return False
+    return True

+ 54 - 0
util/fs_client.py

@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+import oss2
+
+
+class FileServeClient(object):
+    def __init__(self, config):
+        '''
+        文本存储客户端
+        目前使用阿里云OSS对象存储服务
+        注意:文件读写,都是以object-name为索引,请保存object-name
+        '''
+        self.auth = None
+        self.bucket = None
+        self._access_key_id = config.get("access_key_id", "")
+        self._access_key_secret = config.get("access_key_secret", "")
+        self._endpoint = config.get("endpoint", "")
+        self._bucket_name = config.get("bucket_name", "")
+        self.do_auth()
+
+    def do_auth(self):
+        '''
+        身份验证
+        '''
+        auth = oss2.Auth(self._access_key_id, self._access_key_secret)
+        bucket = oss2.Bucket(auth, self._endpoint, self._bucket_name)
+        self.auth = auth
+        self.bucket = bucket
+
+    def upload_text_file(self, object_name: str, file_content: str) -> (any, any):
+        '''
+        文本上传
+        '''
+        result = self.bucket.put_object(object_name, bytes(file_content, encoding='utf-8'))
+        status, request_id = result.status, result.request_id
+        return status, request_id
+
+    def download_text_content(self, object_name) -> (bool, str):
+        '''
+        下载文本内容
+        '''
+        object_stream = self.bucket.get_object(object_name)
+        content = object_stream.read()
+        if object_stream.client_crc == object_stream.server_crc:
+            return True, str(content, encoding='utf-8')
+        else:
+            return False, ''
+
+    def delete_object(self, object_name: str) -> (any, any):
+        '''
+        删除内容
+        '''
+        result = self.bucket.delete_object(object_name)
+        status, request_id = result.status, result.request_id
+        return status, request_id

BIN
util/htmlutil/__pycache__/htmltag.cpython-37.pyc


+ 78 - 0
util/htmlutil/htmltag.py

@@ -0,0 +1,78 @@
+# coding:utf-8
+import re
+
+br_reg = re.compile('<br[/]*>', re.I)
+table_reg = re.compile('<([/]*table[^>]*)>', re.I)
+tablebody_reg = re.compile('<([/]*tbody[^>]*)>', re.I)
+input_reg = re.compile(r'<[/]*input[^>].*?value="(.*?)"[/]>', re.I)
+tr_reg = re.compile('<([/]*tr[^>]*)>', re.I)
+th_reg = re.compile('<([/]*th[^>]*)>', re.I)
+td_reg = re.compile('<([/]*td[^>]*)>', re.I)
+p_reg = re.compile('<[/]?p>', re.I)
+othertag_reg = re.compile('<[^>]+>', re.I | re.M)
+other_symbol_reg = re.compile('[\t| ]*')
+seg_first_space_reg = re.compile('\n+\\s*', re.M)
+mul_crcf_reg = re.compile('\n+', re.M)
+brackets_reg = re.compile('\\s+')
+table_fk_reg = re.compile('(\\[table[^\\]]*\\])(.*?)(\\[/table\\])', re.M | re.S | re.I)
+
+
+##html标签清理
+def Clean(html: str):
+    html = br_reg.sub('\n', html)
+    html = table_reg.sub('', html)
+    html = tablebody_reg.sub('', html)
+    html = tr_reg.sub('\n', html)
+    html = td_reg.sub(' ', html)
+    html = p_reg.sub('\n', html)
+    html = othertag_reg.sub('', html)
+    html = other_symbol_reg.sub('', html)
+    html = seg_first_space_reg.sub('\n', html)
+    html = mul_crcf_reg.sub('\n', html)
+    return html
+
+
+def ClearSpace(txt: str):
+    return brackets_reg.sub('', txt)
+
+
+##html标签清理,但保留table表格
+def CleanKeepTable(html: str):
+    html = br_reg.sub('\n', html)
+    html = table_reg.sub(subFunc4Match, html)
+    html = tablebody_reg.sub(subFunc4Match, html)
+    html = tr_reg.sub(subFunc4Match, html)
+    html = td_reg.sub(subFunc4Match, html)
+    html = th_reg.sub(subFunc4Match, html)
+    html = p_reg.sub('\n', html)
+    html = othertag_reg.sub('', html)
+    # html = other_symbol_reg.sub('',html)
+    html = seg_first_space_reg.sub('\n', html)
+    # print("-->", html)
+    html = table_fk_reg.sub(lambda x: x.group(1) + mul_crcf_reg.sub(' ', x.group(2)) + x.group(3), html)
+    html = mul_crcf_reg.sub('\n', html)
+    # 清理table标签中的空格
+    html = html.replace('[', '<').replace(']', '>')
+    html = html.replace('<table', '\n<table').replace('</table>', '</table>\n')
+    return html
+
+
+def subFunc4Match(strmatch):
+    try:
+        if strmatch:
+            return '[%s]' % (strmatch.group(1))
+        else:
+            return ""
+    except Exception as e:
+        print(e)
+
+
+def extract_input_value(html):
+    input_reg = re.compile(r'<[/]*input[^>].*?value="(.*?)"[/]>', re.I)
+    input_r = re.compile(r'<[/]*input[^>].*?[/]>', re.I)
+    result = input_r.findall(html)
+    for input_detail in result:
+        ret = input_reg.findall(input_detail)
+        if ret:
+            html = html.replace(input_detail, f"</td><td>{ret[0]}")
+    return html

+ 70 - 0
util/oss_file.py

@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+import shutil
+import oss2
+
+
+class OssServeClient(object):
+    def __init__(self, config):
+        '''
+        文件存储客户端
+        目前使用阿里云OSS对象存储服务
+        注意:文件读写,都是以object-name为索引,请保存object-name
+        '''
+        self.auth = None
+        self.bucket = None
+        self._access_key_id = config.get("access_key_id", "")
+        self._access_key_secret = config.get("access_key_secret", "")
+        self._endpoint = config.get("endpoint", "")
+        self._bucket_name = config.get("bucket_name", "")
+        self.do_auth()
+
+    def do_auth(self):
+        '''
+        身份验证
+        '''
+        auth = oss2.Auth(self._access_key_id, self._access_key_secret)
+        bucket = oss2.Bucket(auth, self._endpoint, self._bucket_name)
+        self.auth = auth
+        self.bucket = bucket
+
+    def delete_object(self, object_name: str) -> (any, any):
+        '''
+        删除内容
+        '''
+        result = self.bucket.delete_object(object_name)
+        status, request_id = result.status, result.request_id
+        return status, request_id
+
+    def upload_bytes_file(self, object_name: str, file_content: bytes):
+        '''
+        文件上传
+        :param object_name: fid
+        :param file_content: 文件流
+        :return:
+        '''
+        result = self.bucket.put_object(object_name, file_content)
+        status, request_id = result.status, result.request_id
+        return status, request_id
+
+    def download_file(self, object_name, save_path):
+        '''
+        文件下载到本地
+        :param object_name: fid
+        :param save_path: 保存路径
+        :return:
+        '''
+        object_stream = self.bucket.get_object_to_file(object_name, save_path)
+        if object_stream.status == 200:
+            return True
+
+    def download_file_stream(self, object_name, filename):
+        '''
+        文件流下载
+        :param object_name: fid
+        :param filename: 文件路径
+        :return:
+        '''
+        object_stream = self.bucket.get_object(object_name)
+        with open(filename, 'wb') as file:
+            shutil.copyfileobj(object_stream, file)
+        return object_stream.status, filename