#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:32 # @File : views.py from . import models, schemas, crud, service from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth from utils.response import SuccessResponse, ErrorResponse from apps.business.project.crud import ProjectInfoDal, ProjectImageDal import threading from fastapi import APIRouter, Depends app = APIRouter() ########################################################### # 项目训练信息 ########################################################### @app.post("/", summary="执行训练") async def run_train( train_in: schemas.ProjectTrainIn, auth: Auth = Depends(AllUserAuth())): proj_id = train_in.project_id proj_dal = ProjectInfoDal(auth.db) proj_img_dal = ProjectImageDal(auth.db) proj_info = await proj_dal.get_data(proj_id) if proj_info is None: return ErrorResponse(msg="项目信息查询错误") train_count, val_count = await proj_img_dal.get_img_count(proj_id) if train_count == 0: return ErrorResponse("请先上传训练图片") if train_count < 10: return ErrorResponse("训练图片少于10张,请继续上传训练图片") if val_count == 0: return ErrorResponse("请先上传验证图片") if val_count < 5: return ErrorResponse("验证图片少于5张,请继续上传验证图片") train_label_count, val_label_count = await proj_img_dal.check_image_label(proj_id) if train_label_count > 0: return ErrorResponse("训练图片中存在未标注的图片") if val_label_count > 0: return ErrorResponse("验证图片中存在未标注的图片") data, project, name = service.before_train(proj_info, auth.db) # 异步执行操作,操作过程通过websocket进行同步 thread_train = threading.Thread( target=service.run_event_loop, args=(data, project, name, train_in, proj_id, auth.db,)) thread_train.start() return SuccessResponse(msg="执行成功") @app.get("/{proj_id}", summary="查询训练列表") async def train_list( proj_id: int, auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectTrainDal(auth.db).get_datas( v_where=[models.ProjectTrain.project_id == proj_id], v_schema=schemas.ProjectTrainOut, v_order="asc", v_order_field="id",v_return_count=False) return SuccessResponse(data=datas) @app.get("/result/{proj_id}", summary="查询训练报告") async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())): result = await crud.ProjectTrainDal(auth.db).get_result(train_id) return SuccessResponse(data=result)