# coding:utf-8 from docs.config import dictionaryPath from docs.config import dictionaryUrl from machine_models.databases import File from machine_models.train_model import train from machine_models.databases.mysql_helper import Model from machine_models.databases import session from machine_models.databases.mysql_helper import Project from machine_models.predict_model import predict from util.file_operations import generate_directory, del_directory from docs.config import baseDir import os import joblib import uuid import datetime # 词典文件加载,只加载一次 if not os.path.exists(dictionaryPath): status = File.download_file(dictionaryUrl, dictionaryPath) if not status: raise ValueError("词典文件下载失败") tfidf_vec = joblib.load(dictionaryPath) def train_fail(project_id, user_id): ''' 记录失败日志 :param project_id: :param user_id: :return: ''' fail_model = Model(state=2, projectId=project_id, createperson=user_id, createTime=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) session.add(fail_model) session.commit() return False def train_model(request_params: dict): # 清空数据库链接对象缓存 session.expire_all() session.commit() # 获取训练项目数据 project_id = request_params.get("id") user_id = request_params.get("userId", "") label_type = request_params.get("type", 1) fields = request_params.get("fields", "") model_dir = "" try: # 不存在项目Id if not project_id: return train_fail(project_id, user_id) model_dir = os.path.join(baseDir, str(uuid.uuid4())) dir_status = generate_directory(model_dir) # 文件夹生成错误 if not dir_status: return train_fail(project_id, user_id) # 开始训练 model_detail = train(project_id, fields.split(","), tfidf_vec, label_type, model_dir) # 训练失败 if not model_detail: return train_fail(project_id, user_id) # 训练成功记录 model_detail.projectId = project_id model_detail.state = 0 model_detail.createperson = user_id model_detail.createTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') session.add(model_detail) session.commit() # 清空本次训练生成文件 del_directory(model_dir) return True except Exception as e: train_fail(project_id, user_id) # 清空本次训练生成文件 if model_dir and os.path.exists(model_dir): del_directory(model_dir) return train_fail(project_id, user_id) def predict_model(request_params): # 清空本地数据库缓存 session.expire_all() session.commit() # 获取预测参数 project_id = request_params.get("id", -1) id_list = request_params.get("id_list", []) model_id = request_params.get("model_id", -1) project_info = session.query(Project).filter_by(id=project_id).first() # 查询项目信息 if not project_info: return {"error_code": 0, "error_message": f"项目信息不存在--> {project_id}"} focus_field, target_label, label_type = project_info.focusField.split( ","), project_info.labels.split(","), project_info.type # 查询模型信息 model_info = session.query(Model).filter_by(id=model_id).first() print(type(model_id), "-->", model_id) if not model_info: return {"error_code": 0, "error_message": f"模型信息不存在--> {model_id}"} # 加载模型 model_url = model_info.modelFile model_dir = os.path.join(baseDir, model_url) model_path = os.path.join(model_dir, "model.model") if not os.path.exists(model_path): dir_status = generate_directory(model_dir) if not dir_status: return {"error_code": 0, "error_message": f"文件夹创建失败,请检查存储设备-->"} status = File.download_file(model_url, model_path) if not status: return {"error_code": 0, "error_message": f"oss储存模型加载失败--> {model_id}"} try: data = predict(id_list, tfidf_vec, label_type, focus_field, target_label, model_path) except Exception as e: print(e) return {"error_code": 0, "error_message": "预测过程出错"} # 清空缓存 if os.path.exists(model_dir): del_directory(model_dir) return {"error_code": 1, "data": data}