yolo中断训练
This commit is contained in:
parent
cb1033b661
commit
b70ba8a431
@ -31,6 +31,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
bp = Blueprint('AlgorithmController', __name__)
|
bp = Blueprint('AlgorithmController', __name__)
|
||||||
|
|
||||||
|
ifKillDict = {}
|
||||||
|
|
||||||
def start_train_algorithm():
|
def start_train_algorithm():
|
||||||
"""
|
"""
|
||||||
@ -147,6 +148,42 @@ def algorithm_process_value_websocket():
|
|||||||
|
|
||||||
return wrapTheFunction
|
return wrapTheFunction
|
||||||
|
|
||||||
|
def algorithm_kill_value_websocket():
|
||||||
|
"""
|
||||||
|
获取kill值, websocket发布
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapTheFunction(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped_function(*args, **kwargs):
|
||||||
|
data = func(*args, **kwargs)
|
||||||
|
id = data["id"]
|
||||||
|
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
|
||||||
|
manager.send_message_proj_json(message=data_res, id=id)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return wrapped_function
|
||||||
|
|
||||||
|
return wrapTheFunction
|
||||||
|
|
||||||
|
|
||||||
|
def algorithm_error_value_websocket():
|
||||||
|
"""
|
||||||
|
获取error值, websocket发布
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapTheFunction(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped_function(*args, **kwargs):
|
||||||
|
data = func(*args, **kwargs)
|
||||||
|
id = data["id"]
|
||||||
|
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||||
|
manager.send_message_proj_json(message=data_res, id=id)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return wrapped_function
|
||||||
|
|
||||||
|
return wrapTheFunction
|
||||||
|
|
||||||
def obtain_train_param():
|
def obtain_train_param():
|
||||||
"""
|
"""
|
||||||
@ -164,7 +201,6 @@ def obtain_train_param():
|
|||||||
|
|
||||||
return wrapTheFunction
|
return wrapTheFunction
|
||||||
|
|
||||||
|
|
||||||
def obtain_test_param():
|
def obtain_test_param():
|
||||||
"""
|
"""
|
||||||
获取验证参数
|
获取验证参数
|
||||||
@ -215,6 +251,16 @@ def obtain_download_pt_param():
|
|||||||
|
|
||||||
return wrapTheFunction
|
return wrapTheFunction
|
||||||
|
|
||||||
|
@bp.route('/change_ifKillDIct', methods=['get'])
|
||||||
|
def change_ifKillDIct():
|
||||||
|
"""
|
||||||
|
修改全局变量
|
||||||
|
"""
|
||||||
|
id = request.args.get('id')
|
||||||
|
type = request.args.get('type')
|
||||||
|
global ifKillDict
|
||||||
|
ifKillDict[id] = False
|
||||||
|
return output_wrapped(0, 'success')
|
||||||
|
|
||||||
# @start_train_algorithm()
|
# @start_train_algorithm()
|
||||||
# def start(param: str):
|
# def start(param: str):
|
||||||
@ -241,6 +287,13 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
|||||||
from app import file_tool
|
from app import file_tool
|
||||||
|
|
||||||
|
|
||||||
|
def error_return(id: str):
|
||||||
|
"""
|
||||||
|
算法出错,返回
|
||||||
|
"""
|
||||||
|
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None}
|
||||||
|
manager.send_message_proj_json(message=data_res, id=id)
|
||||||
|
|
||||||
# 启动训练
|
# 启动训练
|
||||||
@start_train_algorithm()
|
@start_train_algorithm()
|
||||||
def train_R0DY(params_str, id):
|
def train_R0DY(params_str, id):
|
||||||
@ -255,8 +308,10 @@ def train_R0DY(params_str, id):
|
|||||||
epoches = params.get('epochnum').value
|
epoches = params.get('epochnum').value
|
||||||
batch_size = params.get('batch_size').value
|
batch_size = params.get('batch_size').value
|
||||||
device = params.get('device').value
|
device = params.get('device').value
|
||||||
|
try:
|
||||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
||||||
|
except:
|
||||||
|
error_return(id=id)
|
||||||
|
|
||||||
|
|
||||||
# 启动验证程序
|
# 启动验证程序
|
||||||
|
@ -61,6 +61,8 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
|||||||
smart_resume, torch_distributed_zero_first)
|
smart_resume, torch_distributed_zero_first)
|
||||||
from app.schemas.TrainResult import Report, ProcessValueList
|
from app.schemas.TrainResult import Report, ProcessValueList
|
||||||
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
||||||
|
from app.controller.AlgorithmController import ifKillDict
|
||||||
|
from app.utils.websocket_tool import manager
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||||
@ -304,6 +306,15 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
|||||||
num_train_img=train_num,
|
num_train_img=train_num,
|
||||||
train_mod_savepath=best)
|
train_mod_savepath=best)
|
||||||
|
|
||||||
|
def kill_return():
|
||||||
|
"""
|
||||||
|
算法中断,返回
|
||||||
|
"""
|
||||||
|
id = report.id
|
||||||
|
data = report.dict()
|
||||||
|
data_res = {'code': 1, "type": 'kill', 'msg': 'fail', 'data': data}
|
||||||
|
manager.send_message_proj_json(message=data_res, id=id)
|
||||||
|
|
||||||
@algorithm_process_value_websocket()
|
@algorithm_process_value_websocket()
|
||||||
def report_cellback(i, num_epochs, reportAccu):
|
def report_cellback(i, num_epochs, reportAccu):
|
||||||
report.rate_of_progess = ((i + 1) / num_epochs) * 100
|
report.rate_of_progess = ((i + 1) / num_epochs) * 100
|
||||||
@ -314,6 +325,11 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
|||||||
###################结束#######################
|
###################结束#######################
|
||||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||||
#callbacks.run('on_train_epoch_start')
|
#callbacks.run('on_train_epoch_start')
|
||||||
|
global ifKillDict
|
||||||
|
ifkill = ifKillDict['id']
|
||||||
|
if ifkill:
|
||||||
|
kill_return()
|
||||||
|
break
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# Update image weights (optional, single-GPU only)
|
# Update image weights (optional, single-GPU only)
|
||||||
|
Loading…
Reference in New Issue
Block a user