123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- # coding:utf-8
- # @description : 三大分类预测
- import os
- import torch
- import pickle as pkl
- from tqdm import tqdm
- import models.FastText as m
- MAX_VOCAB_SIZE = 10000
- UNK, PAD = '<UNK>', '<PAD>'
- def build_vocab(file_path, tokenizer, max_size, min_freq):
- vocab_dic = {}
- with open(file_path, 'r', encoding='UTF-8') as f:
- for line in tqdm(f):
- lin = line.strip()
- if not lin:
- continue
- content = lin.split('\t')[0]
- for word in tokenizer(content):
- vocab_dic[word] = vocab_dic.get(word, 0) + 1
- vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[
- :max_size]
- vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
- vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
- return vocab_dic
- class PredictObject(object):
- def __init__(self, config, ):
- if os.path.exists(config.vocab_path):
- self.vocab = pkl.load(open(config.vocab_path, 'rb'))
- else:
- assert IOError("词典文件不存在!!!")
- self.buckets = config.n_gram_vocab
- config.n_vocab = len(self.vocab)
- self.config = config
- self.model = m.Model(config)
- self.model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
- self.model.eval()
- @staticmethod
- def biGramHash(sequence, t, buckets):
- t1 = sequence[t - 1] if t - 1 >= 0 else 0
- return (t1 * 14918087) % buckets
- @staticmethod
- def triGramHash(sequence, t, buckets):
- t1 = sequence[t - 1] if t - 1 >= 0 else 0
- t2 = sequence[t - 2] if t - 2 >= 0 else 0
- return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
- def convert2vec(self, org_contents, pad_size=32):
- tokenizer = lambda x: [y for y in x] # char-level
- contents = []
- for token in org_contents:
- words_line = []
- seq_len = len(token)
- if pad_size:
- if len(token) < pad_size:
- token.extend([PAD] * (pad_size - len(token)))
- else:
- token = token[:pad_size]
- seq_len = pad_size
- # word to id
- for word in token:
- words_line.append(self.vocab.get(word, self.vocab.get(UNK)))
- # fasttext ngram
- bigram = []
- trigram = []
- # ------ngram------
- for i in range(pad_size):
- bigram.append(self.biGramHash(words_line, i, self.buckets))
- trigram.append(self.triGramHash(words_line, i, self.buckets))
- # -----------------
- contents.append((words_line, seq_len, bigram, trigram))
- return contents
- def to_tensor(self, datas):
- x = torch.LongTensor([_[0] for _ in datas]).to('cpu')
- bigram = torch.LongTensor([_[2] for _ in datas]).to('cpu')
- trigram = torch.LongTensor([_[3] for _ in datas]).to('cpu')
- # pad前的长度(超过pad_size的设为pad_size)
- seq_len = torch.LongTensor([_[1] for _ in datas]).to('cpu')
- return x, seq_len, bigram, trigram
- def predict(self, texts):
- with torch.no_grad():
- texts = self.convert2vec(texts, self.config.pad_size)
- tensor_text = self.to_tensor(texts)
- outputs = self.model(tensor_text)
- pre_result = torch.max(outputs.data, 1)[1]
- cod = outputs
- return pre_result, cod
- if __name__ == '__main__':
- from docs.config import config
- p = PredictObject(config)
- print(p.predict('中国人民武装警察部队七台河支队武警七台河支队更换基层套装门项目更正公告'))
|