148 lines
6.0 KiB
Python
Raw Normal View History

2025-04-11 08:54:28 +08:00
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : views.py
# @IDE : PyCharm
# @desc : 路由,视图文件
2025-04-17 15:57:16 +08:00
import service
from . import schemas, crud, params
2025-04-11 08:54:28 +08:00
from core.dependencies import IdList
2025-04-17 15:57:16 +08:00
from core.database import redis_getter
from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal
2025-04-11 08:54:28 +08:00
from apps.vadmin.auth.utils.current import AllUserAuth
2025-04-17 15:57:16 +08:00
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
2025-04-11 08:54:28 +08:00
2025-04-17 15:57:16 +08:00
import threading
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, APIRouter, Form, UploadFile
2025-04-11 08:54:28 +08:00
app = APIRouter()
###########################################################
# 项目推理集合信息
###########################################################
2025-04-17 15:57:16 +08:00
@app.get("/list", summary="获取项目推理集合信息列表")
async def detect_list(
p: params.ProjectDetectParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
2025-04-11 08:54:28 +08:00
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
2025-04-17 15:57:16 +08:00
@app.post("/", summary="创建项目推理集合信息")
async def add_detect(
data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.create_data(data=data)
return SuccessResponse(msg="保存成功")
2025-04-11 08:54:28 +08:00
2025-04-17 15:57:16 +08:00
@app.delete("/", summary="删除项目推理集合信息")
async def delete_detect(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
2025-04-11 08:54:28 +08:00
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
###########################################################
2025-04-17 15:57:16 +08:00
# 项目推理集合文件信息
2025-04-11 08:54:28 +08:00
###########################################################
2025-04-17 15:57:16 +08:00
@app.get("/file", summary="获取项目推理集合文件信息列表")
async def file_list(
p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
if p.limit:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
else:
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)
@app.post("/file", summary="上传项目推理集合文件")
async def upload_file(
detect_id: int = Form(...),
files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())):
file_dal = crud.ProjectDetectFileDal(auth.db)
detect_out = file_dal.get_data(data_id=detect_id)
if detect_out is None:
return ErrorResponse("训练集合查询失败,请刷新后再试")
await file_dal.add_file(detect_out, files)
return SuccessResponse(msg="上传成功")
@app.delete("/file", summary="删除项目推理集合文件信息")
async def delete_file(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
2025-04-11 08:54:28 +08:00
return SuccessResponse("删除成功")
2025-04-17 15:57:16 +08:00
@app.post("/detect", summary="开始推理")
def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = detect_dal.get_data(detect_log_in.detect_id)
if detect is None:
return ErrorResponse(msg="训练集合不存在")
train = train_dal.get_data(detect_log_in.train_id)
if train is None:
return ErrorResponse("训练权重不存在")
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.id, detect_log.detect_id, auth.db,))
thread_train.start()
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
thread_train.start()
return SuccessResponse(msg="执行成功")
2025-04-11 08:54:28 +08:00
###########################################################
# 项目推理记录信息
###########################################################
2025-04-17 15:57:16 +08:00
@app.get("/log", summary="获取项目推理记录列表")
async def log_pager(
p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
2025-04-11 08:54:28 +08:00
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
2025-04-17 15:57:16 +08:00
@app.get("/log_files", summary="获取项目推理记录文件列表")
async def log_files(
p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)
2025-04-11 08:54:28 +08:00