From 0033746fe1a7104059f702cd799f88b5e3e73a6c Mon Sep 17 00:00:00 2001 From: sunyugang Date: Tue, 22 Apr 2025 10:11:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E9=A1=B9=E7=9B=AE=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E6=A8=A1=E5=9D=97=E7=9A=84=E6=8E=A5=E5=8F=A3=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../business/detect/schemas/project_detect.py | 7 +- .../detect/schemas/project_detect_file.py | 8 +- .../detect/schemas/project_detect_log.py | 7 +- .../detect/schemas/project_detect_log_file.py | 5 +- apps/business/detect/views.py | 1 - apps/business/project/crud.py | 31 +++-- apps/business/train/schemas/project_train.py | 6 +- apps/business/train/service.py | 127 ++++++++++-------- apps/business/train/views.py | 26 ++-- core/crud.py | 10 +- core/websocket_app.py | 2 +- 11 files changed, 139 insertions(+), 91 deletions(-) diff --git a/apps/business/detect/schemas/project_detect.py b/apps/business/detect/schemas/project_detect.py index a454e7a..696f886 100644 --- a/apps/business/detect/schemas/project_detect.py +++ b/apps/business/detect/schemas/project_detect.py @@ -6,9 +6,10 @@ # @IDE : PyCharm # @desc : pydantic 模型,用于数据库序列化操作 -from pydantic import BaseModel, Field, ConfigDict +from core.data_types import DatetimeStr + from typing import Optional -from datetime import datetime +from pydantic import BaseModel, Field, ConfigDict class ProjectDetectIn(BaseModel): @@ -36,7 +37,7 @@ class ProjectDetectOut(BaseModel): file_type: Optional[str] folder_url: Optional[str] rtsp_url: Optional[str] - create_time: Optional[datetime] + create_datetime: DatetimeStr model_config = ConfigDict(from_attributes=True) diff --git a/apps/business/detect/schemas/project_detect_file.py b/apps/business/detect/schemas/project_detect_file.py index 60249ea..795201e 100644 --- a/apps/business/detect/schemas/project_detect_file.py +++ b/apps/business/detect/schemas/project_detect_file.py @@ -5,9 +5,11 @@ # @File : project_detect_file.py # @IDE : PyCharm # @desc : pydantic 模型,用于数据库序列化操作 -from pydantic import BaseModel, Field, ConfigDict + +from core.data_types import DatetimeStr + from typing import Optional -from datetime import datetime +from pydantic import BaseModel, Field, ConfigDict class ProjectDetectFilePager(BaseModel): @@ -21,6 +23,6 @@ class ProjectDetectFileOut(BaseModel): detect_id: Optional[int] = Field(..., description="训练集合id") file_name: Optional[str] = Field(None, description="文件名称") thumb_file_url: Optional[str] = Field(None, description="文件路径") - create_time: Optional[datetime] = Field(None, description="上传时间") + create_datetime: DatetimeStr model_config = ConfigDict(from_attributes=True) diff --git a/apps/business/detect/schemas/project_detect_log.py b/apps/business/detect/schemas/project_detect_log.py index 042d21a..9232fd6 100644 --- a/apps/business/detect/schemas/project_detect_log.py +++ b/apps/business/detect/schemas/project_detect_log.py @@ -6,9 +6,10 @@ # @IDE : PyCharm # @desc : pydantic 模型,用于数据库序列化操作 -from pydantic import BaseModel, Field, ConfigDict +from core.data_types import DatetimeStr + from typing import Optional -from datetime import datetime +from pydantic import BaseModel, Field, ConfigDict class ProjectDetectLogIn(BaseModel): @@ -25,6 +26,6 @@ class ProjectDetectLogOut(BaseModel): train_id: Optional[int] train_version: Optional[str] pt_type: Optional[str] - create_time: Optional[datetime] + create_datetime: DatetimeStr model_config = ConfigDict(from_attributes=True) diff --git a/apps/business/detect/schemas/project_detect_log_file.py b/apps/business/detect/schemas/project_detect_log_file.py index 50ae092..0121081 100644 --- a/apps/business/detect/schemas/project_detect_log_file.py +++ b/apps/business/detect/schemas/project_detect_log_file.py @@ -6,15 +6,16 @@ # @IDE : PyCharm # @desc : pydantic 模型,用于数据库序列化操作 +from core.data_types import DatetimeStr + from pydantic import BaseModel, ConfigDict from typing import Optional -from datetime import datetime class ProjectDetectLogFileOut(BaseModel): id: Optional[int] file_name: Optional[str] thumb_file_url: Optional[str] - create_time: Optional[datetime] + create_datetime: DatetimeStr model_config = ConfigDict(from_attributes=True) diff --git a/apps/business/detect/views.py b/apps/business/detect/views.py index 4144d0a..2ee8545 100644 --- a/apps/business/detect/views.py +++ b/apps/business/detect/views.py @@ -17,7 +17,6 @@ from utils.response import SuccessResponse, ErrorResponse import threading from redis.asyncio import Redis -from sqlalchemy.ext.asyncio import AsyncSession from fastapi import Depends, APIRouter, Form, UploadFile diff --git a/apps/business/project/crud.py b/apps/business/project/crud.py index b501054..4b67c1e 100644 --- a/apps/business/project/crud.py +++ b/apps/business/project/crud.py @@ -112,6 +112,12 @@ class ProjectInfoDal(DalBase): await self.flush(obj) return await self.out_dict(obj, None, False, schemas.ProjectInfoOut) + async def update_version(self, data_id): + proj = await self.get_data(data_id) + if proj: + proj.train_version = proj.train_version + 1 + await self.put_data(data_id=data_id, data=proj) + class ProjectImageDal(DalBase): """ @@ -246,22 +252,26 @@ class ProjectImageDal(DalBase): # 2 主查询 query = ( select( - models.ProjectImage, + models.ProjectImage.id, func.ifnull(subquery.c.label_count, 0).label('label_count') ) .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id) ) train_count_sql = await self.filter_core( v_start_sql=query, - v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'], + v_where=[models.ProjectImage.project_id == proj_id, + models.ProjectImage.img_type == 'train', + subquery.c.label_count == 0], v_return_sql=True) - train_count = await self.get_count(train_count_sql) + train_count = await self.get_count_sql(train_count_sql) val_count_sql = await self.filter_core( v_start_sql=query, - v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'], + v_where=[models.ProjectImage.project_id == proj_id, + models.ProjectImage.img_type == 'val', + subquery.c.label_count == 0], v_return_sql=True) - val_count = await self.get_count(val_count_sql) + val_count = await self.get_count_sql(val_count_sql) return train_count, val_count @@ -295,14 +305,15 @@ class ProjectLabelDal(DalBase): async def get_label_for_train(self, project_id: int): id_list = [] name_list = [] - label_list = self.get_datas( + label_list = await self.get_datas( v_where=[self.model.project_id == project_id], + limit=0, v_order='asc', v_order_field='id', v_return_count=False) for label in label_list: - id_list.append(label.id) - name_list.append(label.label_name) + id_list.append(label['id']) + name_list.append(label['label_name']) return id_list, name_list @@ -326,10 +337,12 @@ class ProjectImgLabelDal(DalBase): async def get_img_label_list(self, image_id: int): return await self.get_datas( + limit=0, v_return_count=False, v_where=[self.model.image_id == image_id], v_order="asc", - v_order_field="id") + v_order_field="id", + v_return_objs=True) async def del_img_label(self, label_ids: list[int]): img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)]) diff --git a/apps/business/train/schemas/project_train.py b/apps/business/train/schemas/project_train.py index 02fc45f..cbbfb12 100644 --- a/apps/business/train/schemas/project_train.py +++ b/apps/business/train/schemas/project_train.py @@ -6,7 +6,8 @@ # @IDE : PyCharm # @desc : pydantic 模型,用于数据库序列化操作 -from datetime import datetime +from core.data_types import DatetimeStr + from pydantic import BaseModel, Field, ConfigDict from typing import Optional @@ -18,6 +19,7 @@ from typing import Optional class ProjectTrainIn(BaseModel): project_id: Optional[int] = Field(..., description="项目id") weights_id: Optional[str] = Field(None, description="权重文件") + weights_name: Optional[str] = Field(None, description="权重文件名称") epochs: Optional[int] = Field(50, description="训练轮数") patience: Optional[int] = Field(20, description="早停的耐心值") @@ -28,7 +30,7 @@ class ProjectTrainOut(BaseModel): weights_name: Optional[str] = Field(None, description="权重名称") epochs: Optional[int] = Field(None, description="训练轮数") patience: Optional[int] = Field(None, description="早停的耐心值") - create_time: Optional[datetime] = Field(None, description="训练时间") + create_datetime: DatetimeStr model_config = ConfigDict(from_attributes=True) diff --git a/apps/business/train/service.py b/apps/business/train/service.py index 3fd307b..b71aaff 100644 --- a/apps/business/train/service.py +++ b/apps/business/train/service.py @@ -1,4 +1,4 @@ -from utils import os_utils as os +from utils import os_utils as osu from application.settings import * from . import schemas, models, crud from utils.websocket_server import room_manager @@ -19,34 +19,39 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession): :param db: 数据库session :return: """ - proj_dal = proj_crud.ProjectInfoDal(db) img_dal = proj_crud.ProjectImageDal(db) label_dal = proj_crud.ProjectLabelDal(db) # 先查询两个图片列表 - project_images_train = img_dal.get_data( - v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train']) - project_images_val = img_dal.get_data( - v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val']) + project_images_train = await img_dal.get_datas( + v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'], + limit=0, + v_return_count=False, + v_return_objs=True) + project_images_val = await img_dal.get_datas( + v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'], + limit=0, + v_return_count=False, + v_return_objs=True) # 得到训练版本 version_path = 'v' + str(proj_info.train_version + 1) # 创建训练的根目录 - train_path = os.create_folder(datasets_url, proj_info.project_no, version_path) + train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path) # 查询项目所属标签,返回两个 id,name一一对应的数组 - label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id) + label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id) # 创建图片的的两个文件夹 - img_path_train = os.create_folder(train_path, 'images', 'train') - img_path_val = os.create_folder(train_path, 'images', 'val') + img_path_train = osu.create_folder(train_path, 'images', 'train') + img_path_val = osu.create_folder(train_path, 'images', 'val') # 创建标签的两个文件夹 - label_path_train = os.create_folder(train_path, 'labels', 'train') - label_path_val = os.create_folder(train_path, 'labels', 'val') + label_path_train = osu.create_folder(train_path, 'labels', 'train') + label_path_val = osu.create_folder(train_path, 'labels', 'val') # 在根目录下创建yaml文件 - yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml') + yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml') yaml_data = { 'path': train_path, 'train': 'images/train', @@ -59,20 +64,20 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession): # 开始循环复制图片和生成label.txt # 先操作train - operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list) + await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list) # 再操作val - operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list) + await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list) # 开始执行异步训练 data = yaml_file - project = os.file_path(runs_url, proj_info.project_no) + project = osu.file_path(runs_url, proj_info.project_no) name = version_path return data, project, name async def operate_img_label( - img_list: list[proj_models.ProjectImgLabel], + img_list: list[proj_models.ProjectImage], img_path: str, label_path: str, db: AsyncSession, @@ -90,10 +95,10 @@ async def operate_img_label( image = img_list[i] # 先复制图片,并把图片改名,不改后缀 file_name = 'image' + str(i) - os.copy_and_rename_file(image.image_url, img_path, file_name) + osu.copy_and_rename_file(image.image_url, img_path, file_name) # 查询这张图片的label信息然后生成这张照片的txt文件 img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id) - label_txt_path = os.file_path(label_path, file_name + '.txt') + label_txt_path = osu.file_path(label_path, file_name + '.txt') with open(label_txt_path, 'w', encoding='utf-8') as file: for image_label in img_label_list: index = label_id_list.index(image_label.label_id) @@ -103,20 +108,19 @@ async def operate_img_label( + image_label.mark_height + '\n') -async def run_event_loop( +def run_event_loop( data: str, project: str, name: str, train_in: schemas.ProjectTrainIn, project_id: int, - db: AsyncSession): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - # 运行异步函数 - loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id, - project_id, db)) - # 可选: 关闭循环 - loop.close() + train_info: models.ProjectTrain, + is_gup: str): + # 运行异步函数,开始训练 + loop_run = asyncio.new_event_loop() + asyncio.set_event_loop(loop_run) + loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, + project_id, train_info, is_gup)) async def run_commend( @@ -125,10 +129,9 @@ async def run_commend( name: str, epochs: int, patience: int, - weights: str, project_id: int, - db: AsyncSession, - rd: Redis): + train_info: models.ProjectTrain, + is_gpu: str): """ 执行训练 :param data: 训练数据集 @@ -138,23 +141,20 @@ async def run_commend( :param patience: 早停耐心值 :param weights: 权重文件 :param project_id: 项目id - :param db: 数据库session - :param rd: redis连接 + :param train_info: 训练信息 + :param is_gpu: 是否是gpu环境 :return: """ - yolo_path = os.file_path(yolo_url, 'train.py') + yolo_path = osu.file_path(yolo_url, 'train.py') room = 'train_' + str(project_id) await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n") commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)] - # 增加权重文件,在之前训练的基础上重新巡逻 - if weights != '' and weights is not None: - train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights)) - if train_info is not None: - commend.append("--weights=" + train_info.best_pt) + # 增加权重文件,在之前训练的基础上重新训练 + if train_info is not None: + commend.append("--weights=" + train_info.best_pt) - is_gpu = rd.get('is_gpu') # 判断是否存在cuda版本 if is_gpu == 'True': commend.append("--device=0") @@ -180,20 +180,33 @@ async def run_commend( await room_manager.send_to_room(room, 'error') else: await room_manager.send_to_room(room, 'success') - # 然后保存版本训练信息 - train = models.ProjectTrain() - train.project_id = project_id - train.train_version = name - train_url = os.file_path(project, name) - train.train_url = train_url - train.train_data = data - bast_pt_path = os.file_path(train_url, 'weights', 'best.pt') - last_pt_path = os.file_path(train_url, 'weights', 'last.pt') - train.best_pt = bast_pt_path - train.last_pt = last_pt_path - if weights is not None and weights != '': - train.weights_id = weights - train.weights_name = train_info.train_version - train.patience = patience - train.epochs = epochs - await crud.ProjectTrainDal(db).create_data(data=train) \ No newline at end of file + + +async def add_train( + db, + project_id, + name, + project, + data, + train_in, + user_id): + # 更新版本信息 + await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id) + # 增加训练版本信息 + train = models.ProjectTrain() + train.project_id = project_id + train.train_version = name + train_url = osu.file_path(project, name) + train.train_url = train_url + train.train_data = data + train.user_id = user_id + bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt') + last_pt_path = osu.file_path(train_url, 'weights', 'last.pt') + train.best_pt = bast_pt_path + train.last_pt = last_pt_path + if train_in is not None: + train.weights_id = train_in.weights_id + train.weights_name = train_in.weights_name + train.patience = train_in.patience + train.epochs = train_in.epochs + await crud.ProjectTrainDal(db).create_model(data=train) \ No newline at end of file diff --git a/apps/business/train/views.py b/apps/business/train/views.py index 66ed4a4..2b81a29 100644 --- a/apps/business/train/views.py +++ b/apps/business/train/views.py @@ -3,6 +3,8 @@ # @version : 1.0 # @Create Time : 2025/04/03 10:32 # @File : views.py + +from core.database import redis_getter from . import models, schemas, crud, service from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth @@ -10,6 +12,7 @@ from utils.response import SuccessResponse, ErrorResponse from apps.business.project.crud import ProjectInfoDal, ProjectImageDal import threading +from redis.asyncio import Redis from fastapi import APIRouter, Depends @@ -19,10 +22,11 @@ app = APIRouter() ########################################################### # 项目训练信息 ########################################################### -@app.post("/", summary="执行训练") +@app.post("/start", summary="执行训练") async def run_train( train_in: schemas.ProjectTrainIn, - auth: Auth = Depends(AllUserAuth())): + auth: Auth = Depends(AllUserAuth()), + rd: Redis = Depends(redis_getter)): proj_id = train_in.project_id proj_dal = ProjectInfoDal(auth.db) proj_img_dal = ProjectImageDal(auth.db) @@ -43,12 +47,17 @@ async def run_train( return ErrorResponse("训练图片中存在未标注的图片") if val_label_count > 0: return ErrorResponse("验证图片中存在未标注的图片") - data, project, name = service.before_train(proj_info, auth.db) + data, project, name = await service.before_train(proj_info, auth.db) + is_gpu = await rd.get('is_gpu') + train_info = None + if train_in.weights_id is not None: + train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id) # 异步执行操作,操作过程通过websocket进行同步 thread_train = threading.Thread( target=service.run_event_loop, - args=(data, project, name, train_in, proj_id, auth.db,)) + args=(data, project, name, train_in, proj_id, train_info, is_gpu)) thread_train.start() + await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id) return SuccessResponse(msg="执行成功") @@ -57,15 +66,16 @@ async def train_list( proj_id: int, auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectTrainDal(auth.db).get_datas( + limit=0, v_where=[models.ProjectTrain.project_id == proj_id], v_schema=schemas.ProjectTrainOut, v_order="asc", - v_order_field="id",v_return_count=False) + v_order_field="id", + v_return_count=False) return SuccessResponse(data=datas) -@app.get("/result/{proj_id}", summary="查询训练报告") -async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())): +@app.get("/result/{train_id}", summary="查询训练报告") +async def get_result(train_id: int, auth: Auth = Depends(AllUserAuth())): result = await crud.ProjectTrainDal(auth.db).get_result(train_id) return SuccessResponse(data=result) - diff --git a/core/crud.py b/core/crud.py index 5d1a074..19e2395 100644 --- a/core/crud.py +++ b/core/crud.py @@ -263,11 +263,17 @@ class DalBase: await self.db.execute(insert(self.model), datas) await self.db.flush() + async def create_model(self, data: Any) -> None: + """ + 创建单个model + :param data: model + """ + self.db.add(data) + await self.db.flush() + async def create_models(self, datas: list[Any]) -> None: """ 批量创建数据 - SQLAlchemy 2.0 批量插入不支持 MySQL 返回值: - https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning :param datas: model数组 """ self.db.add_all(datas) diff --git a/core/websocket_app.py b/core/websocket_app.py index 77f9cdb..c091205 100644 --- a/core/websocket_app.py +++ b/core/websocket_app.py @@ -7,7 +7,7 @@ from utils.websocket_server import room_manager def websocket_config(app: FastAPI): - @app.websocket("/{room}") + @app.websocket("/ws/{room}") async def websocket_room(websocket: WebSocket, room: str): """ websocket 房间管理