73 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:32
# @File : views.py
from . import models, schemas, crud
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
from utils.response import SuccessResponse, ErrorResponse
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
import service
import threading
from fastapi import APIRouter, Depends
app = APIRouter()
###########################################################
# 项目训练信息
###########################################################
@app.post("/", summary="执行训练")
async def run_train(
train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth())):
proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db)
proj_info = await proj_dal.get_data(proj_id)
if proj_info is None:
return ErrorResponse(msg="项目信息查询错误")
train_count, val_count = await proj_img_dal.get_img_count(proj_id)
if train_count == 0:
return ErrorResponse("请先上传训练图片")
if train_count < 10:
return ErrorResponse("训练图片少于10张请继续上传训练图片")
if val_count == 0:
return ErrorResponse("请先上传验证图片")
if val_count < 5:
return ErrorResponse("验证图片少于5张请继续上传验证图片")
train_label_count, val_label_count = await proj_img_dal.check_image_label(proj_id)
if train_label_count > 0:
return ErrorResponse("训练图片中存在未标注的图片")
if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = service.before_train(proj_info, auth.db)
# 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread(
target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, auth.db,))
thread_train.start()
return SuccessResponse(msg="执行成功")
@app.get("/{proj_id}", summary="查询训练列表")
async def train_list(
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectTrainDal(auth.db).get_datas(
v_where=[models.ProjectTrain.project_id == proj_id],
v_schema=schemas.ProjectTrainOut,
v_order="asc",
v_order_field="id",v_return_count=False)
return SuccessResponse(data=datas)
@app.get("/result/{proj_id}", summary="查询训练报告")
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())):
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
return SuccessResponse(data=result)