123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- # coding:utf-8
- '''
- 数据加载、建立缓存
- '''
- from sqlalchemy.orm.session import sessionmaker
- from machine_models.databases.mysql_helper import init_db
- from machine_models.databases.mysql_helper import Model
- from machine_models.databases.mysql_helper import AnnotatedData
- from machine_models.databases.mongo_helper import MongoConnect
- from docs.config import mysql_config
- from docs.config import source_mongo_config
- from docs.config import catch_mongo_config
- from util.fs_client import FileServeClient
- from machine_models.tools import chinese2vector
- from docs.config import stopWordsPath
- from bson import ObjectId
- from docs.config import oss_file_config
- from docs.config import oss_txt_config
- from util.oss_file import OssServeClient
- # 链接初始化
- Fs = FileServeClient(oss_txt_config)
- File = OssServeClient(oss_file_config)
- engine = init_db(mysql_config)
- Connect = sessionmaker(bind=engine)
- session = Connect()
- source_mongo = MongoConnect(source_mongo_config)
- catch_mongo = MongoConnect(catch_mongo_config)
- # 加载停用词
- with open(stopWordsPath, "r") as f:
- stop_words = [word.strip() for word in f.readlines()]
- def get_info(m_id, focus_field: list, need_doc: bool = False):
- """
- 关注字段获取
- :param m_id:
- :param focus_field:
- :param need_doc: 获取原文档
- :return:
- """
- select_fields = ["title", "detail", "href", "buyer", "winner", "purchasing", "attach_text", "cut_title",
- "cut_detail", "cut_buyer", "cut_winner", "cut_purchasing", "cut_attach_text"]
- fields = {field: 1 for field in select_fields}
- c_info = catch_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
- if c_info:
- # 获取字段内容
- content, add_field = select_field(c_info, focus_field)
- # 添加缓存
- if add_field:
- catch_mongo.update(c_info["_id"], add_field)
- doc = c_info if need_doc else {}
- return content, doc
- s_info = source_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
- if s_info:
- # 获取字段内容
- content, add_field = select_field(s_info, focus_field)
- # 添加缓存
- s_info.update(add_field)
- catch_mongo.insert(s_info)
- doc = s_info if need_doc else {}
- return content, doc
- return "", {}
- def select_field(info, focus_field):
- """
- 字段筛选
- :param info:
- :param focus_field:
- :return:
- """
- content = "" # 合并的切词文本
- add_field = {} # 添加的缓存切词字段
- for field in focus_field:
- content += " "
- if field in info:
- content += info[field]
- else:
- original_field = field.split("_", 1)[-1]
- if original_field in info:
- add_field[field] = get_content(original_field, info.get(original_field, ""))
- content += add_field[field]
- return content, add_field
- def get_content(field: dict, value: any) -> str:
- """
- 需求字段合成文本内容
- :param field:字段
- :param value:值
- :return:合并文本
- """
- content = "" # 正文文本
- if value and field == "attach_text": # 附件单独处理
- for ind, attach in value.items():
- for topic, topic_detail in attach.items():
- attach_url = topic_detail.get("attach_url", "")
- # 加载oss附件文本
- state, attach_txt = Fs.download_text_content(attach_url)
- if state:
- content += attach_txt
- else:
- # 通用处理方法
- if isinstance(value, str):
- content = value if value else ""
- else:
- return ""
- return chinese2vector(content, remove_word=["x"], stopwords=stop_words)
- def loading_train_data(project_id, focus_field):
- """
- 加载训练数据
- :param project_id:
- :param focus_field:
- :return:
- """
- train_data = []
- labels = []
- result = session.query(AnnotatedData).filter_by(projectId=project_id).order_by(AnnotatedData.id).all()
- for row in result:
- label, m_id = row.label, row.infoId
- many_label = [tag.strip() for tag in label.split(",") if tag.strip()]
- if not many_label:
- continue
- content, doc = get_info(m_id, focus_field)
- # 添加训练文本
- if content.strip():
- train_data.append(content)
- labels.append(many_label)
- return train_data, labels, len(labels)
- def loading_predict_data(m_id: str, focus_field: list):
- """
- 加载预测数据
- :param m_id:
- :param focus_field:
- :return:
- """
- content, doc = get_info(m_id, focus_field, need_doc=True)
- return content, doc
|