123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # coding:utf-8
- from docs.config import dictionaryPath
- from docs.config import dictionaryUrl
- from machine_models.databases import File
- from machine_models.train_model import train
- from machine_models.databases.mysql_helper import Model
- from machine_models.databases import session
- from machine_models.databases.mysql_helper import Project
- from machine_models.predict_model import predict
- from util.file_operations import generate_directory, del_directory
- from docs.config import baseDir
- import os
- import joblib
- import uuid
- import datetime
- # 词典文件加载,只加载一次
- if not os.path.exists(dictionaryPath):
- status = File.download_file(dictionaryUrl, dictionaryPath)
- if not status:
- raise ValueError("词典文件下载失败")
- tfidf_vec = joblib.load(dictionaryPath)
- def train_fail(project_id, user_id):
- '''
- 记录失败日志
- :param project_id:
- :param user_id:
- :return:
- '''
- fail_model = Model(state=2, projectId=project_id, createperson=user_id,
- createTime=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
- session.add(fail_model)
- session.commit()
- return False
- def train_model(request_params: dict):
- # 清空数据库链接对象缓存
- session.expire_all()
- session.commit()
- # 获取训练项目数据
- project_id = request_params.get("id")
- user_id = request_params.get("userId", "")
- label_type = request_params.get("type", 1)
- fields = request_params.get("fields", "")
- model_dir = ""
- try:
- # 不存在项目Id
- if not project_id:
- return train_fail(project_id, user_id)
- model_dir = os.path.join(baseDir, str(uuid.uuid4()))
- dir_status = generate_directory(model_dir)
- # 文件夹生成错误
- if not dir_status:
- return train_fail(project_id, user_id)
- # 开始训练
- model_detail = train(project_id, fields.split(","), tfidf_vec, label_type, model_dir)
- # 训练失败
- if not model_detail:
- return train_fail(project_id, user_id)
- # 训练成功记录
- model_detail.projectId = project_id
- model_detail.state = 0
- model_detail.createperson = user_id
- model_detail.createTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- session.add(model_detail)
- session.commit()
- # 清空本次训练生成文件
- del_directory(model_dir)
- return True
- except Exception as e:
- train_fail(project_id, user_id)
- # 清空本次训练生成文件
- if model_dir and os.path.exists(model_dir):
- del_directory(model_dir)
- return train_fail(project_id, user_id)
- def predict_model(request_params):
- # 清空本地数据库缓存
- session.expire_all()
- session.commit()
- # 获取预测参数
- project_id = request_params.get("id", -1)
- id_list = request_params.get("id_list", [])
- model_id = request_params.get("model_id", -1)
- project_info = session.query(Project).filter_by(id=project_id).first()
- # 查询项目信息
- if not project_info:
- return {"error_code": 0, "error_message": f"项目信息不存在--> {project_id}"}
- focus_field, target_label, label_type = project_info.focusField.split(
- ","), project_info.labels.split(","), project_info.type
- # 查询模型信息
- model_info = session.query(Model).filter_by(id=model_id).first()
- print(type(model_id), "-->", model_id)
- if not model_info:
- return {"error_code": 0, "error_message": f"模型信息不存在--> {model_id}"}
- # 加载模型
- model_url = model_info.modelFile
- model_dir = os.path.join(baseDir, model_url)
- model_path = os.path.join(model_dir, "model.model")
- if not os.path.exists(model_path):
- dir_status = generate_directory(model_dir)
- if not dir_status:
- return {"error_code": 0, "error_message": f"文件夹创建失败,请检查存储设备-->"}
- status = File.download_file(model_url, model_path)
- if not status:
- return {"error_code": 0, "error_message": f"oss储存模型加载失败--> {model_id}"}
- try:
- data = predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path)
- except Exception as e:
- print(e)
- return {"error_code": 0, "error_message": "预测过程出错"}
- # 清空缓存
- if os.path.exists(model_dir):
- del_directory(model_dir)
- return {"error_code": 1, "data": data}
|