199 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils import os_utils as os
from application.settings import *
from . import schemas, models, crud
from utils.websocket_server import room_manager
from apps.business.project import models as proj_models, crud as proj_crud
import yaml
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
"""
yolov5执行训练任务
:param proj_info: 项目信息
:param db: 数据库session
:return:
"""
proj_dal = proj_crud.ProjectInfoDal(db)
img_dal = proj_crud.ProjectImageDal(db)
label_dal = proj_crud.ProjectLabelDal(db)
# 先查询两个图片列表
project_images_train = img_dal.get_data(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'])
project_images_val = img_dal.get_data(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'])
# 得到训练版本
version_path = 'v' + str(proj_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, proj_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id)
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
# 开始执行异步训练
data = yaml_file
project = os.file_path(runs_url, proj_info.project_no)
name = version_path
return data, project, name
async def operate_img_label(
img_list: list[proj_models.ProjectImgLabel],
img_path: str,
label_path: str,
db: AsyncSession,
label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param db: 数据库session
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
async def run_event_loop(
data: str,
project: str,
name: str,
train_in: schemas.ProjectTrainIn,
project_id: int,
db: AsyncSession):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id,
project_id, db))
# 可选: 关闭循环
loop.close()
async def run_commend(
data: str,
project: str,
name: str,
epochs: int,
patience: int,
weights: str,
project_id: int,
db: AsyncSession,
rd: Redis):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param db: 数据库session
:param rd: redis连接
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻
if weights != '' and weights is not None:
train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights))
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n' and '0%' not in line:
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
await room_manager.send_to_room(room, 'error')
else:
await room_manager.send_to_room(room, 'success')
# 然后保存版本训练信息
train = models.ProjectTrain()
train.project_id = project_id
train.train_version = name
train_url = os.file_path(project, name)
train.train_url = train_url
train.train_data = data
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if weights is not None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
await crud.ProjectTrainDal(db).create_data(data=train)