From cdb081654a543ef420f6ec1fc79e0e3df13a3a76 Mon Sep 17 00:00:00 2001 From: jiakunhao Date: Tue, 15 Nov 2022 17:59:06 +0800 Subject: [PATCH] =?UTF-8?q?yolo=E4=B8=AD=E6=96=AD=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E6=B7=BB=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/configs/global_var.py | 27 +++++++++++++++++++++++++++ app/controller/AlgorithmController.py | 10 ++++++---- app/schemas/TrainResult.py | 1 + app/yolov5/train_server.py | 16 ++++++++++++---- 4 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 app/configs/global_var.py diff --git a/app/configs/global_var.py b/app/configs/global_var.py new file mode 100644 index 0000000..bb2add3 --- /dev/null +++ b/app/configs/global_var.py @@ -0,0 +1,27 @@ +""" +@Time : 2022/11/15 10:13 +@Auth : 东 +@File :global_var.py +@IDE :PyCharm +@Motto:ABC(Always Be Coding) +@Desc: + +""" + + +def _init(): # 初始化 + global _global_dict + _global_dict = {} + + +def set_value(key, value): + # 定义一个全局变量 + _global_dict[key] = value + + +def get_value(key): + # 获得一个全局变量,不存在则提示读取对应变量失败 + try: + return _global_dict[key] + except: + print('读取' + key + '失败\r\n') diff --git a/app/controller/AlgorithmController.py b/app/controller/AlgorithmController.py index ddabca1..70eca5f 100644 --- a/app/controller/AlgorithmController.py +++ b/app/controller/AlgorithmController.py @@ -287,11 +287,11 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport from app import file_tool -def error_return(id: str): +def error_return(id: str, data): """ 算法出错,返回 """ - data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None} + data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data} manager.send_message_proj_json(message=data_res, id=id) # 启动训练 @@ -310,8 +310,10 @@ def train_R0DY(params_str, id): device = params.get('device').value try: train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) - except: - error_return(id=id) + print("train down!") + except Exception as e: + print(repr(e)) + error_return(id=id,data=repr(e)) # 启动验证程序 diff --git a/app/schemas/TrainResult.py b/app/schemas/TrainResult.py index a5bb7a0..ff07135 100644 --- a/app/schemas/TrainResult.py +++ b/app/schemas/TrainResult.py @@ -32,6 +32,7 @@ class Report(BaseModel): train_mod_savepath: str = Field(..., description="模型保存路径") start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间") end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间") + alg_code: str = Field(..., description="模型编码") class ReportDict(BaseModel): diff --git a/app/yolov5/train_server.py b/app/yolov5/train_server.py index aafb52f..9a625a7 100644 --- a/app/yolov5/train_server.py +++ b/app/yolov5/train_server.py @@ -61,7 +61,7 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, smart_resume, torch_distributed_zero_first) from app.schemas.TrainResult import Report, ProcessValueList from app.controller.AlgorithmController import algorithm_process_value_websocket -from app.controller.AlgorithmController import ifKillDict +from app.configs import global_var from app.utils.websocket_tool import manager LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -304,7 +304,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml report = Report(rate_of_progess=0, precision=[process_value_list], id=id, sum=epochs, progress=0, num_train_img=train_num, - train_mod_savepath=best) + train_mod_savepath=best, + alg_code="R-ODY") def kill_return(): """ @@ -325,8 +326,9 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml ###################结束####################### for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ #callbacks.run('on_train_epoch_start') - global ifKillDict - ifkill = ifKillDict['id'] + print("start get global_var") + ifkill = global_var.get_value(report.id) + print("get global_var down:",ifkill) if ifkill: kill_return() break @@ -350,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml optimizer.zero_grad() for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- #callbacks.run('on_train_batch_start') + print("start get global_var") + ifkill = global_var.get_value(report.id) + print("get global_var down:",ifkill) + if ifkill: + kill_return() + break if targets.shape[0] == 0: targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553], [0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],