226 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from application.settings import yolo_url, detect_url
from utils.websocket_server import room_manager
from utils import os_utils as os
from . import models, crud, schemas
from apps.business.train import models as train_models
from utils.yolov5.models.common import DetectMultiBackend
from utils.yolov5.utils.torch_utils import select_device
from utils.yolov5.utils.dataloaders import LoadStreams
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
from ultralytics.utils.plotting import Annotator, colors
import time
import torch
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
async def before_detect(
detect_in: schemas.ProjectDetectLogIn,
detect: models.ProjectDetect,
train: train_models.ProjectTrain,
db: AsyncSession):
"""
开始推理
:param detect:
:param detect_in:
:param train:
:param db:
:return:
"""
# 推理版本
version_path = 'v' + str(detect.detect_version + 1)
# 权重文件
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
# 推理集合文件路径
img_url = detect.folder_url
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
# 构建推理记录数据
detect_log = models.ProjectDetectLog()
detect_log.detect_name = detect.detect_name
detect_log.detect_id = detect.id
detect_log.detect_version = version_path
detect_log.train_id = train.id
detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type
detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url
await crud.ProjectDetectLogDal(db).create_data(detect_log)
return detect_log
async def run_detect_img(
weights: str,
source: str,
project: str,
name: str,
log_id: int,
detect_id: int,
db: AsyncSession,
rd: Redis):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:param log_id: 日志id
:param detect_id: 推理集合id
:param db: 数据库session
:param rd: Redis
:return:
"""
yolo_path = os.file_path(yolo_url, 'detect.py')
room = 'detect_' + str(detect_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
project, "--save-txt", "--conf-thres", "0.4"]
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device", "0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
detect_files = crud.ProjectDetectFileDal(db).get_data(
v_where=[models.ProjectDetectFile.detect_id == detect_id])
detect_log_files = []
for detect_file in detect_files:
detect_log_img = models.ProjectDetectLogFile()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_file.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_file.file_name
detect_log_files.append(detect_log_img)
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param rd: Redis :redis
:return:
"""
room = 'detect_rtsp_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cpu')
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
device = select_device('cuda:0')
# 加载模型
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3) # 等待3s等待websocket进入
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
with dt[1]:
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image
p, im0, frame = path[i], im0s[i].copy(), dataset.count
annotator = Annotator(im0, line_width=3, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
annotator.box_label(xyxy, label, color=colors(c, True))
# Stream results
im0 = annotator.result()
# 将帧编码为 JPEG
ret, jpeg = cv2.imencode('.jpg', im0)
if ret:
frame_data = jpeg.tobytes()
await room_manager.send_stream_to_room(room, frame_data)
else:
print(room, '结束推理')
break
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
# 可选: 关闭循环
loop.close()
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
# 可选: 关闭循环
loop.close()