373 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:25
# @File : crud.py
# @IDE : PyCharm
# @desc : 数据访问层
from . import schemas, models, params
from apps.vadmin.auth.utils.validation.auth import Auth
from utils import os_utils as os, random_utils as ru
from utils.huawei_obs import MyObs
from utils import status
from core.exception import CustomException
from application.settings import datasets_url, runs_url, images_url
from typing import Any, List
from core.crud import DalBase
from fastapi import UploadFile
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, case
class ProjectInfoDal(DalBase):
"""
项目信息
"""
def __init__(self, db: AsyncSession):
super(ProjectInfoDal, self).__init__()
self.db = db
self.model = models.ProjectInfo
self.schema = schemas.ProjectInfoOut
async def get_project_pager(self, project: params.ProjectInfoParams, auth: Auth):
"""
分页查询项目列表
"""
# 定义子查询
subquery = (
select(
models.ProjectImage.project_id,
func.sum(case((models.ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
func.sum(case((models.ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
)
.outerjoin(models.ProjectImgLeafer, models.ProjectImage.id == models.ProjectImgLeafer.image_id)
.group_by(models.ProjectImage.project_id)
.subquery()
)
full_query = select(
models.ProjectInfo,
func.ifnull(subquery.c.mark_count, 0).label("mark_count"),
func.ifnull(subquery.c.no_mark_count, 0).label("no_mark_count")
).select_from(models.ProjectInfo).join(
subquery, models.ProjectInfo.id == subquery.c.project_id, isouter=True
)
v_where = [models.ProjectInfo.is_delete.is_(False)]
if '*' in auth.dept_ids:
v_where.append(models.ProjectInfo.dept_id.isnot(None))
else:
v_where.append(models.ProjectInfo.dept_id.in_(auth.dept_ids))
sql = await self.filter_core(
v_start_sql=full_query,
v_where=v_where,
v_return_sql=True,
v_order=project.v_order,
v_order_field=project.v_order_field
)
count = await self.get_count_sql(sql)
if project.limit != 0:
sql = sql.offset((project.page - 1) * project.limit).limit(project.limit)
queryset = await self.db.execute(sql)
result = queryset.all()
datas = []
for result in result:
data = schemas.ProjectInfoPagerOut.model_validate(result[0])
data.mark_count = int(result[1])
data.no_mark_count = int(result[2])
datas.append(data.model_dump())
return datas, count
async def check_name(self, project_name: str):
"""
校验项目名称是否重名
"""
count = await self.get_count(v_where=[models.ProjectInfo.project_name == project_name,
models.ProjectInfo.is_delete is False])
return count > 0
async def add_project(
self,
project: schemas.ProjectInfoIn,
auth: Auth
) -> Any:
"""
新建项目
"""
obj = self.model(**project.model_dump())
obj.user_id = auth.user.id
obj.project_no = ru.random_str(6)
obj.project_status = "0"
obj.train_version = 0
obj.user_id = auth.user.id
if '*' in auth.dept_ids:
obj.dept_id = 0
else:
obj.dept_id = auth.dept_ids[0]
# 新建数据集文件夹
os.create_folder(datasets_url, obj.project_no)
# 新建训练文件夹
os.create_folder(runs_url, obj.project_no)
await self.flush(obj)
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
async def update_version(self, data_id):
proj = await self.get_data(data_id)
if proj:
proj.train_version = proj.train_version + 1
await self.put_data(data_id=data_id, data=proj)
class ProjectImageDal(DalBase):
"""
项目图片
"""
def __init__(self, db: AsyncSession):
super(ProjectImageDal, self).__init__()
self.db = db
self.model = models.ProjectImage
self.schema = schemas.ProjectImageOut
async def img_page(self, param: params.ProjectImageParams):
"""
分页查询图片信息,然后关联一个图片的标签数量
"""
subquery = (
select(
models.ProjectImgLabel.image_id,
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
)
.group_by(models.ProjectImgLabel.image_id)
.subquery()
)
# 2 主查询
query = (
select(
models.ProjectImage,
func.ifnull(subquery.c.label_count, 0).label('label_count')
)
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
)
v_where = [models.ProjectImage.project_id == param.project_id, models.ProjectImage.img_type == param.img_type]
sql = await self.filter_core(
v_start_sql=query,
v_where=v_where,
v_return_sql=True,
v_order=param.v_order,
v_order_field=param.v_order_field
)
count = await self.get_count_sql(sql)
if param.limit != 0:
sql = sql.offset((param.page - 1) * param.limit).limit(param.limit)
queryset = await self.db.execute(sql)
result = queryset.all()
datas = []
for result in result:
data = schemas.ProjectImageOut.model_validate(result[0])
data.label_count = int(result[1])
datas.append(data.model_dump())
return datas, count
async def upload_imgs(self, files: List[UploadFile], pro: schemas.ProjectInfoOut, img_type: str) -> int:
"""
上传项目图片
"""
image_models = []
obs = MyObs()
for file in files:
image = models.ProjectImage()
image.project_id = pro.id
image.file_name = file.filename
image.img_type = img_type
# 保存原图
path = os.save_images(images_url, pro.project_no, file=file)
image.image_url = path
# 上传图片到obs
object_key = pro.project_no + '/' + img_type + '/' + file.filename
success, key, url = obs.put_file(object_key=object_key, file_path=path)
if success:
image.object_key = object_key
image.thumb_image_url = url
else:
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
image_models.append(image)
await self.create_models(datas=image_models)
return len(image_models)
async def check_img_name(self, file_name: str, project_id: int, img_type: str):
"""
校验相同的项目,相同的文件类型是否有同名的文件
"""
count = await self.get_count(v_where=[
models.ProjectImage.file_name == file_name,
models.ProjectImage.project_id == project_id,
models.ProjectImage.img_type == img_type
])
return count > 0
async def del_img(self, ids: List[int]):
"""
删除图片删除数据库数据删除本地的文件删除obs中的文件
"""
file_urls = []
object_keys = []
for img_id in ids:
image = await self.get_data(data_id=img_id)
if image:
file_urls.append(image.image_url)
object_keys.append(image.object_key)
os.delete_file_if_exists(*file_urls)
MyObs().del_objects(object_keys)
await self.delete_datas(ids)
async def get_img_count(
self,
proj_id: int) -> int:
"""
查询图片数量
"""
train_count = await self.get_count(
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'])
val_count = await self.get_count(
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'])
return train_count, val_count
async def check_image_label(
self,
proj_id: int) -> int:
"""
查询图片未标注数量
"""
# 1 子查询
subquery = (
select(
models.ProjectImgLabel.image_id,
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
)
.group_by(models.ProjectImgLabel.image_id)
.subquery()
)
# 2 主查询
query = (
select(
models.ProjectImage.id,
func.ifnull(subquery.c.label_count, 0).label('label_count')
)
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
)
train_count_sql = await self.filter_core(
v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'train',
subquery.c.label_count == 0],
v_return_sql=True)
train_count = await self.get_count_sql(train_count_sql)
val_count_sql = await self.filter_core(
v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'val',
subquery.c.label_count == 0],
v_return_sql=True)
val_count = await self.get_count_sql(val_count_sql)
return train_count, val_count
class ProjectLabelDal(DalBase):
"""
项目标签
"""
def __init__(self, db: AsyncSession):
super(ProjectLabelDal, self).__init__()
self.db = db
self.model = models.ProjectLabel
self.schema = schemas.ProjectLabel
async def check_label_name(
self,
name: str,
pro_id: int,
label_id: int = None
):
wheres = [
self.model.project_id == pro_id,
self.model.label_name == name
]
if label_id:
wheres.append(self.model.id != label_id)
count = await self.get_count(v_where=wheres)
return count > 0
async def get_label_for_train(self, project_id: int):
id_list = []
name_list = []
label_list = await self.get_datas(
v_where=[self.model.project_id == project_id],
limit=0,
v_order='asc',
v_order_field='id',
v_return_count=False)
for label in label_list:
id_list.append(label['id'])
name_list.append(label['label_name'])
return id_list, name_list
class ProjectImgLabelDal(DalBase):
"""
图片标签信息
"""
def __init__(self, db: AsyncSession):
super(ProjectImgLabelDal, self).__init__()
self.db = db
self.model = models.ProjectImgLabel
async def add_img_label(self, img_label_in: schemas.ProjectImgLeaferLabel):
# 先把历史数据都删掉,然后再保存
image_id = img_label_in.image_id
await self.delete_datas(image_id=image_id)
img_labels = [self.model(**i.model_dump()) for i in img_label_in.label_infos]
for img in img_labels:
img.image_id = image_id
await self.create_datas(img_labels)
async def get_img_label_list(self, image_id: int):
return await self.get_datas(
limit=0,
v_return_count=False,
v_where=[self.model.image_id == image_id],
v_order="asc",
v_order_field="id",
v_return_objs=True)
async def del_img_label(self, label_ids: list[int]):
img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])
img_label_ids = [i.id for i in img_labels]
self.delete_datas(ids=img_label_ids)
class ProjectImgLeaferDal(DalBase):
"""
图片标注信息-leafer.js
"""
def __init__(self, db: AsyncSession):
super(ProjectImgLeaferDal, self).__init__()
self.db = db
self.model = models.ProjectImgLeafer
self.schema = schemas.ProjectImgLeaferOut
async def get_leafer(self, image_id: int):
img_label = self.get_data(v_where=[self.model.image_id == image_id])
return img_label.leafer
async def add_leafer(self, img_label_in: schemas.ProjectImgLeaferLabel):
# 先把历史数据都删掉,然后再保存
image_id = img_label_in.image_id
await self.delete_datas(image_id=image_id)
await self.create_data(data=self.model(**img_label_in.model_dump()))