Merge branch 'master' of https://gitea.star-rising.cn/xkrs_manan/RODY
This commit is contained in:
commit
4b636f1b7d
27
app/configs/global_var.py
Normal file
27
app/configs/global_var.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""
|
||||
@Time : 2022/11/15 10:13
|
||||
@Auth : 东
|
||||
@File :global_var.py
|
||||
@IDE :PyCharm
|
||||
@Motto:ABC(Always Be Coding)
|
||||
@Desc:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _init(): # 初始化
|
||||
global _global_dict
|
||||
_global_dict = {}
|
||||
|
||||
|
||||
def set_value(key, value):
|
||||
# 定义一个全局变量
|
||||
_global_dict[key] = value
|
||||
|
||||
|
||||
def get_value(key):
|
||||
# 获得一个全局变量,不存在则提示读取对应变量失败
|
||||
try:
|
||||
return _global_dict[key]
|
||||
except:
|
||||
print('读取' + key + '失败\r\n')
|
@ -31,6 +31,7 @@ from pathlib import Path
|
||||
|
||||
bp = Blueprint('AlgorithmController', __name__)
|
||||
|
||||
ifKillDict = {}
|
||||
|
||||
def start_train_algorithm():
|
||||
"""
|
||||
@ -147,6 +148,42 @@ def algorithm_process_value_websocket():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def algorithm_kill_value_websocket():
|
||||
"""
|
||||
获取kill值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_error_value_websocket():
|
||||
"""
|
||||
获取error值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_train_param():
|
||||
"""
|
||||
@ -164,7 +201,6 @@ def obtain_train_param():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_test_param():
|
||||
"""
|
||||
获取验证参数
|
||||
@ -215,6 +251,16 @@ def obtain_download_pt_param():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
@bp.route('/change_ifKillDIct', methods=['get'])
|
||||
def change_ifKillDIct():
|
||||
"""
|
||||
修改全局变量
|
||||
"""
|
||||
id = request.args.get('id')
|
||||
type = request.args.get('type')
|
||||
global ifKillDict
|
||||
ifKillDict[id] = False
|
||||
return output_wrapped(0, 'success')
|
||||
|
||||
# @start_train_algorithm()
|
||||
# def start(param: str):
|
||||
@ -241,6 +287,13 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
||||
from app import file_tool
|
||||
|
||||
|
||||
def error_return(id: str, data):
|
||||
"""
|
||||
算法出错,返回
|
||||
"""
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
# 启动训练
|
||||
@start_train_algorithm()
|
||||
def train_R0DY(params_str, id):
|
||||
@ -255,8 +308,12 @@ def train_R0DY(params_str, id):
|
||||
epoches = params.get('epochnum').value
|
||||
batch_size = params.get('batch_size').value
|
||||
device = params.get('device').value
|
||||
|
||||
try:
|
||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
||||
print("train down!")
|
||||
except Exception as e:
|
||||
print(repr(e))
|
||||
error_return(id=id,data=repr(e))
|
||||
|
||||
|
||||
# 启动验证程序
|
||||
@ -303,7 +360,8 @@ def Export_model_RODY(params_str):
|
||||
exp_inputPath = params.get('exp_inputPath').value # 模型路径
|
||||
print('输入模型:', exp_inputPath)
|
||||
exp_device = params.get('device').value
|
||||
modellist = Start_Model_Export(exp_inputPath, exp_device)
|
||||
imgsz = params.get('imgsz').value
|
||||
modellist = Start_Model_Export(exp_inputPath, exp_device, imgsz)
|
||||
exp_outputPath = exp_inputPath.replace('pt', 'zip') # 压缩文件
|
||||
print('模型路径:',exp_outputPath)
|
||||
zipf = zipfile.ZipFile(exp_outputPath, 'w')
|
||||
@ -312,20 +370,19 @@ def Export_model_RODY(params_str):
|
||||
|
||||
return exp_outputPath
|
||||
|
||||
|
||||
@obtain_train_param()
|
||||
def returnTrainParams():
|
||||
# nvmlInit()
|
||||
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
nvmlInit()
|
||||
gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
_kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "epochnum", "value": 10, "description": '训练轮次', "default": 100, "type": "I", 'show': True},
|
||||
{"index": 1, "name": "batch_size", "value": 4, "description": '批次图像数量', "default": 1, "type": "I",
|
||||
'show': True},
|
||||
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
|
||||
'show': True},
|
||||
{"index": 3, "name": "device", "value": "0", "description": '训练核心', "default": "cuda", "type": "S",
|
||||
"items": '', 'show': True}, # _kernel
|
||||
{"index": 3, "name": "device", "value": 'CUDA', "description": '训练核心', "default": 'CUDA', "type": "E",
|
||||
"items": _kernel, 'show': True}, # _kernel
|
||||
{"index": 4, "name": "saveModDir", "value": "E:/alg_demo-master/alg_demo/app/yolov5/best.pt",
|
||||
"description": '保存模型路径',
|
||||
"default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
@ -381,7 +438,7 @@ def returnDetectParams():
|
||||
{"index": 1, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
|
||||
"description": '输出结果路径',
|
||||
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
|
||||
{"index": 0, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
{"index": 2, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
"description": '模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 3, "name": "device", "value": "0", "description": '推理核', "default": "cpu", "type": "S",
|
||||
'show': False},
|
||||
@ -399,7 +456,9 @@ def returnDownloadParams():
|
||||
"default": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt/',
|
||||
"type": "S", 'show': False},
|
||||
{"index": 1, "name": "device", "value": 'gpu', "description": 'CPU或GPU', "default": 'gpu', "type": "S",
|
||||
'show': False}
|
||||
'show': False},
|
||||
{"index": 2, "name": "imgsz", "value": 640, "description": '图像大小', "default": 640, "type": "I",
|
||||
'show': True}
|
||||
]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
@ -32,6 +32,7 @@ class Report(BaseModel):
|
||||
train_mod_savepath: str = Field(..., description="模型保存路径")
|
||||
start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间")
|
||||
end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间")
|
||||
alg_code: str = Field(..., description="模型编码")
|
||||
|
||||
|
||||
class ReportDict(BaseModel):
|
||||
|
@ -566,11 +566,12 @@ def run(
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
|
||||
def parse_opt(weights,device):
|
||||
def parse_opt(weights,device,imgsz):
|
||||
imgsz = [imgsz,imgsz]
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=weights, help='model.pt path(s)')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=imgsz, help='image (h, w)') #default=[640, 640]
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--device', default=device, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
||||
@ -604,13 +605,13 @@ def main(opt):
|
||||
f = run(**vars(opt))
|
||||
return f
|
||||
|
||||
def Start_Model_Export(weights,device):
|
||||
def Start_Model_Export(weights,device,imgsz):
|
||||
# 判断cpu or gpu
|
||||
if device == 'gpu':
|
||||
device = '0'
|
||||
else:
|
||||
device = 'cpu'
|
||||
opt = parse_opt(weights,device)
|
||||
opt = parse_opt(weights,device,imgsz)
|
||||
f = main(opt)
|
||||
return f
|
||||
|
||||
|
@ -61,6 +61,8 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
||||
smart_resume, torch_distributed_zero_first)
|
||||
from app.schemas.TrainResult import Report, ProcessValueList
|
||||
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
||||
from app.configs import global_var
|
||||
from app.utils.websocket_tool import manager
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
@ -72,7 +74,7 @@ def yaml_rewrite(file='data.yaml',data_list=[]):
|
||||
with open(file, errors='ignore') as f:
|
||||
coco_dict = yaml.safe_load(f)
|
||||
#读取img_label_type.json
|
||||
with open(data_list[3], 'r') as f:
|
||||
with open(data_list[3], 'r',encoding='UTF-8') as f:
|
||||
class_dict = json.load(f)
|
||||
f.close()
|
||||
classes = class_dict["classes"]
|
||||
@ -302,7 +304,17 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
report = Report(rate_of_progess=0, precision=[process_value_list],
|
||||
id=id, sum=epochs, progress=0,
|
||||
num_train_img=train_num,
|
||||
train_mod_savepath=best)
|
||||
train_mod_savepath=best,
|
||||
alg_code="R-ODY")
|
||||
|
||||
def kill_return():
|
||||
"""
|
||||
算法中断,返回
|
||||
"""
|
||||
id = report.id
|
||||
data = report.dict()
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
@algorithm_process_value_websocket()
|
||||
def report_cellback(i, num_epochs, reportAccu):
|
||||
@ -314,6 +326,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
###################结束#######################
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
#callbacks.run('on_train_epoch_start')
|
||||
print("start get global_var")
|
||||
ifkill = global_var.get_value(report.id)
|
||||
print("get global_var down:",ifkill)
|
||||
if ifkill:
|
||||
kill_return()
|
||||
break
|
||||
model.train()
|
||||
|
||||
# Update image weights (optional, single-GPU only)
|
||||
@ -334,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
optimizer.zero_grad()
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
#callbacks.run('on_train_batch_start')
|
||||
print("start get global_var")
|
||||
ifkill = global_var.get_value(report.id)
|
||||
print("get global_var down:",ifkill)
|
||||
if ifkill:
|
||||
kill_return()
|
||||
break
|
||||
if targets.shape[0] == 0:
|
||||
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
|
||||
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],
|
||||
|
@ -430,9 +430,12 @@ class LoadImagesAndLabels(Dataset):
|
||||
self.label_files = img2label_paths(self.im_files) # labels
|
||||
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
|
||||
try:
|
||||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||||
assert cache['version'] == self.cache_version # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical has
|
||||
# cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||||
# assert cache['version'] == self.cache_version # matches current version
|
||||
# assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical has
|
||||
if os.path.exists(cache_path):
|
||||
os.remove(cache_path)
|
||||
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
|
||||
except Exception:
|
||||
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user