2025-04-11 08:54:28 +08:00
|
|
|
|
#!/usr/bin/python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
# @version : 1.0
|
|
|
|
|
# @Create Time : 2025/04/03 10:25
|
|
|
|
|
# @File : crud.py
|
|
|
|
|
# @IDE : PyCharm
|
|
|
|
|
# @desc : 数据访问层
|
|
|
|
|
from . import schemas, models, params
|
|
|
|
|
from apps.vadmin.auth.utils.validation.auth import Auth
|
|
|
|
|
from utils import os_utils as os, random_utils as ru
|
2025-04-18 17:22:57 +08:00
|
|
|
|
from utils.huawei_obs import MyObs
|
2025-04-11 14:30:48 +08:00
|
|
|
|
from utils import status
|
|
|
|
|
from core.exception import CustomException
|
2025-04-17 11:03:05 +08:00
|
|
|
|
from application.settings import datasets_url, runs_url, images_url
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
from typing import Any, List
|
|
|
|
|
from core.crud import DalBase
|
2025-04-11 14:30:48 +08:00
|
|
|
|
from fastapi import UploadFile
|
2025-04-11 08:54:28 +08:00
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2025-04-17 11:03:05 +08:00
|
|
|
|
from sqlalchemy import select, func, case
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProjectInfoDal(DalBase):
|
2025-04-11 14:30:48 +08:00
|
|
|
|
"""
|
|
|
|
|
项目信息
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
super(ProjectInfoDal, self).__init__()
|
|
|
|
|
self.db = db
|
|
|
|
|
self.model = models.ProjectInfo
|
|
|
|
|
self.schema = schemas.ProjectInfoOut
|
|
|
|
|
|
|
|
|
|
async def get_project_pager(self, project: params.ProjectInfoParams, auth: Auth):
|
|
|
|
|
"""
|
|
|
|
|
分页查询项目列表
|
|
|
|
|
"""
|
|
|
|
|
# 定义子查询
|
|
|
|
|
subquery = (
|
|
|
|
|
select(
|
|
|
|
|
models.ProjectImage.project_id,
|
|
|
|
|
func.sum(case((models.ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
|
|
|
|
|
func.sum(case((models.ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
|
|
|
|
|
)
|
|
|
|
|
.outerjoin(models.ProjectImgLeafer, models.ProjectImage.id == models.ProjectImgLeafer.image_id)
|
|
|
|
|
.group_by(models.ProjectImage.project_id)
|
|
|
|
|
.subquery()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
full_query = select(
|
|
|
|
|
models.ProjectInfo,
|
|
|
|
|
func.ifnull(subquery.c.mark_count, 0).label("mark_count"),
|
|
|
|
|
func.ifnull(subquery.c.no_mark_count, 0).label("no_mark_count")
|
|
|
|
|
).select_from(models.ProjectInfo).join(
|
|
|
|
|
subquery, models.ProjectInfo.id == subquery.c.project_id, isouter=True
|
|
|
|
|
)
|
|
|
|
|
v_where = [models.ProjectInfo.is_delete.is_(False)]
|
|
|
|
|
if '*' in auth.dept_ids:
|
|
|
|
|
v_where.append(models.ProjectInfo.dept_id.isnot(None))
|
|
|
|
|
else:
|
|
|
|
|
v_where.append(models.ProjectInfo.dept_id.in_(auth.dept_ids))
|
|
|
|
|
sql = await self.filter_core(
|
|
|
|
|
v_start_sql=full_query,
|
|
|
|
|
v_where=v_where,
|
|
|
|
|
v_return_sql=True,
|
|
|
|
|
v_order=project.v_order,
|
|
|
|
|
v_order_field=project.v_order_field
|
|
|
|
|
)
|
|
|
|
|
count = await self.get_count_sql(sql)
|
|
|
|
|
if project.limit != 0:
|
|
|
|
|
sql = sql.offset((project.page - 1) * project.limit).limit(project.limit)
|
|
|
|
|
queryset = await self.db.execute(sql)
|
|
|
|
|
result = queryset.all()
|
|
|
|
|
datas = []
|
|
|
|
|
for result in result:
|
|
|
|
|
data = schemas.ProjectInfoPagerOut.model_validate(result[0])
|
|
|
|
|
data.mark_count = int(result[1])
|
|
|
|
|
data.no_mark_count = int(result[2])
|
|
|
|
|
datas.append(data.model_dump())
|
|
|
|
|
return datas, count
|
|
|
|
|
|
|
|
|
|
async def check_name(self, project_name: str):
|
|
|
|
|
"""
|
|
|
|
|
校验项目名称是否重名
|
|
|
|
|
"""
|
|
|
|
|
count = await self.get_count(v_where=[models.ProjectInfo.project_name == project_name,
|
|
|
|
|
models.ProjectInfo.is_delete is False])
|
|
|
|
|
return count > 0
|
|
|
|
|
|
|
|
|
|
async def add_project(
|
|
|
|
|
self,
|
|
|
|
|
project: schemas.ProjectInfoIn,
|
|
|
|
|
auth: Auth
|
|
|
|
|
) -> Any:
|
2025-04-17 11:03:05 +08:00
|
|
|
|
"""
|
|
|
|
|
新建项目
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
obj = self.model(**project.model_dump())
|
|
|
|
|
obj.user_id = auth.user.id
|
|
|
|
|
obj.project_no = ru.random_str(6)
|
|
|
|
|
obj.project_status = "0"
|
|
|
|
|
obj.train_version = 0
|
|
|
|
|
obj.user_id = auth.user.id
|
|
|
|
|
if '*' in auth.dept_ids:
|
|
|
|
|
obj.dept_id = 0
|
|
|
|
|
else:
|
|
|
|
|
obj.dept_id = auth.dept_ids[0]
|
2025-04-17 11:03:05 +08:00
|
|
|
|
# 新建数据集文件夹
|
2025-04-11 08:54:28 +08:00
|
|
|
|
os.create_folder(datasets_url, obj.project_no)
|
2025-04-17 11:03:05 +08:00
|
|
|
|
# 新建训练文件夹
|
2025-04-11 08:54:28 +08:00
|
|
|
|
os.create_folder(runs_url, obj.project_no)
|
|
|
|
|
await self.flush(obj)
|
|
|
|
|
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProjectImageDal(DalBase):
|
2025-04-11 14:30:48 +08:00
|
|
|
|
"""
|
|
|
|
|
项目图片
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
super(ProjectImageDal, self).__init__()
|
|
|
|
|
self.db = db
|
|
|
|
|
self.model = models.ProjectImage
|
2025-04-11 14:30:48 +08:00
|
|
|
|
self.schema = schemas.ProjectImageOut
|
|
|
|
|
|
|
|
|
|
async def img_page(self, param: params.ProjectImageParams):
|
|
|
|
|
"""
|
|
|
|
|
分页查询图片信息,然后关联一个图片的标签数量
|
|
|
|
|
"""
|
|
|
|
|
subquery = (
|
|
|
|
|
select(
|
|
|
|
|
models.ProjectImgLabel.image_id,
|
|
|
|
|
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
|
|
|
|
|
)
|
|
|
|
|
.group_by(models.ProjectImgLabel.image_id)
|
|
|
|
|
.subquery()
|
|
|
|
|
)
|
|
|
|
|
# 2 主查询
|
|
|
|
|
query = (
|
|
|
|
|
select(
|
|
|
|
|
models.ProjectImage,
|
|
|
|
|
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
|
|
|
|
)
|
|
|
|
|
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
|
|
|
|
|
)
|
|
|
|
|
v_where = [models.ProjectImage.project_id == param.project_id, models.ProjectImage.img_type == param.img_type]
|
|
|
|
|
sql = await self.filter_core(
|
|
|
|
|
v_start_sql=query,
|
|
|
|
|
v_where=v_where,
|
|
|
|
|
v_return_sql=True,
|
|
|
|
|
v_order=param.v_order,
|
|
|
|
|
v_order_field=param.v_order_field
|
|
|
|
|
)
|
|
|
|
|
count = await self.get_count_sql(sql)
|
|
|
|
|
if param.limit != 0:
|
|
|
|
|
sql = sql.offset((param.page - 1) * param.limit).limit(param.limit)
|
|
|
|
|
queryset = await self.db.execute(sql)
|
|
|
|
|
result = queryset.all()
|
|
|
|
|
datas = []
|
|
|
|
|
for result in result:
|
|
|
|
|
data = schemas.ProjectImageOut.model_validate(result[0])
|
|
|
|
|
data.label_count = int(result[1])
|
|
|
|
|
datas.append(data.model_dump())
|
|
|
|
|
return datas, count
|
|
|
|
|
|
|
|
|
|
async def upload_imgs(self, files: List[UploadFile], pro: schemas.ProjectInfoOut, img_type: str) -> int:
|
|
|
|
|
"""
|
|
|
|
|
上传项目图片
|
|
|
|
|
"""
|
|
|
|
|
image_models = []
|
2025-04-18 17:22:57 +08:00
|
|
|
|
obs = MyObs()
|
2025-04-11 14:30:48 +08:00
|
|
|
|
for file in files:
|
|
|
|
|
image = models.ProjectImage()
|
|
|
|
|
image.project_id = pro.id
|
|
|
|
|
image.file_name = file.filename
|
|
|
|
|
image.img_type = img_type
|
|
|
|
|
# 保存原图
|
|
|
|
|
path = os.save_images(images_url, pro.project_no, file=file)
|
|
|
|
|
image.image_url = path
|
|
|
|
|
# 上传图片到obs
|
|
|
|
|
object_key = pro.project_no + '/' + img_type + '/' + file.filename
|
2025-04-18 17:22:57 +08:00
|
|
|
|
success, key, url = obs.put_file(object_key=object_key, file_path=path)
|
2025-04-11 14:30:48 +08:00
|
|
|
|
if success:
|
|
|
|
|
image.object_key = object_key
|
|
|
|
|
image.thumb_image_url = url
|
|
|
|
|
else:
|
|
|
|
|
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
|
|
|
|
|
image_models.append(image)
|
2025-04-18 17:22:57 +08:00
|
|
|
|
await self.create_models(datas=image_models)
|
2025-04-11 14:30:48 +08:00
|
|
|
|
return len(image_models)
|
|
|
|
|
|
|
|
|
|
async def check_img_name(self, file_name: str, project_id: int, img_type: str):
|
|
|
|
|
"""
|
|
|
|
|
校验相同的项目,相同的文件类型是否有同名的文件
|
|
|
|
|
"""
|
|
|
|
|
count = await self.get_count(v_where=[
|
|
|
|
|
models.ProjectImage.file_name == file_name,
|
|
|
|
|
models.ProjectImage.project_id == project_id,
|
|
|
|
|
models.ProjectImage.img_type == img_type
|
|
|
|
|
])
|
|
|
|
|
return count > 0
|
|
|
|
|
|
|
|
|
|
async def del_img(self, ids: List[int]):
|
|
|
|
|
"""
|
|
|
|
|
删除图片,删除数据库数据,删除本地的文件,删除obs中的文件
|
|
|
|
|
"""
|
|
|
|
|
file_urls = []
|
|
|
|
|
object_keys = []
|
|
|
|
|
for img_id in ids:
|
2025-04-18 17:22:57 +08:00
|
|
|
|
image = await self.get_data(data_id=img_id)
|
2025-04-11 14:30:48 +08:00
|
|
|
|
if image:
|
|
|
|
|
file_urls.append(image.image_url)
|
|
|
|
|
object_keys.append(image.object_key)
|
2025-04-18 17:22:57 +08:00
|
|
|
|
os.delete_file_if_exists(*file_urls)
|
|
|
|
|
MyObs().del_objects(object_keys)
|
2025-04-11 14:30:48 +08:00
|
|
|
|
await self.delete_datas(ids)
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
2025-04-17 11:03:05 +08:00
|
|
|
|
async def get_img_count(
|
|
|
|
|
self,
|
|
|
|
|
proj_id: int) -> int:
|
|
|
|
|
"""
|
|
|
|
|
查询图片数量
|
|
|
|
|
"""
|
|
|
|
|
train_count = await self.get_count(
|
|
|
|
|
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'])
|
|
|
|
|
val_count = await self.get_count(
|
|
|
|
|
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'])
|
|
|
|
|
return train_count, val_count
|
|
|
|
|
|
|
|
|
|
async def check_image_label(
|
|
|
|
|
self,
|
|
|
|
|
proj_id: int) -> int:
|
|
|
|
|
"""
|
|
|
|
|
查询图片未标注数量
|
|
|
|
|
"""
|
|
|
|
|
# 1 子查询
|
|
|
|
|
subquery = (
|
|
|
|
|
select(
|
|
|
|
|
models.ProjectImgLabel.image_id,
|
|
|
|
|
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
|
|
|
|
|
)
|
|
|
|
|
.group_by(models.ProjectImgLabel.image_id)
|
|
|
|
|
.subquery()
|
|
|
|
|
)
|
|
|
|
|
# 2 主查询
|
|
|
|
|
query = (
|
|
|
|
|
select(
|
|
|
|
|
models.ProjectImage,
|
|
|
|
|
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
|
|
|
|
)
|
|
|
|
|
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
|
|
|
|
|
)
|
|
|
|
|
train_count_sql = await self.filter_core(
|
|
|
|
|
v_start_sql=query,
|
|
|
|
|
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'],
|
|
|
|
|
v_return_sql=True)
|
|
|
|
|
train_count = await self.get_count(train_count_sql)
|
|
|
|
|
|
|
|
|
|
val_count_sql = await self.filter_core(
|
|
|
|
|
v_start_sql=query,
|
|
|
|
|
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'],
|
|
|
|
|
v_return_sql=True)
|
|
|
|
|
val_count = await self.get_count(val_count_sql)
|
|
|
|
|
|
|
|
|
|
return train_count, val_count
|
|
|
|
|
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
class ProjectLabelDal(DalBase):
|
2025-04-11 14:30:48 +08:00
|
|
|
|
"""
|
|
|
|
|
项目标签
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
super(ProjectLabelDal, self).__init__()
|
|
|
|
|
self.db = db
|
|
|
|
|
self.model = models.ProjectLabel
|
|
|
|
|
self.schema = schemas.ProjectLabel
|
|
|
|
|
|
|
|
|
|
async def check_label_name(
|
|
|
|
|
self,
|
|
|
|
|
name: str,
|
|
|
|
|
pro_id: int,
|
|
|
|
|
label_id: int = None
|
|
|
|
|
):
|
|
|
|
|
wheres = [
|
2025-04-17 11:03:05 +08:00
|
|
|
|
self.model.project_id == pro_id,
|
|
|
|
|
self.model.label_name == name
|
2025-04-11 08:54:28 +08:00
|
|
|
|
]
|
|
|
|
|
if label_id:
|
2025-04-17 11:03:05 +08:00
|
|
|
|
wheres.append(self.model.id != label_id)
|
2025-04-11 08:54:28 +08:00
|
|
|
|
count = await self.get_count(v_where=wheres)
|
|
|
|
|
return count > 0
|
|
|
|
|
|
2025-04-17 11:03:05 +08:00
|
|
|
|
async def get_label_for_train(self, project_id: int):
|
|
|
|
|
id_list = []
|
|
|
|
|
name_list = []
|
|
|
|
|
label_list = self.get_datas(
|
|
|
|
|
v_where=[self.model.project_id == project_id],
|
|
|
|
|
v_order='asc',
|
|
|
|
|
v_order_field='id',
|
|
|
|
|
v_return_count=False)
|
|
|
|
|
for label in label_list:
|
|
|
|
|
id_list.append(label.id)
|
|
|
|
|
name_list.append(label.label_name)
|
|
|
|
|
return id_list, name_list
|
|
|
|
|
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
class ProjectImgLabelDal(DalBase):
|
2025-04-11 14:30:48 +08:00
|
|
|
|
"""
|
|
|
|
|
图片标签信息
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
super(ProjectImgLabelDal, self).__init__()
|
|
|
|
|
self.db = db
|
|
|
|
|
self.model = models.ProjectImgLabel
|
2025-04-11 14:30:48 +08:00
|
|
|
|
|
|
|
|
|
async def add_img_label(self, img_label_in: schemas.ProjectImgLeaferLabel):
|
|
|
|
|
# 先把历史数据都删掉,然后再保存
|
|
|
|
|
image_id = img_label_in.image_id
|
|
|
|
|
await self.delete_datas(image_id=image_id)
|
|
|
|
|
img_labels = [self.model(**i.model_dump()) for i in img_label_in.label_infos]
|
|
|
|
|
for img in img_labels:
|
|
|
|
|
img.image_id = image_id
|
|
|
|
|
await self.create_datas(img_labels)
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
2025-04-17 11:03:05 +08:00
|
|
|
|
async def get_img_label_list(self, image_id: int):
|
|
|
|
|
return await self.get_datas(
|
|
|
|
|
v_return_count=False,
|
|
|
|
|
v_where=[self.model.image_id == image_id],
|
|
|
|
|
v_order="asc",
|
|
|
|
|
v_order_field="id")
|
|
|
|
|
|
2025-04-18 17:22:57 +08:00
|
|
|
|
async def del_img_label(self, label_ids: list[int]):
|
|
|
|
|
img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])
|
|
|
|
|
img_label_ids = [i.id for i in img_labels]
|
|
|
|
|
self.delete_datas(ids=img_label_ids)
|
|
|
|
|
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
class ProjectImgLeaferDal(DalBase):
|
2025-04-11 14:30:48 +08:00
|
|
|
|
"""
|
|
|
|
|
图片标注信息-leafer.js
|
|
|
|
|
"""
|
2025-04-11 08:54:28 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
super(ProjectImgLeaferDal, self).__init__()
|
|
|
|
|
self.db = db
|
|
|
|
|
self.model = models.ProjectImgLeafer
|
2025-04-11 14:30:48 +08:00
|
|
|
|
self.schema = schemas.ProjectImgLeaferOut
|
|
|
|
|
|
2025-04-18 17:22:57 +08:00
|
|
|
|
async def get_leafer(self, image_id: int):
|
|
|
|
|
img_label = self.get_data(v_where=[self.model.image_id == image_id])
|
|
|
|
|
return img_label.leafer
|
|
|
|
|
|
2025-04-11 14:30:48 +08:00
|
|
|
|
async def add_leafer(self, img_label_in: schemas.ProjectImgLeaferLabel):
|
|
|
|
|
# 先把历史数据都删掉,然后再保存
|
|
|
|
|
image_id = img_label_in.image_id
|
|
|
|
|
await self.delete_datas(image_id=image_id)
|
|
|
|
|
await self.create_data(data=self.model(**img_label_in.model_dump()))
|