train_eval.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # coding: UTF-8
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from sklearn import metrics
  7. import time
  8. from utils import get_time_dif
  9. from tensorboardX import SummaryWriter
  10. # 权重初始化,默认xavier
  11. def init_network(model, method='xavier', exclude='embedding', seed=123):
  12. for name, w in model.named_parameters():
  13. if exclude not in name:
  14. if 'weight' in name:
  15. if method == 'xavier':
  16. nn.init.xavier_normal_(w)
  17. elif method == 'kaiming':
  18. nn.init.kaiming_normal_(w)
  19. else:
  20. nn.init.normal_(w)
  21. elif 'bias' in name:
  22. nn.init.constant_(w, 0)
  23. else:
  24. pass
  25. def train(config, model, train_iter, dev_iter, test_iter):
  26. start_time = time.time()
  27. model.train()
  28. optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
  29. total_batch = 0 # 记录进行到多少batch
  30. dev_best_loss = float('inf')
  31. last_improve = 0 # 记录上次验证集loss下降的batch数
  32. flag = False # 记录是否很久没有效果提升
  33. writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
  34. for epoch in range(config.num_epochs):
  35. print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
  36. for i, (trains, labels) in enumerate(train_iter):
  37. #
  38. outputs = model(trains) #)
  39. model.zero_grad() # 每次迭代,梯度清零,不然会累加
  40. loss = F.cross_entropy(outputs, labels)
  41. loss.backward()
  42. optimizer.step()
  43. if total_batch % 100 == 0:
  44. true = labels.data.cpu() # 后面要打印数据,提前送回CPU
  45. predic = torch.max(outputs.data, 1)[1].cpu() # torch.max返回最大值和索引,[1]说明只需要索引
  46. train_acc = metrics.accuracy_score(true, predic) # 计算正确率
  47. dev_acc, dev_loss = evaluate(config, model, dev_iter) # 每100个训练batch就评估模型
  48. if dev_loss < dev_best_loss: # 只要效果好,就保存模型
  49. dev_best_loss = dev_loss
  50. torch.save(model.state_dict(), config.save_path)
  51. improve = '*'
  52. last_improve = total_batch
  53. else:
  54. improve = ''
  55. time_dif = get_time_dif(start_time)
  56. msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
  57. print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
  58. writer.add_scalar("loss/train", loss.item(), total_batch)
  59. writer.add_scalar("loss/dev", dev_loss, total_batch)
  60. writer.add_scalar("acc/train", train_acc, total_batch)
  61. writer.add_scalar("acc/dev", dev_acc, total_batch)
  62. model.train()
  63. total_batch += 1
  64. if total_batch - last_improve > config.require_improvement:
  65. print("No optimization for a long time, auto-stopping...")
  66. flag = True
  67. break
  68. if flag:
  69. break
  70. writer.close()
  71. test(config, model, test_iter)
  72. def test(config, model, test_iter):
  73. # test
  74. model.load_state_dict(torch.load(config.save_path)) # 加载保存的当前最好的模型
  75. model.eval() # 评估模式,冻结dropout等层
  76. start_time = time.time()
  77. test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
  78. msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
  79. print(msg.format(test_loss, test_acc))
  80. print("Precision, Recall and F1-Score...")
  81. print(test_report)
  82. print("Confusion Matrix...")
  83. print(test_confusion)
  84. time_dif = get_time_dif(start_time)
  85. print("Time usage:", time_dif)
  86. def evaluate(config, model, data_iter, test=False):
  87. model.eval()
  88. loss_total = 0
  89. predict_all = np.array([], dtype=int)
  90. labels_all = np.array([], dtype=int)
  91. with torch.no_grad():
  92. for texts, labels in data_iter:
  93. outputs = model(texts)
  94. loss = F.cross_entropy(outputs, labels)
  95. loss_total += loss
  96. labels = labels.data.cpu().numpy() # 这里后面用到了np.append,所以需要.numpy(
  97. predic = torch.max(outputs.data, 1)[1].cpu().numpy()
  98. labels_all = np.append(labels_all, labels) # 拼接所有label
  99. predict_all = np.append(predict_all, predic) # 拼接所有predict
  100. acc = metrics.accuracy_score(labels_all, predict_all)
  101. if test:
  102. report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
  103. confusion = metrics.confusion_matrix(labels_all, predict_all)
  104. return acc, loss_total / len(data_iter), report, confusion
  105. return acc, loss_total / len(data_iter)