# coding:utf-8 # @description : 三大分类预测 import os import torch import pickle as pkl from tqdm import tqdm import models.FastText as m MAX_VOCAB_SIZE = 10000 UNK, PAD = '', '' def build_vocab(file_path, tokenizer, max_size, min_freq): vocab_dic = {} with open(file_path, 'r', encoding='UTF-8') as f: for line in tqdm(f): lin = line.strip() if not lin: continue content = lin.split('\t')[0] for word in tokenizer(content): vocab_dic[word] = vocab_dic.get(word, 0) + 1 vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[ :max_size] vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) return vocab_dic class PredictObject(object): def __init__(self, config, ): if os.path.exists(config.vocab_path): self.vocab = pkl.load(open(config.vocab_path, 'rb')) else: assert IOError("词典文件不存在!!!") self.buckets = config.n_gram_vocab config.n_vocab = len(self.vocab) self.config = config self.model = m.Model(config) self.model.load_state_dict(torch.load(config.save_path, map_location='cpu')) self.model.eval() @staticmethod def biGramHash(sequence, t, buckets): t1 = sequence[t - 1] if t - 1 >= 0 else 0 return (t1 * 14918087) % buckets @staticmethod def triGramHash(sequence, t, buckets): t1 = sequence[t - 1] if t - 1 >= 0 else 0 t2 = sequence[t - 2] if t - 2 >= 0 else 0 return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets def convert2vec(self, org_contents, pad_size=32): tokenizer = lambda x: [y for y in x] # char-level contents = [] for token in org_contents: words_line = [] seq_len = len(token) if pad_size: if len(token) < pad_size: token.extend([PAD] * (pad_size - len(token))) else: token = token[:pad_size] seq_len = pad_size # word to id for word in token: words_line.append(self.vocab.get(word, self.vocab.get(UNK))) # fasttext ngram bigram = [] trigram = [] # ------ngram------ for i in range(pad_size): bigram.append(self.biGramHash(words_line, i, self.buckets)) trigram.append(self.triGramHash(words_line, i, self.buckets)) # ----------------- contents.append((words_line, seq_len, bigram, trigram)) return contents def to_tensor(self, datas): x = torch.LongTensor([_[0] for _ in datas]).to('cpu') bigram = torch.LongTensor([_[2] for _ in datas]).to('cpu') trigram = torch.LongTensor([_[3] for _ in datas]).to('cpu') # pad前的长度(超过pad_size的设为pad_size) seq_len = torch.LongTensor([_[1] for _ in datas]).to('cpu') return x, seq_len, bigram, trigram def predict(self, texts): with torch.no_grad(): texts = self.convert2vec(texts, self.config.pad_size) tensor_text = self.to_tensor(texts) outputs = self.model(tensor_text) pre_result = torch.max(outputs.data, 1)[1] cod = outputs return pre_result, cod if __name__ == '__main__': from docs.config import config p = PredictObject(config) print(p.predict('中国人民武装警察部队七台河支队武警七台河支队更换基层套装门项目更正公告'))