table_field_category.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # 文本分类 (多元分类)
  4. import pandas as pd
  5. from sklearn.utils import shuffle
  6. import torch as t
  7. from util.dictionary import Dictionary
  8. import jieba
  9. import os
  10. import joblib
  11. from docs.config import ai2config
  12. table_field_config = ai2config["table_field_config"]
  13. jieba.add_word('型号')
  14. jieba.add_word('规格')
  15. jieba.add_word('设备')
  16. jieba.add_word('名称')
  17. EMBED_DIM = 300
  18. vocab_file = table_field_config['vocab_file']
  19. ct = Dictionary(stopwords=[])
  20. class TableFieldCategoryModel(object):
  21. def __init__(self, config):
  22. self._corpus_path = config.get("corpus_path")
  23. self._epochs = config.get("epochs", 500)
  24. self._lr = config.get("lr", 1e-3)
  25. self._momentum = config.get("momentum", 0.5)
  26. self._output = config.get("output", 2)
  27. self._vocab_file = config.get("vocab_file")
  28. self._model_path = config.get("model_path")
  29. self._vocab_label = config.get("vocab_label")
  30. self._model = ""
  31. self.label2decode = ""
  32. def create_train_data(self):
  33. """
  34. 训练数据生成
  35. :return:
  36. """
  37. if not os.path.exists(self._corpus_path):
  38. raise FileExistsError("文件不存在")
  39. label_data = pd.read_csv(self._corpus_path)
  40. label_data.drop_duplicates(['corpus'], inplace=True)
  41. label_data = shuffle(label_data)
  42. corpus = label_data['corpus'].values
  43. category = label_data['label'].values
  44. ct.append_vocab(text=corpus, need_cut=True)
  45. ct.build_dictionary(tfidf_limit=1e-6, vocab_file=vocab_file)
  46. # 查验数据
  47. print('词量:', len(ct.dictionary))
  48. global EMBED_DIM
  49. EMBED_DIM = len(ct.dictionary)
  50. x_vector = ct.vector_corpus(corpus=corpus, dim=EMBED_DIM, use_tfidf=False, return_type='one_hot')
  51. label2encode, label2decode = {}, {}
  52. category = [str(c).split(';') for c in category]
  53. for c in category:
  54. for w in c:
  55. if not (w in label2encode):
  56. label2encode[w] = len(label2encode)
  57. label2decode = dict(list(zip(label2encode.values(), label2encode.keys())))
  58. y = []
  59. label_len = len(label2encode)
  60. print('classes::', label_len)
  61. for c in category:
  62. y1 = [0] * label_len
  63. for w in c:
  64. y1[label2encode[w]] = 1
  65. y.append(y1)
  66. joblib.dump((label2encode, label2decode), self._vocab_label)
  67. return x_vector, y, label2encode, label2decode
  68. @staticmethod
  69. def make_nn(input_size, output_size):
  70. return t.nn.Sequential(
  71. t.nn.Linear(input_size, input_size // 2),
  72. t.nn.ReLU(inplace=True),
  73. t.nn.Linear(input_size // 2, input_size // 4),
  74. t.nn.ReLU(inplace=True),
  75. t.nn.Linear(input_size // 4, output_size),
  76. t.nn.Sigmoid() if output_size == 2 else t.nn.Softmax(dim=1),
  77. )
  78. def train(self):
  79. x_vector, y, le, _ = self.create_train_data()
  80. print('classes:', len(le))
  81. mlp = self.make_nn(EMBED_DIM, len(le))
  82. optimizer = t.optim.SGD(mlp.parameters(), lr=self._lr, momentum=self._momentum)
  83. lossfunc = t.nn.BCELoss()
  84. x_vector = t.autograd.Variable(t.tensor(x_vector).float())
  85. y = t.autograd.Variable(t.tensor(y).float())
  86. for epoch in range(self._epochs):
  87. outputs = mlp(x_vector)
  88. optimizer.zero_grad()
  89. acc(outputs, y)
  90. loss = lossfunc(outputs, y)
  91. loss.backward()
  92. optimizer.step()
  93. print('epoch:', epoch, 'loss:', loss.data.numpy())
  94. if loss.data.numpy() < 0.01:
  95. print('提前结束训练')
  96. break
  97. t.save(mlp.state_dict(), self._model_path)
  98. def predict(self, corups):
  99. global EMBED_DIM
  100. if not self._model:
  101. ct.load_dictionary(vocab_file=vocab_file)
  102. EMBED_DIM = len(ct.dictionary)
  103. le, self.label2decode = joblib.load(self._vocab_label)
  104. self._model = self.make_nn(EMBED_DIM, len(self.label2decode))
  105. self._model.load_state_dict(t.load(self._model_path))
  106. x_vector = ct.vector_corpus(corups, dim=EMBED_DIM, return_type='one_hot', use_tfidf=False)
  107. x = t.tensor(x_vector).float()
  108. x = x.view(x.size()[0], -1)
  109. x = t.autograd.Variable(x)
  110. y = self._model(x)
  111. ret = []
  112. for r in y.data.numpy():
  113. row_label = []
  114. for i, w in enumerate(r):
  115. if w >= 0.20:
  116. row_label.append(self.label2decode[i])
  117. ret.append(row_label)
  118. return ret
  119. def val(self):
  120. global EMBED_DIM
  121. ct.load_dictionary(vocab_file=vocab_file)
  122. EMBED_DIM = len(ct.dictionary)
  123. le, de = joblib.load(self._vocab_file)
  124. label_data = pd.read_csv(self._corpus_path)
  125. label_data.drop_duplicates(['corpus'], inplace=True)
  126. label_data = shuffle(label_data)
  127. corpus = label_data['corpus'].values
  128. category = label_data['label'].values
  129. x_vector = ct.vector_corpus(corpus, dim=EMBED_DIM, return_type='one_hot', use_tfidf=False)
  130. mlp = self.make_nn(EMBED_DIM, len(de))
  131. mlp.load_state_dict(t.load(self._model_path))
  132. x = t.tensor(x_vector).float()
  133. x = x.view(x.size()[0], -1)
  134. x = t.autograd.Variable(x)
  135. y = mlp(x)
  136. ret = []
  137. for r in y.data.numpy():
  138. row_label = []
  139. for i, w in enumerate(r):
  140. if w >= 0.35:
  141. row_label.append(de[i])
  142. ret.append(row_label)
  143. print(list(zip(category, ret)))
  144. count = 0
  145. for i in list(zip(category, ret)):
  146. if len(i[1]) == 1:
  147. if i[0] == i[1][0]:
  148. count += 1
  149. else:
  150. print(i)
  151. if i[0].split(';') == i[1]:
  152. count += 1
  153. print(count / len(category))
  154. def acc(outputs, y):
  155. b = outputs > 0.25
  156. b = b.data.numpy()
  157. a = y > 0.5
  158. a = a.data.numpy()
  159. count = 0
  160. print(len(a))
  161. for i in range(len(a)):
  162. if (a[i] == b[i]).all():
  163. count += 1
  164. print("**********************************", count / len(a))