2025-04-11 08:54:28 +08:00
|
|
|
#!/usr/bin/python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# @version : 1.0
|
|
|
|
# @Create Time : 2025/04/03 10:30
|
|
|
|
# @File : views.py
|
|
|
|
# @IDE : PyCharm
|
|
|
|
# @desc : 路由,视图文件
|
2025-04-17 15:57:16 +08:00
|
|
|
|
|
|
|
import service
|
|
|
|
from . import schemas, crud, params
|
2025-04-11 08:54:28 +08:00
|
|
|
from core.dependencies import IdList
|
2025-04-17 15:57:16 +08:00
|
|
|
from core.database import redis_getter
|
|
|
|
from utils.websocket_server import room_manager
|
|
|
|
from apps.business.train.crud import ProjectTrainDal
|
2025-04-11 08:54:28 +08:00
|
|
|
from apps.vadmin.auth.utils.current import AllUserAuth
|
2025-04-17 15:57:16 +08:00
|
|
|
from apps.vadmin.auth.utils.validation.auth import Auth
|
|
|
|
from utils.response import SuccessResponse, ErrorResponse
|
2025-04-11 08:54:28 +08:00
|
|
|
|
2025-04-17 15:57:16 +08:00
|
|
|
import threading
|
|
|
|
from redis.asyncio import Redis
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from fastapi import Depends, APIRouter, Form, UploadFile
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
app = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
###########################################################
|
|
|
|
# 项目推理集合信息
|
|
|
|
###########################################################
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.get("/list", summary="获取项目推理集合信息列表")
|
|
|
|
async def detect_list(
|
|
|
|
p: params.ProjectDetectParams = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
2025-04-11 08:54:28 +08:00
|
|
|
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
|
|
|
return SuccessResponse(datas, count=count)
|
|
|
|
|
|
|
|
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.post("/", summary="创建项目推理集合信息")
|
|
|
|
async def add_detect(
|
|
|
|
data: schemas.ProjectDetectIn,
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
detect_dal = crud.ProjectDetectDal(auth.db)
|
|
|
|
if await detect_dal.check_name(data.detect_name, data.project_id):
|
|
|
|
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
|
|
|
await detect_dal.create_data(data=data)
|
|
|
|
return SuccessResponse(msg="保存成功")
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.delete("/", summary="删除项目推理集合信息")
|
|
|
|
async def delete_detect(
|
|
|
|
ids: IdList = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
2025-04-11 08:54:28 +08:00
|
|
|
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
|
|
|
return SuccessResponse("删除成功")
|
|
|
|
|
|
|
|
|
|
|
|
###########################################################
|
2025-04-17 15:57:16 +08:00
|
|
|
# 项目推理集合文件信息
|
2025-04-11 08:54:28 +08:00
|
|
|
###########################################################
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
|
|
|
async def file_list(
|
|
|
|
p: params.ProjectDetectFileParams = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
if p.limit:
|
|
|
|
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
|
|
|
return SuccessResponse(datas, count=count)
|
|
|
|
else:
|
|
|
|
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
|
|
|
return SuccessResponse(datas)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/file", summary="上传项目推理集合文件")
|
|
|
|
async def upload_file(
|
|
|
|
detect_id: int = Form(...),
|
|
|
|
files: list[UploadFile] = Form(...),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
file_dal = crud.ProjectDetectFileDal(auth.db)
|
|
|
|
detect_out = file_dal.get_data(data_id=detect_id)
|
|
|
|
if detect_out is None:
|
|
|
|
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
|
|
|
await file_dal.add_file(detect_out, files)
|
|
|
|
return SuccessResponse(msg="上传成功")
|
|
|
|
|
|
|
|
|
|
|
|
@app.delete("/file", summary="删除项目推理集合文件信息")
|
|
|
|
async def delete_file(
|
|
|
|
ids: IdList = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
|
2025-04-11 08:54:28 +08:00
|
|
|
return SuccessResponse("删除成功")
|
|
|
|
|
|
|
|
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.post("/detect", summary="开始推理")
|
|
|
|
def run_detect_yolo(
|
|
|
|
detect_log_in: schemas.ProjectDetectLogIn,
|
|
|
|
auth: Auth = Depends(AllUserAuth()),
|
|
|
|
rd: Redis = Depends(redis_getter)):
|
|
|
|
detect_dal = crud.ProjectDetectDal(auth.db)
|
|
|
|
train_dal = ProjectTrainDal(auth.db)
|
|
|
|
detect = detect_dal.get_data(detect_log_in.detect_id)
|
|
|
|
if detect is None:
|
|
|
|
return ErrorResponse(msg="训练集合不存在")
|
|
|
|
train = train_dal.get_data(detect_log_in.train_id)
|
|
|
|
if train is None:
|
|
|
|
return ErrorResponse("训练权重不存在")
|
|
|
|
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
|
|
|
if file_count == 0 and detect.rtsp_url is None:
|
|
|
|
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
|
|
|
if detect.file_type == 'img' or detect.file_type == 'video':
|
|
|
|
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
|
|
|
thread_train = threading.Thread(target=service.run_img_loop,
|
|
|
|
args=(detect_log.pt_url, detect_log.folder_url,
|
|
|
|
detect_log.detect_folder_url, detect_log.detect_version,
|
|
|
|
detect_log.id, detect_log.detect_id, auth.db,))
|
|
|
|
thread_train.start()
|
|
|
|
elif detect.file_type == 'rtsp':
|
|
|
|
room = 'detect_rtsp_' + str(detect.id)
|
|
|
|
if not room_manager.rooms.get(room):
|
|
|
|
if detect_log_in.pt_type == 'best':
|
|
|
|
weights_pt = train.best_pt
|
|
|
|
else:
|
|
|
|
weights_pt = train.last_pt
|
|
|
|
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
|
|
|
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
|
|
|
thread_train.start()
|
|
|
|
return SuccessResponse(msg="执行成功")
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
###########################################################
|
|
|
|
# 项目推理记录信息
|
|
|
|
###########################################################
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.get("/log", summary="获取项目推理记录列表")
|
|
|
|
async def log_pager(
|
|
|
|
p: params.ProjectDetectLogParams = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
2025-04-11 08:54:28 +08:00
|
|
|
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
|
|
|
return SuccessResponse(datas, count=count)
|
|
|
|
|
|
|
|
|
2025-04-17 15:57:16 +08:00
|
|
|
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
|
|
|
async def log_files(
|
|
|
|
p: params.ProjectDetectLogFileParams = Depends(),
|
|
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
|
|
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
|
|
|
return SuccessResponse(datas)
|
2025-04-11 08:54:28 +08:00
|
|
|
|