predict.py 983 B

123456789101112131415161718192021222324252627282930313233
  1. # coding:utf-8
  2. import joblib
  3. class PredictModel(object):
  4. def __init__(self, dictionary_path, model_path, threshold_val=0.8):
  5. self.dictionary = joblib.load(dictionary_path)
  6. self.model = joblib.load(model_path)
  7. self._threshold_val = threshold_val
  8. def predict(self, contents, threshold=True):
  9. """
  10. 结果预测
  11. :param contents: 需要预测的文本列表
  12. :param threshold:
  13. :return:
  14. """
  15. if not contents:
  16. return []
  17. content_vec = self.dictionary.transform(contents)
  18. if threshold:
  19. predict_result = self.model.predict_proba(content_vec)
  20. predict_result = list(map(self.threshold, predict_result))
  21. else:
  22. predict_result = self.model.predict(content_vec)
  23. return predict_result
  24. def threshold(self, x):
  25. # 预测结果
  26. if x[1] > self._threshold_val:
  27. return 1
  28. else:
  29. return 0