predict.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # coding:utf-8
  2. # @description : 三大分类预测
  3. import os
  4. import torch
  5. import pickle as pkl
  6. from tqdm import tqdm
  7. import models.FastText as m
  8. MAX_VOCAB_SIZE = 10000
  9. UNK, PAD = '<UNK>', '<PAD>'
  10. def build_vocab(file_path, tokenizer, max_size, min_freq):
  11. vocab_dic = {}
  12. with open(file_path, 'r', encoding='UTF-8') as f:
  13. for line in tqdm(f):
  14. lin = line.strip()
  15. if not lin:
  16. continue
  17. content = lin.split('\t')[0]
  18. for word in tokenizer(content):
  19. vocab_dic[word] = vocab_dic.get(word, 0) + 1
  20. vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[
  21. :max_size]
  22. vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
  23. vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
  24. return vocab_dic
  25. class PredictObject(object):
  26. def __init__(self, config, ):
  27. if os.path.exists(config.vocab_path):
  28. self.vocab = pkl.load(open(config.vocab_path, 'rb'))
  29. else:
  30. assert IOError("词典文件不存在!!!")
  31. self.buckets = config.n_gram_vocab
  32. config.n_vocab = len(self.vocab)
  33. self.config = config
  34. self.model = m.Model(config)
  35. self.model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
  36. self.model.eval()
  37. @staticmethod
  38. def biGramHash(sequence, t, buckets):
  39. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  40. return (t1 * 14918087) % buckets
  41. @staticmethod
  42. def triGramHash(sequence, t, buckets):
  43. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  44. t2 = sequence[t - 2] if t - 2 >= 0 else 0
  45. return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
  46. def convert2vec(self, org_contents, pad_size=32):
  47. tokenizer = lambda x: [y for y in x] # char-level
  48. contents = []
  49. for token in org_contents:
  50. words_line = []
  51. seq_len = len(token)
  52. if pad_size:
  53. if len(token) < pad_size:
  54. token.extend([PAD] * (pad_size - len(token)))
  55. else:
  56. token = token[:pad_size]
  57. seq_len = pad_size
  58. # word to id
  59. for word in token:
  60. words_line.append(self.vocab.get(word, self.vocab.get(UNK)))
  61. # fasttext ngram
  62. bigram = []
  63. trigram = []
  64. # ------ngram------
  65. for i in range(pad_size):
  66. bigram.append(self.biGramHash(words_line, i, self.buckets))
  67. trigram.append(self.triGramHash(words_line, i, self.buckets))
  68. # -----------------
  69. contents.append((words_line, seq_len, bigram, trigram))
  70. return contents
  71. def to_tensor(self, datas):
  72. x = torch.LongTensor([_[0] for _ in datas]).to('cpu')
  73. bigram = torch.LongTensor([_[2] for _ in datas]).to('cpu')
  74. trigram = torch.LongTensor([_[3] for _ in datas]).to('cpu')
  75. # pad前的长度(超过pad_size的设为pad_size)
  76. seq_len = torch.LongTensor([_[1] for _ in datas]).to('cpu')
  77. return x, seq_len, bigram, trigram
  78. def predict(self, texts):
  79. with torch.no_grad():
  80. texts = self.convert2vec(texts, self.config.pad_size)
  81. tensor_text = self.to_tensor(texts)
  82. outputs = self.model(tensor_text)
  83. pre_result = torch.max(outputs.data, 1)[1]
  84. cod = outputs
  85. return pre_result, cod
  86. if __name__ == '__main__':
  87. from docs.config import config
  88. p = PredictObject(config)
  89. print(p.predict('中国人民武装警察部队七台河支队武警七台河支队更换基层套装门项目更正公告'))