This commit is contained in:
wudong 2022-11-08 10:02:48 +08:00
commit cd01c9af86
3 changed files with 7 additions and 7 deletions

View File

@ -242,7 +242,7 @@ from app import file_tool
# 启动训练
@start_train_algorithm()
#@start_train_algorithm()
def train_R0DY(params_str, id):
from app.yolov5.train_server import train_start
params = TrainParams()
@ -314,7 +314,7 @@ def train_R0DY(params_str, id):
# zip_outputPath = os.path.join(exp_outputPath, "inference_model.zip")
@obtain_train_param()
#@obtain_train_param()
def returnTrainParams():
# nvmlInit()
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
@ -338,7 +338,7 @@ def returnTrainParams():
{"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L",
"items": '',
'show': False},
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori/",
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori",
"description": '数据集路径',
"default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH
]

View File

@ -179,7 +179,7 @@ def get_file(ori_path: str, type_list: Union[object,str]):
test_files = []
# 训练、测试比例强制91
for img in imgs[0:1]:
path = ori_path + '/images/' +img
path = ori_path + '/images/' +img #'/images/'
# print(os.path.exists(path))
print('图像路径',path)
if os.path.exists(path):
@ -187,7 +187,7 @@ def get_file(ori_path: str, type_list: Union[object,str]):
print('1111')
#label = ori_path + 'labels/' + os.path.split(path)[1]
(filename1, extension) = os.path.splitext(img) # 文件名与后缀名分开
label = ori_path + '/labels/' + filename1 + '.json'
label = ori_path + '/labels/' + filename1 + '.json' #'/labels/'
print('标签',label)
if label is not None:
#train_files.append(label)

View File

@ -304,7 +304,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
num_train_img=train_num,
train_mod_savepath=best)
# @algorithm_process_value_websocket()
@algorithm_process_value_websocket()
def report_cellback(i, num_epochs, reportAccu):
report.rate_of_progess = ((i + 1) / num_epochs) * 100
report.progress = (i + 1)
@ -470,7 +470,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
print('##############',best)
for f in best:
print('##################',f)
if os.path.exists(f):
if os.path.exists(best):
strip_optimizer(f) # strip optimizers
if f is best:
LOGGER.info(f'\nValidating {f}...')