#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:30 # @File : views.py # @IDE : PyCharm # @desc : 路由,视图文件 from core.dependencies import IdList from core.database import redis_getter from . import schemas, crud, params, service from utils.websocket_server import room_manager from apps.business.train.crud import ProjectTrainDal from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth from utils.response import SuccessResponse, ErrorResponse import threading from redis.asyncio import Redis from fastapi import Depends, APIRouter, Form, UploadFile app = APIRouter() ########################################################### # 项目推理集合信息 ########################################################### @app.get("/list", summary="获取项目推理集合信息列表") async def detect_list( p: params.ProjectDetectParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) @app.post("/", summary="创建项目推理集合信息") async def add_detect( data: schemas.ProjectDetectIn, auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) if await detect_dal.check_name(data.detect_name, data.project_id): return ErrorResponse(msg="该项目中存在相同名称的集合") await detect_dal.create_data(data=data) return SuccessResponse(msg="保存成功") @app.delete("/", summary="删除项目推理集合信息") async def delete_detect( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False) return SuccessResponse("删除成功") ########################################################### # 项目推理集合文件信息 ########################################################### @app.get("/file", summary="获取项目推理集合文件信息列表") async def file_list( p: params.ProjectDetectFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): if p.limit: datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) else: datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas) @app.post("/file", summary="上传项目推理集合文件") async def upload_file( detect_id: int = Form(...), files: list[UploadFile] = Form(...), auth: Auth = Depends(AllUserAuth())): file_dal = crud.ProjectDetectFileDal(auth.db) detect_out = file_dal.get_data(data_id=detect_id) if detect_out is None: return ErrorResponse("训练集合查询失败,请刷新后再试") await file_dal.add_file(detect_out, files) return SuccessResponse(msg="上传成功") @app.delete("/file", summary="删除项目推理集合文件信息") async def delete_file( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids) return SuccessResponse("删除成功") @app.post("/detect", summary="开始推理") def run_detect_yolo( detect_log_in: schemas.ProjectDetectLogIn, auth: Auth = Depends(AllUserAuth()), rd: Redis = Depends(redis_getter)): detect_dal = crud.ProjectDetectDal(auth.db) train_dal = ProjectTrainDal(auth.db) detect = detect_dal.get_data(detect_log_in.detect_id) if detect is None: return ErrorResponse(msg="训练集合不存在") train = train_dal.get_data(detect_log_in.train_id) if train is None: return ErrorResponse("训练权重不存在") file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) if file_count == 0 and detect.rtsp_url is None: return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片") if detect.file_type == 'img' or detect.file_type == 'video': detect_log = service.before_detect(detect_log_in, detect, train, auth.db) thread_train = threading.Thread(target=service.run_img_loop, args=(detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version, detect_log.id, detect_log.detect_id, auth.db,)) thread_train.start() elif detect.file_type == 'rtsp': room = 'detect_rtsp_' + str(detect.id) if not room_manager.rooms.get(room): if detect_log_in.pt_type == 'best': weights_pt = train.best_pt else: weights_pt = train.last_pt thread_train = threading.Thread(target=service.run_rtsp_loop, args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,)) thread_train.start() return SuccessResponse(msg="执行成功") ########################################################### # 项目推理记录信息 ########################################################### @app.get("/log", summary="获取项目推理记录列表") async def log_pager( p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) @app.get("/log_files", summary="获取项目推理记录文件列表") async def log_files( p: params.ProjectDetectLogFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas)