__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # coding:utf-8
  2. from docs.config import dictionaryPath
  3. from docs.config import dictionaryUrl
  4. from machine_models.databases import File
  5. from machine_models.train_model import train
  6. from machine_models.databases.mysql_helper import Model
  7. from machine_models.databases import session
  8. from machine_models.databases.mysql_helper import Project
  9. from machine_models.predict_model import predict
  10. from util.file_operations import generate_directory, del_directory
  11. from docs.config import baseDir
  12. import os
  13. import joblib
  14. import uuid
  15. import datetime
  16. # 词典文件加载,只加载一次
  17. if not os.path.exists(dictionaryPath):
  18. status = File.download_file(dictionaryUrl, dictionaryPath)
  19. if not status:
  20. raise ValueError("词典文件下载失败")
  21. tfidf_vec = joblib.load(dictionaryPath)
  22. def train_fail(project_id, user_id):
  23. '''
  24. 记录失败日志
  25. :param project_id:
  26. :param user_id:
  27. :return:
  28. '''
  29. fail_model = Model(state=2, projectId=project_id, createperson=user_id,
  30. createTime=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  31. session.add(fail_model)
  32. session.commit()
  33. return False
  34. def train_model(request_params: dict):
  35. # 清空数据库链接对象缓存
  36. session.expire_all()
  37. session.commit()
  38. # 获取训练项目数据
  39. project_id = request_params.get("id")
  40. user_id = request_params.get("userId", "")
  41. label_type = request_params.get("type", 1)
  42. fields = request_params.get("fields", "")
  43. model_dir = ""
  44. try:
  45. # 不存在项目Id
  46. if not project_id:
  47. return train_fail(project_id, user_id)
  48. model_dir = os.path.join(baseDir, str(uuid.uuid4()))
  49. dir_status = generate_directory(model_dir)
  50. # 文件夹生成错误
  51. if not dir_status:
  52. return train_fail(project_id, user_id)
  53. # 开始训练
  54. model_detail = train(project_id, fields.split(","), tfidf_vec, label_type, model_dir)
  55. # 训练失败
  56. if not model_detail:
  57. return train_fail(project_id, user_id)
  58. # 训练成功记录
  59. model_detail.projectId = project_id
  60. model_detail.state = 0
  61. model_detail.createperson = user_id
  62. model_detail.createTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  63. session.add(model_detail)
  64. session.commit()
  65. # 清空本次训练生成文件
  66. del_directory(model_dir)
  67. return True
  68. except Exception as e:
  69. train_fail(project_id, user_id)
  70. # 清空本次训练生成文件
  71. if model_dir and os.path.exists(model_dir):
  72. del_directory(model_dir)
  73. return train_fail(project_id, user_id)
  74. def predict_model(request_params):
  75. # 清空本地数据库缓存
  76. session.expire_all()
  77. session.commit()
  78. # 获取预测参数
  79. project_id = request_params.get("id", -1)
  80. id_list = request_params.get("id_list", [])
  81. model_id = request_params.get("model_id", -1)
  82. project_info = session.query(Project).filter_by(id=project_id).first()
  83. # 查询项目信息
  84. if not project_info:
  85. return {"error_code": 0, "error_message": f"项目信息不存在--> {project_id}"}
  86. focus_field, target_label, label_type = project_info.focusField.split(
  87. ","), project_info.labels.split(","), project_info.type
  88. # 查询模型信息
  89. model_info = session.query(Model).filter_by(id=model_id).first()
  90. print(type(model_id), "-->", model_id)
  91. if not model_info:
  92. return {"error_code": 0, "error_message": f"模型信息不存在--> {model_id}"}
  93. # 加载模型
  94. model_url = model_info.modelFile
  95. model_dir = os.path.join(baseDir, model_url)
  96. model_path = os.path.join(model_dir, "model.model")
  97. if not os.path.exists(model_path):
  98. dir_status = generate_directory(model_dir)
  99. if not dir_status:
  100. return {"error_code": 0, "error_message": f"文件夹创建失败,请检查存储设备-->"}
  101. status = File.download_file(model_url, model_path)
  102. if not status:
  103. return {"error_code": 0, "error_message": f"oss储存模型加载失败--> {model_id}"}
  104. try:
  105. data = predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path)
  106. except Exception as e:
  107. print(e)
  108. return {"error_code": 0, "error_message": "预测过程出错"}
  109. # 清空缓存
  110. if os.path.exists(model_dir):
  111. del_directory(model_dir)
  112. return {"error_code": 1, "data": data}