|
@@ -0,0 +1,144 @@
|
|
|
|
+# coding:utf-8
|
|
|
|
+'''
|
|
|
|
+数据加载、建立缓存
|
|
|
|
+'''
|
|
|
|
+from sqlalchemy.orm.session import sessionmaker
|
|
|
|
+from machine_models.databases.mysql_helper import init_db
|
|
|
|
+from machine_models.databases.mysql_helper import Model
|
|
|
|
+from machine_models.databases.mysql_helper import AnnotatedData
|
|
|
|
+from machine_models.databases.mongo_helper import MongoConnect
|
|
|
|
+from docs.config import mysql_config
|
|
|
|
+from docs.config import source_mongo_config
|
|
|
|
+from docs.config import catch_mongo_config
|
|
|
|
+from util.fs_client import FileServeClient
|
|
|
|
+from machine_models.tools import chinese2vector
|
|
|
|
+from docs.config import stopWordsPath
|
|
|
|
+from bson import ObjectId
|
|
|
|
+from docs.config import oss_file_config
|
|
|
|
+from docs.config import oss_txt_config
|
|
|
|
+from util.oss_file import OssServeClient
|
|
|
|
+
|
|
|
|
+# 链接初始化
|
|
|
|
+Fs = FileServeClient(oss_txt_config)
|
|
|
|
+File = OssServeClient(oss_file_config)
|
|
|
|
+engine = init_db(mysql_config)
|
|
|
|
+Connect = sessionmaker(bind=engine)
|
|
|
|
+session = Connect()
|
|
|
|
+source_mongo = MongoConnect(source_mongo_config)
|
|
|
|
+catch_mongo = MongoConnect(catch_mongo_config)
|
|
|
|
+
|
|
|
|
+# 加载停用词
|
|
|
|
+with open(stopWordsPath, "r") as f:
|
|
|
|
+ stop_words = [word.strip() for word in f.readlines()]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_info(m_id, focus_field: list, need_doc: bool = False):
|
|
|
|
+ """
|
|
|
|
+ 关注字段获取
|
|
|
|
+ :param m_id:
|
|
|
|
+ :param focus_field:
|
|
|
|
+ :param need_doc: 获取原文档
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ select_fields = ["title", "detail", "href", "buyer", "winner", "purchasing", "attach_text", "cut_title",
|
|
|
|
+ "cut_detail", "cut_buyer", "cut_winner", "cut_purchasing", "cut_attach_text"]
|
|
|
|
+ fields = {field: 1 for field in select_fields}
|
|
|
|
+ c_info = catch_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
|
|
|
|
+ if c_info:
|
|
|
|
+ # 获取字段内容
|
|
|
|
+ content, add_field = select_field(c_info, focus_field)
|
|
|
|
+ # 添加缓存
|
|
|
|
+ if add_field:
|
|
|
|
+ catch_mongo.update(c_info["_id"], add_field)
|
|
|
|
+ doc = c_info if need_doc else {}
|
|
|
|
+ return content, doc
|
|
|
|
+ s_info = source_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
|
|
|
|
+ if s_info:
|
|
|
|
+ # 获取字段内容
|
|
|
|
+ content, add_field = select_field(s_info, focus_field)
|
|
|
|
+ # 添加缓存
|
|
|
|
+ s_info.update(add_field)
|
|
|
|
+ catch_mongo.insert(s_info)
|
|
|
|
+ doc = s_info if need_doc else {}
|
|
|
|
+ return content, doc
|
|
|
|
+ return "", {}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def select_field(info, focus_field):
|
|
|
|
+ """
|
|
|
|
+ 字段筛选
|
|
|
|
+ :param info:
|
|
|
|
+ :param focus_field:
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ content = "" # 合并的切词文本
|
|
|
|
+ add_field = {} # 添加的缓存切词字段
|
|
|
|
+ for field in focus_field:
|
|
|
|
+ content += " "
|
|
|
|
+ if field in info:
|
|
|
|
+ content += info[field]
|
|
|
|
+ else:
|
|
|
|
+ original_field = field.split("_", 1)[-1]
|
|
|
|
+ if original_field in info:
|
|
|
|
+ add_field[field] = get_content(original_field, info.get(original_field, ""))
|
|
|
|
+ content += add_field[field]
|
|
|
|
+ return content, add_field
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_content(field: dict, value: any) -> str:
|
|
|
|
+ """
|
|
|
|
+ 需求字段合成文本内容
|
|
|
|
+ :param field:字段
|
|
|
|
+ :param value:值
|
|
|
|
+ :return:合并文本
|
|
|
|
+ """
|
|
|
|
+ content = "" # 正文文本
|
|
|
|
+ if value and field == "attach_text": # 附件单独处理
|
|
|
|
+ for ind, attach in value.items():
|
|
|
|
+ for topic, topic_detail in attach.items():
|
|
|
|
+ attach_url = topic_detail.get("attach_url", "")
|
|
|
|
+ # 加载oss附件文本
|
|
|
|
+ state, attach_txt = Fs.download_text_content(attach_url)
|
|
|
|
+ if state:
|
|
|
|
+ content += attach_txt
|
|
|
|
+ else:
|
|
|
|
+ # 通用处理方法
|
|
|
|
+ if isinstance(value, str):
|
|
|
|
+ content = value if value else ""
|
|
|
|
+ else:
|
|
|
|
+ return ""
|
|
|
|
+ return chinese2vector(content, remove_word=["x"], stopwords=stop_words)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def loading_train_data(project_id, focus_field):
|
|
|
|
+ """
|
|
|
|
+ 加载训练数据
|
|
|
|
+ :param project_id:
|
|
|
|
+ :param focus_field:
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ train_data = []
|
|
|
|
+ labels = []
|
|
|
|
+ result = session.query(AnnotatedData).filter_by(projectId=project_id).order_by(AnnotatedData.id).all()
|
|
|
|
+ for row in result:
|
|
|
|
+ label, m_id = row.label, row.infoId
|
|
|
|
+ many_label = [tag.strip() for tag in label.split(",") if tag.strip()]
|
|
|
|
+ if not many_label:
|
|
|
|
+ continue
|
|
|
|
+ content, doc = get_info(m_id, focus_field)
|
|
|
|
+ # 添加训练文本
|
|
|
|
+ if content.strip():
|
|
|
|
+ train_data.append(content)
|
|
|
|
+ labels.append(many_label)
|
|
|
|
+ return train_data, labels, len(labels)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def loading_predict_data(m_id: str, focus_field: list):
|
|
|
|
+ """
|
|
|
|
+ 加载预测数据
|
|
|
|
+ :param m_id:
|
|
|
|
+ :param focus_field:
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ content, doc = get_info(m_id, focus_field, need_doc=True)
|
|
|
|
+ return content, doc
|