# coding:utf-8 import joblib class PredictModel(object): def __init__(self, dictionary_path, model_path, threshold_val=0.8): self.dictionary = joblib.load(dictionary_path) self.model = joblib.load(model_path) self._threshold_val = threshold_val def predict(self, contents, threshold=True): """ 结果预测 :param contents: 需要预测的文本列表 :param threshold: :return: """ if not contents: return [] content_vec = self.dictionary.transform(contents) if threshold: predict_result = self.model.predict_proba(content_vec) predict_result = list(map(self.threshold, predict_result)) else: predict_result = self.model.predict(content_vec) return predict_result def threshold(self, x): # 预测结果 if x[1] > self._threshold_val: return 1 else: return 0