完成项目训练模块的接口测试

This commit is contained in:
sunyugang 2025-04-22 10:11:44 +08:00
parent 7a9e571a96
commit 0033746fe1
11 changed files with 139 additions and 91 deletions

View File

@ -6,9 +6,10 @@
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectIn(BaseModel):
@ -36,7 +37,7 @@ class ProjectDetectOut(BaseModel):
file_type: Optional[str]
folder_url: Optional[str]
rtsp_url: Optional[str]
create_time: Optional[datetime]
create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True)

View File

@ -5,9 +5,11 @@
# @File : project_detect_file.py
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectFilePager(BaseModel):
@ -21,6 +23,6 @@ class ProjectDetectFileOut(BaseModel):
detect_id: Optional[int] = Field(..., description="训练集合id")
file_name: Optional[str] = Field(None, description="文件名称")
thumb_file_url: Optional[str] = Field(None, description="文件路径")
create_time: Optional[datetime] = Field(None, description="上传时间")
create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True)

View File

@ -6,9 +6,10 @@
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectLogIn(BaseModel):
@ -25,6 +26,6 @@ class ProjectDetectLogOut(BaseModel):
train_id: Optional[int]
train_version: Optional[str]
pt_type: Optional[str]
create_time: Optional[datetime]
create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True)

View File

@ -6,15 +6,16 @@
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from core.data_types import DatetimeStr
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
class ProjectDetectLogFileOut(BaseModel):
id: Optional[int]
file_name: Optional[str]
thumb_file_url: Optional[str]
create_time: Optional[datetime]
create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True)

View File

@ -17,7 +17,6 @@ from utils.response import SuccessResponse, ErrorResponse
import threading
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, APIRouter, Form, UploadFile

View File

