123456789101112131415161718192021222324252627282930313233 |
- # coding:utf-8
- import joblib
- class PredictModel(object):
- def __init__(self, dictionary_path, model_path, threshold_val=0.8):
- self.dictionary = joblib.load(dictionary_path)
- self.model = joblib.load(model_path)
- self._threshold_val = threshold_val
- def predict(self, contents, threshold=True):
- """
- 结果预测
- :param contents: 需要预测的文本列表
- :param threshold:
- :return:
- """
- if not contents:
- return []
- content_vec = self.dictionary.transform(contents)
- if threshold:
- predict_result = self.model.predict_proba(content_vec)
- predict_result = list(map(self.threshold, predict_result))
- else:
- predict_result = self.model.predict(content_vec)
- return predict_result
- def threshold(self, x):
- # 预测结果
- if x[1] > self._threshold_val:
- return 1
- else:
- return 0
|