#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:25 # @File : crud.py # @IDE : PyCharm # @desc : 数据访问层 from . import schemas, models, params from apps.vadmin.auth.utils.validation.auth import Auth from utils import os_utils as os, random_utils as ru from utils.huawei_obs import ObsClient from utils import status from core.exception import CustomException from application.settings import datasets_url, runs_url, images_url from typing import Any, List from core.crud import DalBase from fastapi import UploadFile from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, case class ProjectInfoDal(DalBase): """ 项目信息 """ def __init__(self, db: AsyncSession): super(ProjectInfoDal, self).__init__() self.db = db self.model = models.ProjectInfo self.schema = schemas.ProjectInfoOut async def get_project_pager(self, project: params.ProjectInfoParams, auth: Auth): """ 分页查询项目列表 """ # 定义子查询 subquery = ( select( models.ProjectImage.project_id, func.sum(case((models.ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'), func.sum(case((models.ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count') ) .outerjoin(models.ProjectImgLeafer, models.ProjectImage.id == models.ProjectImgLeafer.image_id) .group_by(models.ProjectImage.project_id) .subquery() ) full_query = select( models.ProjectInfo, func.ifnull(subquery.c.mark_count, 0).label("mark_count"), func.ifnull(subquery.c.no_mark_count, 0).label("no_mark_count") ).select_from(models.ProjectInfo).join( subquery, models.ProjectInfo.id == subquery.c.project_id, isouter=True ) v_where = [models.ProjectInfo.is_delete.is_(False)] if '*' in auth.dept_ids: v_where.append(models.ProjectInfo.dept_id.isnot(None)) else: v_where.append(models.ProjectInfo.dept_id.in_(auth.dept_ids)) sql = await self.filter_core( v_start_sql=full_query, v_where=v_where, v_return_sql=True, v_order=project.v_order, v_order_field=project.v_order_field ) count = await self.get_count_sql(sql) if project.limit != 0: sql = sql.offset((project.page - 1) * project.limit).limit(project.limit) queryset = await self.db.execute(sql) result = queryset.all() datas = [] for result in result: data = schemas.ProjectInfoPagerOut.model_validate(result[0]) data.mark_count = int(result[1]) data.no_mark_count = int(result[2]) datas.append(data.model_dump()) return datas, count async def check_name(self, project_name: str): """ 校验项目名称是否重名 """ count = await self.get_count(v_where=[models.ProjectInfo.project_name == project_name, models.ProjectInfo.is_delete is False]) return count > 0 async def add_project( self, project: schemas.ProjectInfoIn, auth: Auth ) -> Any: """ 新建项目 """ obj = self.model(**project.model_dump()) obj.user_id = auth.user.id obj.project_no = ru.random_str(6) obj.project_status = "0" obj.train_version = 0 obj.user_id = auth.user.id if '*' in auth.dept_ids: obj.dept_id = 0 else: obj.dept_id = auth.dept_ids[0] # 新建数据集文件夹 os.create_folder(datasets_url, obj.project_no) # 新建训练文件夹 os.create_folder(runs_url, obj.project_no) await self.flush(obj) return await self.out_dict(obj, None, False, schemas.ProjectInfoOut) class ProjectImageDal(DalBase): """ 项目图片 """ def __init__(self, db: AsyncSession): super(ProjectImageDal, self).__init__() self.db = db self.model = models.ProjectImage self.schema = schemas.ProjectImageOut async def img_page(self, param: params.ProjectImageParams): """ 分页查询图片信息,然后关联一个图片的标签数量 """ subquery = ( select( models.ProjectImgLabel.image_id, func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count') ) .group_by(models.ProjectImgLabel.image_id) .subquery() ) # 2 主查询 query = ( select( models.ProjectImage, func.ifnull(subquery.c.label_count, 0).label('label_count') ) .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id) ) v_where = [models.ProjectImage.project_id == param.project_id, models.ProjectImage.img_type == param.img_type] sql = await self.filter_core( v_start_sql=query, v_where=v_where, v_return_sql=True, v_order=param.v_order, v_order_field=param.v_order_field ) count = await self.get_count_sql(sql) if param.limit != 0: sql = sql.offset((param.page - 1) * param.limit).limit(param.limit) queryset = await self.db.execute(sql) result = queryset.all() datas = [] for result in result: data = schemas.ProjectImageOut.model_validate(result[0]) data.label_count = int(result[1]) datas.append(data.model_dump()) return datas, count async def upload_imgs(self, files: List[UploadFile], pro: schemas.ProjectInfoOut, img_type: str) -> int: """ 上传项目图片 """ image_models = [] for file in files: image = models.ProjectImage() image.project_id = pro.id image.file_name = file.filename image.img_type = img_type # 保存原图 path = os.save_images(images_url, pro.project_no, file=file) image.image_url = path # 上传图片到obs object_key = pro.project_no + '/' + img_type + '/' + file.filename success, key, url = ObsClient.put_file(object_key=object_key, file_path=path) if success: image.object_key = object_key image.thumb_image_url = url else: raise CustomException("obs上传失败", code=status.HTTP_ERROR) image_models.append(image) await self.create_datas(datas=image_models) return len(image_models) async def check_img_name(self, file_name: str, project_id: int, img_type: str): """ 校验相同的项目,相同的文件类型是否有同名的文件 """ count = await self.get_count(v_where=[ models.ProjectImage.file_name == file_name, models.ProjectImage.project_id == project_id, models.ProjectImage.img_type == img_type ]) return count > 0 async def del_img(self, ids: List[int]): """ 删除图片,删除数据库数据,删除本地的文件,删除obs中的文件 """ file_urls = [] object_keys = [] for img_id in ids: image = self.get_data(data_id=img_id) if image: file_urls.append(image.image_url) object_keys.append(image.object_key) os.delete_file_if_exists(**file_urls) ObsClient.del_objects(object_keys) await self.delete_datas(ids) async def get_img_count( self, proj_id: int) -> int: """ 查询图片数量 """ train_count = await self.get_count( v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train']) val_count = await self.get_count( v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val']) return train_count, val_count async def check_image_label( self, proj_id: int) -> int: """ 查询图片未标注数量 """ # 1 子查询 subquery = ( select( models.ProjectImgLabel.image_id, func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count') ) .group_by(models.ProjectImgLabel.image_id) .subquery() ) # 2 主查询 query = ( select( models.ProjectImage, func.ifnull(subquery.c.label_count, 0).label('label_count') ) .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id) ) train_count_sql = await self.filter_core( v_start_sql=query, v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'], v_return_sql=True) train_count = await self.get_count(train_count_sql) val_count_sql = await self.filter_core( v_start_sql=query, v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'], v_return_sql=True) val_count = await self.get_count(val_count_sql) return train_count, val_count class ProjectLabelDal(DalBase): """ 项目标签 """ def __init__(self, db: AsyncSession): super(ProjectLabelDal, self).__init__() self.db = db self.model = models.ProjectLabel self.schema = schemas.ProjectLabel async def check_label_name( self, name: str, pro_id: int, label_id: int = None ): wheres = [ self.model.project_id == pro_id, self.model.label_name == name ] if label_id: wheres.append(self.model.id != label_id) count = await self.get_count(v_where=wheres) return count > 0 async def get_label_for_train(self, project_id: int): id_list = [] name_list = [] label_list = self.get_datas( v_where=[self.model.project_id == project_id], v_order='asc', v_order_field='id', v_return_count=False) for label in label_list: id_list.append(label.id) name_list.append(label.label_name) return id_list, name_list class ProjectImgLabelDal(DalBase): """ 图片标签信息 """ def __init__(self, db: AsyncSession): super(ProjectImgLabelDal, self).__init__() self.db = db self.model = models.ProjectImgLabel async def add_img_label(self, img_label_in: schemas.ProjectImgLeaferLabel): # 先把历史数据都删掉,然后再保存 image_id = img_label_in.image_id await self.delete_datas(image_id=image_id) img_labels = [self.model(**i.model_dump()) for i in img_label_in.label_infos] for img in img_labels: img.image_id = image_id await self.create_datas(img_labels) async def get_img_label_list(self, image_id: int): return await self.get_datas( v_return_count=False, v_where=[self.model.image_id == image_id], v_order="asc", v_order_field="id") class ProjectImgLeaferDal(DalBase): """ 图片标注信息-leafer.js """ def __init__(self, db: AsyncSession): super(ProjectImgLeaferDal, self).__init__() self.db = db self.model = models.ProjectImgLeafer self.schema = schemas.ProjectImgLeaferOut async def add_leafer(self, img_label_in: schemas.ProjectImgLeaferLabel): # 先把历史数据都删掉,然后再保存 image_id = img_label_in.image_id await self.delete_datas(image_id=image_id) await self.create_data(data=self.model(**img_label_in.model_dump()))