@ -112,6 +112,12 @@ class ProjectInfoDal(DalBase):
await self.flush(obj)
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
async def update_version(self, data_id):
proj = await self.get_data(data_id)
if proj:
proj.train_version = proj.train_version + 1
await self.put_data(data_id=data_id, data=proj)
class ProjectImageDal(DalBase):
"""
@ -246,22 +252,26 @@ class ProjectImageDal(DalBase):
# 2 主查询
query = (
select(
models.ProjectImage,
models.ProjectImage.id,
func.ifnull(subquery.c.label_count, 0).label('label_count')
)
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
)
train_count_sql = await self.filter_core(
v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'],
v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'train',
subquery.c.label_count == 0],
v_return_sql=True)
train_count = await self.get_count(train_count_sql)
train_count = await self.get_count_sql(train_count_sql)
val_count_sql = await self.filter_core(
v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'],
v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'val',
subquery.c.label_count == 0],
v_return_sql=True)
val_count = await self.get_count(val_count_sql)
val_count = await self.get_count_sql(val_count_sql)
return train_count, val_count
@ -295,14 +305,15 @@ class ProjectLabelDal(DalBase):
async def get_label_for_train(self, project_id: int):
id_list = []
name_list = []
label_list = self.get_datas(
label_list = await self.get_datas(
v_where=[self.model.project_id == project_id],
limit=0,
v_order='asc',
v_order_field='id',
v_return_count=False)
for label in label_list:
id_list.append(label.id)
name_list.append(label.label_name)
id_list.append(label['id'])
name_list.append(label['label_name'])
return id_list, name_list
@ -326,10 +337,12 @@ class ProjectImgLabelDal(DalBase):
async def get_img_label_list(self, image_id: int):
return await self.get_datas(
limit=0,
v_return_count=False,
v_where=[self.model.image_id == image_id],
v_order="asc",
v_order_field="id")
v_order_field="id",
v_return_objs=True)
async def del_img_label(self, label_ids: list[int]):
img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])

View File

@ -6,7 +6,8 @@
# @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作
from datetime import datetime
from core.data_types import DatetimeStr
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
@ -18,6 +19,7 @@ from typing import Optional
class ProjectTrainIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
weights_id: Optional[str] = Field(None, description="权重文件")
weights_name: Optional[str] = Field(None, description="权重文件名称")
epochs: Optional[int] = Field(50, description="训练轮数")
patience: Optional[int] = Field(20, description="早停的耐心值")
@ -28,7 +30,7 @@ class ProjectTrainOut(BaseModel):
weights_name: Optional[str] = Field(None, description="权重名称")
epochs: Optional[int] = Field(None, description="训练轮数")
patience: Optional[int] = Field(None, description="早停的耐心值")
create_time: Optional[datetime] = Field(None, description="训练时间")
create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True)

View File

@ -1,4 +1,4 @@
from utils import os_utils as os
from utils import os_utils as osu
from application.settings import *
from . import schemas, models, crud
from utils.websocket_server import room_manager
@ -19,34 +19,39 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
:param db: 数据库session
:return:
"""
proj_dal = proj_crud.ProjectInfoDal(db)
img_dal = proj_crud.ProjectImageDal(db)
label_dal = proj_crud.ProjectLabelDal(db)
# 先查询两个图片列表
project_images_train = img_dal.get_data(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'])
project_images_val = img_dal.get_data(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'])
project_images_train = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'],
limit=0,
v_return_count=False,
v_return_objs=True)
project_images_val = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'],
limit=0,
v_return_count=False,
v_return_objs=True)
# 得到训练版本
version_path = 'v' + str(proj_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, proj_info.project_no, version_path)
train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id)
label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id)
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
img_path_train = osu.create_folder(train_path, 'images', 'train')
img_path_val = osu.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
label_path_train = osu.create_folder(train_path, 'labels', 'train')
label_path_val = osu.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml')
yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
@ -59,20 +64,20 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
# 开始执行异步训练
data = yaml_file
project = os.file_path(runs_url, proj_info.project_no)
project = osu.file_path(runs_url, proj_info.project_no)
name = version_path
return data, project, name
async def operate_img_label(
img_list: list[proj_models.ProjectImgLabel],
img_list: list[proj_models.ProjectImage],
img_path: str,
label_path: str,
db: AsyncSession,
@ -90,10 +95,10 @@ async def operate_img_label(
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
osu.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
label_txt_path = os.file_path(label_path, file_name + '.txt')
label_txt_path = osu.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
@ -103,20 +108,19 @@ async def operate_img_label(
+ image_label.mark_height + '\n')
async def run_event_loop(
def run_event_loop(
data: str,
project: str,
name: str,
train_in: schemas.ProjectTrainIn,
project_id: int,
db: AsyncSession):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id,
project_id, db))
# 可选: 关闭循环
loop.close()
train_info: models.ProjectTrain,
is_gup: str):
# 运行异步函数,开始训练
loop_run = asyncio.new_event_loop()
asyncio.set_event_loop(loop_run)
loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience,
project_id, train_info, is_gup))
async def run_commend(
@ -125,10 +129,9 @@ async def run_commend(
name: str,
epochs: int,
patience: int,
weights: str,
project_id: int,
db: AsyncSession,
rd: Redis):
train_info: models.ProjectTrain,
is_gpu: str):
"""
执行训练
:param data: 训练数据集
@ -138,23 +141,20 @@ async def run_commend(
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param db: 数据库session
:param rd: redis连接
:param train_info: 训练信息
:param is_gpu: 是否是gpu环境
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py')
yolo_path = osu.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻
if weights != '' and weights is not None:
train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights))
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
# 增加权重文件,在之前训练的基础上重新训练
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
@ -180,20 +180,33 @@ async def run_commend(
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
# 然后保存版本训练信息
train = models.ProjectTrain()
train.project_id = project_id
train.train_version = name
train_url = os.file_path(project, name)
train.train_url = train_url
train.train_data = data
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if weights is not None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
await crud.ProjectTrainDal(db).create_data(data=train)
async def add_train(
db,
project_id,
name,
project,
data,
train_in,
user_id):
# 更新版本信息
await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id)
# 增加训练版本信息
train = models.ProjectTrain()
train.project_id = project_id
train.train_version = name
train_url = osu.file_path(project, name)
train.train_url = train_url
train.train_data = data
train.user_id = user_id
bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt')
last_pt_path = osu.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if train_in is not None:
train.weights_id = train_in.weights_id
train.weights_name = train_in.weights_name
train.patience = train_in.patience
train.epochs = train_in.epochs
await crud.ProjectTrainDal(db).create_model(data=train)

View File

@ -3,6 +3,8 @@
# @version : 1.0
# @Create Time : 2025/04/03 10:32
# @File : views.py
from core.database import redis_getter
from . import models, schemas, crud, service
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
@ -10,6 +12,7 @@ from utils.response import SuccessResponse, ErrorResponse
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
import threading
from redis.asyncio import Redis
from fastapi import APIRouter, Depends
@ -19,10 +22,11 @@ app = APIRouter()
###########################################################
# 项目训练信息
###########################################################
@app.post("/", summary="执行训练")
@app.post("/start", summary="执行训练")
async def run_train(
train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth())):
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db)
@ -43,12 +47,17 @@ async def run_train(
return ErrorResponse("训练图片中存在未标注的图片")
if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = service.before_train(proj_info, auth.db)
data, project, name = await service.before_train(proj_info, auth.db)
is_gpu = await rd.get('is_gpu')
train_info = None
if train_in.weights_id is not None:
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
# 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread(
target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, auth.db,))
args=(data, project, name, train_in, proj_id, train_info, is_gpu))
thread_train.start()
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
return SuccessResponse(msg="执行成功")
@ -57,15 +66,16 @@ async def train_list(
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectTrainDal(auth.db).get_datas(
limit=0,
v_where=[models.ProjectTrain.project_id == proj_id],
v_schema=schemas.ProjectTrainOut,
v_order="asc",
v_order_field="id",v_return_count=False)
v_order_field="id",
v_return_count=False)
return SuccessResponse(data=datas)
@app.get("/result/{proj_id}", summary="查询训练报告")
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())):
@app.get("/result/{train_id}", summary="查询训练报告")
async def get_result(train_id: int, auth: Auth = Depends(AllUserAuth())):
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
return SuccessResponse(data=result)

View File

@ -263,11 +263,17 @@ class DalBase:
await self.db.execute(insert(self.model), datas)
await self.db.flush()
async def create_model(self, data: Any) -> None:
"""
创建单个model
:param data: model
"""
self.db.add(data)
await self.db.flush()
async def create_models(self, datas: list[Any]) -> None:
"""
批量创建数据
SQLAlchemy 2.0 批量插入不支持 MySQL 返回值
https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
:param datas: model数组
"""
self.db.add_all(datas)

View File

@ -7,7 +7,7 @@ from utils.websocket_server import room_manager
def websocket_config(app: FastAPI):
@app.websocket("/{room}")
@app.websocket("/ws/{room}")
async def websocket_room(websocket: WebSocket, room: str):
"""
websocket 房间管理