train_model.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # coding:utf-8
  2. from sklearn.svm import LinearSVC
  3. import os
  4. import joblib
  5. from machine_models.databases import loading_train_data
  6. from sklearn.model_selection import train_test_split
  7. from machine_models.tools import label2encode
  8. from sklearn.multiclass import OneVsRestClassifier
  9. import datetime
  10. from docs.config import convertField
  11. from machine_models.databases import File
  12. import numpy as np
  13. import uuid
  14. from machine_models.databases.mysql_helper import Model
  15. def many_recall_score(y_test, y_pred):
  16. '''
  17. 多标签召回率计算
  18. :param y_test:
  19. :param y_pred:
  20. :return:
  21. '''
  22. correct_count = 0
  23. total_count = 0
  24. for values in zip(y_test, y_pred):
  25. test_result = values[0]
  26. pred_result = values[1]
  27. total_count += test_result.sum()
  28. correct_count += pred_result[test_result > 0].sum()
  29. return correct_count / total_count
  30. def recall_score(y_test, y_pred):
  31. '''
  32. 单标签召回率计算
  33. :param y_test:
  34. :param y_pred:
  35. :return:
  36. '''
  37. return (y_test == y_pred).sum() / y_test.size
  38. def train_ones_label(x_train, y_train):
  39. '''
  40. 单标签训练
  41. :return:
  42. '''
  43. seed = int(datetime.datetime.now().timestamp())
  44. model = LinearSVC(random_state=seed)
  45. model.fit(x_train, y_train)
  46. return model
  47. def train_many_labels(x_train, y_train):
  48. '''
  49. 多标签训练
  50. :param x_train:
  51. :param y_train:
  52. :return:
  53. '''
  54. seed = int(datetime.datetime.now().timestamp())
  55. model = LinearSVC(random_state=seed)
  56. clf = OneVsRestClassifier(model, n_jobs=-1) # 根据二分类器构建多分类器
  57. clf.fit(x_train, y_train) # 训练模型
  58. return clf
  59. def train(project_id, focus_field, tfidf_vec, label_type: int, model_dir: str):
  60. """
  61. 模型训练
  62. :param project_id:
  63. :param focus_field:
  64. :param tfidf_vec:
  65. :param label_type:
  66. :param model_dir:
  67. :return:
  68. """
  69. # 关注字段
  70. focus_field = [convertField[field] for field in focus_field if field in convertField]
  71. # 读取数据
  72. train_data, train_label, count = loading_train_data(project_id, focus_field)
  73. # 训练数据向量化
  74. train_vec = tfidf_vec.transform(train_data)
  75. # label转向量
  76. le, label_vec = label2encode(train_label)
  77. if label_type == 1:
  78. single_label = []
  79. for label in label_vec:
  80. for ind, tag in enumerate(label):
  81. if tag == 1:
  82. single_label.append(ind)
  83. break
  84. label_vec = single_label
  85. x_train, x_test, y_train, y_test = train_test_split(train_vec, label_vec, test_size=0.2, shuffle=True)
  86. model_path = os.path.join(model_dir, "model.model")
  87. try:
  88. if label_type == 1:
  89. # 单标签训练
  90. y_test = np.array(y_test)
  91. clf = train_ones_label(x_train, y_train)
  92. y_pred = clf.predict(x_test)
  93. # 模型评估
  94. score = (y_test == y_pred).sum() / y_test.size
  95. recall = recall_score(y_test, y_pred)
  96. else:
  97. # 多标签训练
  98. clf = train_many_labels(x_train, y_train)
  99. y_pred = clf.predict(x_test)
  100. # 模型评估
  101. score = (y_test == y_pred).sum() / y_test.size
  102. recall = many_recall_score(y_test, y_pred)
  103. except Exception:
  104. return False
  105. # 模型储存
  106. joblib.dump((clf, le), model_path)
  107. # 上传模型
  108. model_url = str(uuid.uuid4())
  109. with open(model_path, "rb") as f:
  110. File.upload_bytes_file(model_url, f.read())
  111. f1_score = ((score * recall) / (score + recall)) * 2 if score and recall else 0
  112. # 生成数据库记录对象
  113. add_model = Model(sampleData=count, recallRate=recall, precision=score, accuracyRate=f1_score,
  114. modelFile=model_url)
  115. return add_model