2025-04-11 08:54:28 +08:00
|
|
|
|
#!/usr/bin/python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
# @version : 1.0
|
|
|
|
|
# @Create Time : 2025/04/03 10:32
|
|
|
|
|
# @File : views.py
|
2025-04-22 10:11:44 +08:00
|
|
|
|
|
|
|
|
|
from core.database import redis_getter
|
2025-04-17 16:25:44 +08:00
|
|
|
|
from . import models, schemas, crud, service
|
2025-04-11 08:54:28 +08:00
|
|
|
|
from apps.vadmin.auth.utils.current import AllUserAuth
|
|
|
|
|
from apps.vadmin.auth.utils.validation.auth import Auth
|
2025-04-17 16:25:44 +08:00
|
|
|
|
from utils.response import SuccessResponse, ErrorResponse
|
|
|
|
|
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
|
2025-04-17 11:03:05 +08:00
|
|
|
|
|
|
|
|
|
import threading
|
2025-04-22 10:11:44 +08:00
|
|
|
|
from redis.asyncio import Redis
|
2025-04-17 11:03:05 +08:00
|
|
|
|
from fastapi import APIRouter, Depends
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
###########################################################
|
2025-04-17 11:03:05 +08:00
|
|
|
|
# 项目训练信息
|
2025-04-11 08:54:28 +08:00
|
|
|
|
###########################################################
|
2025-04-22 10:11:44 +08:00
|
|
|
|
@app.post("/start", summary="执行训练")
|
2025-04-17 11:03:05 +08:00
|
|
|
|
async def run_train(
|
|
|
|
|
train_in: schemas.ProjectTrainIn,
|
2025-04-22 10:11:44 +08:00
|
|
|
|
auth: Auth = Depends(AllUserAuth()),
|
|
|
|
|
rd: Redis = Depends(redis_getter)):
|
2025-04-17 11:03:05 +08:00
|
|
|
|
proj_id = train_in.project_id
|
|
|
|
|
proj_dal = ProjectInfoDal(auth.db)
|
|
|
|
|
proj_img_dal = ProjectImageDal(auth.db)
|
|
|
|
|
proj_info = await proj_dal.get_data(proj_id)
|
|
|
|
|
if proj_info is None:
|
|
|
|
|
return ErrorResponse(msg="项目信息查询错误")
|
|
|
|
|
train_count, val_count = await proj_img_dal.get_img_count(proj_id)
|
|
|
|
|
if train_count == 0:
|
|
|
|
|
return ErrorResponse("请先上传训练图片")
|
|
|
|
|
if train_count < 10:
|
|
|
|
|
return ErrorResponse("训练图片少于10张,请继续上传训练图片")
|
|
|
|
|
if val_count == 0:
|
|
|
|
|
return ErrorResponse("请先上传验证图片")
|
|
|
|
|
if val_count < 5:
|
|
|
|
|
return ErrorResponse("验证图片少于5张,请继续上传验证图片")
|
|
|
|
|
train_label_count, val_label_count = await proj_img_dal.check_image_label(proj_id)
|
|
|
|
|
if train_label_count > 0:
|
|
|
|
|
return ErrorResponse("训练图片中存在未标注的图片")
|
|
|
|
|
if val_label_count > 0:
|
|
|
|
|
return ErrorResponse("验证图片中存在未标注的图片")
|
2025-04-22 10:11:44 +08:00
|
|
|
|
data, project, name = await service.before_train(proj_info, auth.db)
|
|
|
|
|
is_gpu = await rd.get('is_gpu')
|
|
|
|
|
train_info = None
|
|
|
|
|
if train_in.weights_id is not None:
|
|
|
|
|
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
|
2025-04-17 11:03:05 +08:00
|
|
|
|
# 异步执行操作,操作过程通过websocket进行同步
|
|
|
|
|
thread_train = threading.Thread(
|
|
|
|
|
target=service.run_event_loop,
|
2025-04-22 10:11:44 +08:00
|
|
|
|
args=(data, project, name, train_in, proj_id, train_info, is_gpu))
|
2025-04-17 11:03:05 +08:00
|
|
|
|
thread_train.start()
|
2025-04-22 10:11:44 +08:00
|
|
|
|
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
|
2025-04-17 11:03:05 +08:00
|
|
|
|
return SuccessResponse(msg="执行成功")
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
2025-04-17 11:03:05 +08:00
|
|
|
|
@app.get("/{proj_id}", summary="查询训练列表")
|
|
|
|
|
async def train_list(
|
|
|
|
|
proj_id: int,
|
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
|
datas = await crud.ProjectTrainDal(auth.db).get_datas(
|
2025-04-22 10:11:44 +08:00
|
|
|
|
limit=0,
|
2025-04-17 11:03:05 +08:00
|
|
|
|
v_where=[models.ProjectTrain.project_id == proj_id],
|
|
|
|
|
v_schema=schemas.ProjectTrainOut,
|
|
|
|
|
v_order="asc",
|
2025-04-22 10:11:44 +08:00
|
|
|
|
v_order_field="id",
|
|
|
|
|
v_return_count=False)
|
2025-04-17 11:03:05 +08:00
|
|
|
|
return SuccessResponse(data=datas)
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
2025-04-22 10:11:44 +08:00
|
|
|
|
@app.get("/result/{train_id}", summary="查询训练报告")
|
|
|
|
|
async def get_result(train_id: int, auth: Auth = Depends(AllUserAuth())):
|
2025-04-17 11:03:05 +08:00
|
|
|
|
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
|
|
|
|
|
return SuccessResponse(data=result)
|