94 lines
3.4 KiB
Python
Raw Normal View History

2025-04-11 08:54:28 +08:00
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:32
# @File : crud.py
# @IDE : PyCharm
# @desc : 数据访问层
from sqlalchemy.ext.asyncio import AsyncSession
from core.crud import DalBase
from . import models, schemas
2025-04-17 11:03:05 +08:00
from utils import os_utils as os
from utils.csv_utils import read_csv
2025-04-11 08:54:28 +08:00
class ProjectTrainDal(DalBase):
def __init__(self, db: AsyncSession):
super(ProjectTrainDal, self).__init__()
self.db = db
self.model = models.ProjectTrain
2025-04-17 11:03:05 +08:00
self.schema = schemas.ProjectTrainOut
async def get_result(self, train_id: int):
"""
查询训练报告
"""
data = await self.get_data(data_id=train_id)
if data is None:
return None
result_csv_path = os.file_path(data.train_url, 'results.csv')
result_row = read_csv(result_csv_path)
report_data = {}
# 轮数
epoch_data = []
# 边界框回归损失Bounding Box Loss衡量预测框位置中心坐标、宽高与真实框的差异值越低表示定位越准。
train_box_loss = []
# 目标置信度损失Objectness Loss衡量检测到目标的置信度误差即是否包含物体值越低表示模型越能正确判断有无物体。
train_obj_loss = []
# 分类损失Classification Loss衡量预测类别与真实类别的差异值越低表示分类越准。
train_cls_loss = []
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
val_box_loss = []
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
val_obj_loss = []
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
val_cls_loss = []
# 精确率Precision正确检测的正样本占所有预测为正样本的比例反映“误检率”。值越高说明误检越少。
m_p = []
# 召回率Recall正确检测的正样本占所有真实正样本的比例反映“漏检率”。值越高说明漏检越少。
m_r = []
# 主干网络Backbone的学习率。
x_lr0 = []
# 检测头Head的学习率。
x_lr1 = []
for row in result_row:
epoch_data.append(row[0].strip())
train_box_loss.append(row[1].strip())
train_obj_loss.append(row[2].strip())
train_cls_loss.append(row[3].strip())
val_box_loss.append(row[8].strip())
val_obj_loss.append(row[9].strip())
val_cls_loss.append(row[10].strip())
m_p.append(row[4].strip())
m_r.append(row[5].strip())
x_lr0.append(row[11].strip())
x_lr1.append(row[12].strip())
report_data['epoch_data'] = epoch_data
report_data['train_box_loss'] = train_box_loss
report_data['train_obj_loss'] = train_obj_loss
report_data['train_cls_loss'] = train_cls_loss
report_data['val_box_loss'] = val_box_loss
report_data['val_obj_loss'] = val_obj_loss
report_data['val_cls_loss'] = val_cls_loss
report_data['m_p'] = m_p
report_data['m_r'] = m_r
report_data['x_lr0'] = x_lr0
report_data['x_lr1'] = x_lr1
return report_data