123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- # coding:utf-8
- from machine_models.databases import loading_predict_data
- import joblib
- from machine_models.tools import encode2label
- from docs.config import convertField
- def predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path):
- '''
- 预测入口
- :param id_list: id列表
- :param tfidf_vec: tf-idf 词典
- :param label_type: 类型
- :param focus_field:关注字段
- :param target_label:目标标签
- :param model_path:model_path
- :return:
- '''
- model, le = joblib.load(model_path)
- # 开始预测
- focus_field = [convertField[field] for field in focus_field if field in convertField]
- predict_result = []
- for m_id in id_list:
- content, doc = loading_predict_data(m_id, focus_field)
- if not doc:
- predict_result = [{"id": m_id, "title": "",
- "url": "", "labels": ""}]
- continue
- content_vec = tfidf_vec.transform([content])
- # 单标签
- if label_type == 1:
- predict_y = model.predict(content_vec)
- target = le.classes_[predict_y[0]] if len(predict_y) > 0 else ""
- predict_result.append({"id": m_id, "title": doc.get("title", ""),
- "url": doc.get("href", ""), "labels": target})
- else:
- # 多标签
- predict_y = model.predict(content_vec)
- result = encode2label(le, predict_y, target_label)
- target = result[0] if result else ""
- predict_result.append({"id": m_id, "title": doc.get("title", ""),
- "url": doc.get("href", ""), "labels": target})
- return predict_result
|