# coding:utf-8 from machine_models.databases import loading_predict_data import joblib from machine_models.tools import encode2label from docs.config import convertField def predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path): ''' 预测入口 :param id_list: id列表 :param tfidf_vec: tf-idf 词典 :param label_type: 类型 :param focus_field:关注字段 :param target_label:目标标签 :param model_path:model_path :return: ''' model, le = joblib.load(model_path) # 开始预测 focus_field = [convertField[field] for field in focus_field if field in convertField] predict_result = [] for m_id in id_list: content, doc = loading_predict_data(m_id, focus_field) if not doc: predict_result = [{"id": m_id, "title": "", "url": "", "labels": ""}] continue content_vec = tfidf_vec.transform([content]) # 单标签 if label_type == 1: predict_y = model.predict(content_vec) target = le.classes_[predict_y[0]] if len(predict_y) > 0 else "" predict_result.append({"id": m_id, "title": doc.get("title", ""), "url": doc.get("href", ""), "labels": target}) else: # 多标签 predict_y = model.predict(content_vec) result = encode2label(le, predict_y, target_label) target = result[0] if result else "" predict_result.append({"id": m_id, "title": doc.get("title", ""), "url": doc.get("href", ""), "labels": target}) return predict_result