predict_model.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # coding:utf-8
  2. from machine_models.databases import loading_predict_data
  3. import joblib
  4. from machine_models.tools import encode2label
  5. from docs.config import convertField
  6. def predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path):
  7. '''
  8. 预测入口
  9. :param id_list: id列表
  10. :param tfidf_vec: tf-idf 词典
  11. :param label_type: 类型
  12. :param focus_field:关注字段
  13. :param target_label:目标标签
  14. :param model_path:model_path
  15. :return:
  16. '''
  17. model, le = joblib.load(model_path)
  18. # 开始预测
  19. focus_field = [convertField[field] for field in focus_field if field in convertField]
  20. predict_result = []
  21. for m_id in id_list:
  22. content, doc = loading_predict_data(m_id, focus_field)
  23. if not doc:
  24. predict_result = [{"id": m_id, "title": "",
  25. "url": "", "labels": ""}]
  26. continue
  27. content_vec = tfidf_vec.transform([content])
  28. # 单标签
  29. if label_type == 1:
  30. predict_y = model.predict(content_vec)
  31. target = le.classes_[predict_y[0]] if len(predict_y) > 0 else ""
  32. predict_result.append({"id": m_id, "title": doc.get("title", ""),
  33. "url": doc.get("href", ""), "labels": target})
  34. else:
  35. # 多标签
  36. predict_y = model.predict(content_vec)
  37. result = encode2label(le, predict_y, target_label)
  38. target = result[0] if result else ""
  39. predict_result.append({"id": m_id, "title": doc.get("title", ""),
  40. "url": doc.get("href", ""), "labels": target})
  41. return predict_result