__init__.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # coding:utf-8
  2. '''
  3. 数据加载、建立缓存
  4. '''
  5. from sqlalchemy.orm.session import sessionmaker
  6. from machine_models.databases.mysql_helper import init_db
  7. from machine_models.databases.mysql_helper import Model
  8. from machine_models.databases.mysql_helper import AnnotatedData
  9. from machine_models.databases.mongo_helper import MongoConnect
  10. from docs.config import mysql_config
  11. from docs.config import source_mongo_config
  12. from docs.config import catch_mongo_config
  13. from util.fs_client import FileServeClient
  14. from machine_models.tools import chinese2vector
  15. from docs.config import stopWordsPath
  16. from bson import ObjectId
  17. from docs.config import oss_file_config
  18. from docs.config import oss_txt_config
  19. from util.oss_file import OssServeClient
  20. # 链接初始化
  21. Fs = FileServeClient(oss_txt_config)
  22. File = OssServeClient(oss_file_config)
  23. engine = init_db(mysql_config)
  24. Connect = sessionmaker(bind=engine)
  25. session = Connect()
  26. source_mongo = MongoConnect(source_mongo_config)
  27. catch_mongo = MongoConnect(catch_mongo_config)
  28. # 加载停用词
  29. with open(stopWordsPath, "r") as f:
  30. stop_words = [word.strip() for word in f.readlines()]
  31. def get_info(m_id, focus_field: list, need_doc: bool = False):
  32. """
  33. 关注字段获取
  34. :param m_id:
  35. :param focus_field:
  36. :param need_doc: 获取原文档
  37. :return:
  38. """
  39. select_fields = ["title", "detail", "href", "buyer", "winner", "purchasing", "attach_text", "cut_title",
  40. "cut_detail", "cut_buyer", "cut_winner", "cut_purchasing", "cut_attach_text"]
  41. fields = {field: 1 for field in select_fields}
  42. c_info = catch_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
  43. if c_info:
  44. # 获取字段内容
  45. content, add_field = select_field(c_info, focus_field)
  46. # 添加缓存
  47. if add_field:
  48. catch_mongo.update(c_info["_id"], add_field)
  49. doc = c_info if need_doc else {}
  50. return content, doc
  51. s_info = source_mongo.get_by_mid(ObjectId(m_id.strip()), fields)
  52. if s_info:
  53. # 获取字段内容
  54. content, add_field = select_field(s_info, focus_field)
  55. # 添加缓存
  56. s_info.update(add_field)
  57. catch_mongo.insert(s_info)
  58. doc = s_info if need_doc else {}
  59. return content, doc
  60. return "", {}
  61. def select_field(info, focus_field):
  62. """
  63. 字段筛选
  64. :param info:
  65. :param focus_field:
  66. :return:
  67. """
  68. content = "" # 合并的切词文本
  69. add_field = {} # 添加的缓存切词字段
  70. for field in focus_field:
  71. content += " "
  72. if field in info:
  73. content += info[field]
  74. else:
  75. original_field = field.split("_", 1)[-1]
  76. if original_field in info:
  77. add_field[field] = get_content(original_field, info.get(original_field, ""))
  78. content += add_field[field]
  79. return content, add_field
  80. def get_content(field: dict, value: any) -> str:
  81. """
  82. 需求字段合成文本内容
  83. :param field:字段
  84. :param value:值
  85. :return:合并文本
  86. """
  87. content = "" # 正文文本
  88. if value and field == "attach_text": # 附件单独处理
  89. for ind, attach in value.items():
  90. for topic, topic_detail in attach.items():
  91. attach_url = topic_detail.get("attach_url", "")
  92. # 加载oss附件文本
  93. state, attach_txt = Fs.download_text_content(attach_url)
  94. if state:
  95. content += attach_txt
  96. else:
  97. # 通用处理方法
  98. if isinstance(value, str):
  99. content = value if value else ""
  100. else:
  101. return ""
  102. return chinese2vector(content, remove_word=["x"], stopwords=stop_words)
  103. def loading_train_data(project_id, focus_field):
  104. """
  105. 加载训练数据
  106. :param project_id:
  107. :param focus_field:
  108. :return:
  109. """
  110. train_data = []
  111. labels = []
  112. result = session.query(AnnotatedData).filter_by(projectId=project_id).order_by(AnnotatedData.id).all()
  113. for row in result:
  114. label, m_id = row.label, row.infoId
  115. many_label = [tag.strip() for tag in label.split(",") if tag.strip()]
  116. if not many_label:
  117. continue
  118. content, doc = get_info(m_id, focus_field)
  119. # 添加训练文本
  120. if content.strip():
  121. train_data.append(content)
  122. labels.append(many_label)
  123. return train_data, labels, len(labels)
  124. def loading_predict_data(m_id: str, focus_field: list):
  125. """
  126. 加载预测数据
  127. :param m_id:
  128. :param focus_field:
  129. :return:
  130. """
  131. content, doc = get_info(m_id, focus_field, need_doc=True)
  132. return content, doc