147 lines
6.0 KiB
Python

#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : views.py
# @IDE : PyCharm
# @desc : 路由,视图文件
from core.dependencies import IdList
from core.database import redis_getter
from . import schemas, crud, params, service
from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import threading
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, APIRouter, Form, UploadFile
app = APIRouter()
###########################################################
# 项目推理集合信息
###########################################################
@app.get("/list", summary="获取项目推理集合信息列表")
async def detect_list(
p: params.ProjectDetectParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.post("/", summary="创建项目推理集合信息")
async def add_detect(
data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.create_data(data=data)
return SuccessResponse(msg="保存成功")
@app.delete("/", summary="删除项目推理集合信息")
async def delete_detect(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
###########################################################
# 项目推理集合文件信息
###########################################################
@app.get("/file", summary="获取项目推理集合文件信息列表")
async def file_list(
p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
if p.limit:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
else:
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)
@app.post("/file", summary="上传项目推理集合文件")
async def upload_file(
detect_id: int = Form(...),
files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())):
file_dal = crud.ProjectDetectFileDal(auth.db)
detect_out = file_dal.get_data(data_id=detect_id)
if detect_out is None:
return ErrorResponse("训练集合查询失败,请刷新后再试")
await file_dal.add_file(detect_out, files)
return SuccessResponse(msg="上传成功")
@app.delete("/file", summary="删除项目推理集合文件信息")
async def delete_file(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
return SuccessResponse("删除成功")
@app.post("/detect", summary="开始推理")
def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = detect_dal.get_data(detect_log_in.detect_id)
if detect is None:
return ErrorResponse(msg="训练集合不存在")
train = train_dal.get_data(detect_log_in.train_id)
if train is None:
return ErrorResponse("训练权重不存在")
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.id, detect_log.detect_id, auth.db,))
thread_train.start()
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
thread_train.start()
return SuccessResponse(msg="执行成功")
###########################################################
# 项目推理记录信息
###########################################################
@app.get("/log", summary="获取项目推理记录列表")
async def log_pager(
p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.get("/log_files", summary="获取项目推理记录文件列表")
async def log_files(
p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)