# coding:utf-8 ''' 数据加载、建立缓存 ''' from sqlalchemy.orm.session import sessionmaker from machine_models.databases.mysql_helper import init_db from machine_models.databases.mysql_helper import Model from machine_models.databases.mysql_helper import AnnotatedData from machine_models.databases.mongo_helper import MongoConnect from docs.config import mysql_config from docs.config import source_mongo_config from docs.config import catch_mongo_config from util.fs_client import FileServeClient from machine_models.tools import chinese2vector from docs.config import stopWordsPath from bson import ObjectId from docs.config import oss_file_config from docs.config import oss_txt_config from util.oss_file import OssServeClient # 链接初始化 Fs = FileServeClient(oss_txt_config) File = OssServeClient(oss_file_config) engine = init_db(mysql_config) Connect = sessionmaker(bind=engine) session = Connect() source_mongo = MongoConnect(source_mongo_config) catch_mongo = MongoConnect(catch_mongo_config) # 加载停用词 with open(stopWordsPath, "r") as f: stop_words = [word.strip() for word in f.readlines()] def get_info(m_id, focus_field: list, need_doc: bool = False): """ 关注字段获取 :param m_id: :param focus_field: :param need_doc: 获取原文档 :return: """ select_fields = ["title", "detail", "href", "buyer", "winner", "purchasing", "attach_text", "cut_title", "cut_detail", "cut_buyer", "cut_winner", "cut_purchasing", "cut_attach_text"] fields = {field: 1 for field in select_fields} c_info = catch_mongo.get_by_mid(ObjectId(m_id.strip()), fields) if c_info: # 获取字段内容 content, add_field = select_field(c_info, focus_field) # 添加缓存 if add_field: catch_mongo.update(c_info["_id"], add_field) doc = c_info if need_doc else {} return content, doc s_info = source_mongo.get_by_mid(ObjectId(m_id.strip()), fields) if s_info: # 获取字段内容 content, add_field = select_field(s_info, focus_field) # 添加缓存 s_info.update(add_field) catch_mongo.insert(s_info) doc = s_info if need_doc else {} return content, doc return "", {} def select_field(info, focus_field): """ 字段筛选 :param info: :param focus_field: :return: """ content = "" # 合并的切词文本 add_field = {} # 添加的缓存切词字段 for field in focus_field: content += " " if field in info: content += info[field] else: original_field = field.split("_", 1)[-1] if original_field in info: add_field[field] = get_content(original_field, info.get(original_field, "")) content += add_field[field] return content, add_field def get_content(field: dict, value: any) -> str: """ 需求字段合成文本内容 :param field:字段 :param value:值 :return:合并文本 """ content = "" # 正文文本 if value and field == "attach_text": # 附件单独处理 for ind, attach in value.items(): for topic, topic_detail in attach.items(): attach_url = topic_detail.get("attach_url", "") # 加载oss附件文本 state, attach_txt = Fs.download_text_content(attach_url) if state: content += attach_txt else: # 通用处理方法 if isinstance(value, str): content = value if value else "" else: return "" return chinese2vector(content, remove_word=["x"], stopwords=stop_words) def loading_train_data(project_id, focus_field): """ 加载训练数据 :param project_id: :param focus_field: :return: """ train_data = [] labels = [] result = session.query(AnnotatedData).filter_by(projectId=project_id).order_by(AnnotatedData.id).all() for row in result: label, m_id = row.label, row.infoId many_label = [tag.strip() for tag in label.split(",") if tag.strip()] if not many_label: continue content, doc = get_info(m_id, focus_field) # 添加训练文本 if content.strip(): train_data.append(content) labels.append(many_label) return train_data, labels, len(labels) def loading_predict_data(m_id: str, focus_field: list): """ 加载预测数据 :param m_id: :param focus_field: :return: """ content, doc = get_info(m_id, focus_field, need_doc=True) return content, doc