499 lines
20 KiB
Python
Raw Normal View History

2025-04-11 08:54:28 +08:00
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Update Time : 2023/8/21 22:18
# @File : crud.py
# @IDE : PyCharm
# @desc : 数据库 增删改查操作
# sqlalchemy 官方文档https://docs.sqlalchemy.org/en/20/index.html
# sqlalchemy 查询操作(官方文档): https://docs.sqlalchemy.org/en/20/orm/queryguide/select.html
# sqlalchemy 增删改操作https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html
# sqlalchemy 1.x 语法迁移到 2.x :https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-query-usage
import datetime
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder
from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select, false, insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.strategy_options import _AbstractLoad
from starlette import status
from core.exception import CustomException
from sqlalchemy.sql.selectable import Select as SelectType
from typing import Any, Union
class DalBase:
# 倒叙
ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None):
self.db = db
self.model = model
self.schema = schema
async def get_data(
self,
data_id: int = None,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[Any] = None,
v_outer_join: list[Any] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_none: bool = False,
v_schema: Any = None,
v_expire_all: bool = False,
**kwargs
) -> Any:
"""
获取单个数据默认使用 ID 查询否则使用关键词查询
:param data_id: 数据 ID
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
:param v_order_field: 排序字段
:param v_return_none: 是否返回空 None否认 抛出异常默认抛出异常
:param v_schema: 指定使用的序列化对象
:param v_expire_all: 使当前会话Session中所有已加载的对象过期确保您获取的是数据库中的最新数据但可能会有性能损耗博客https://blog.csdn.net/k_genius/article/details/135490378
:param kwargs: 查询参数
:return: 默认返回 ORM 对象如果存在 v_schema 则会返回 v_schema 结果
"""
if v_expire_all:
self.db.expire_all()
if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == false())
if data_id is not None:
v_start_sql = v_start_sql.where(self.model.id == data_id)
queryset: ScalarResult = await self.filter_core(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outer_join=v_outer_join,
v_options=v_options,
v_where=v_where,
v_order=v_order,
v_order_field=v_order_field,
v_return_sql=False,
**kwargs
)
if v_options:
data = queryset.unique().first()
else:
data = queryset.first()
if not data and v_return_none:
return None
if data and v_schema:
return v_schema.model_validate(data).model_dump()
if data:
return data
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到此数据")
async def get_datas(
self,
page: int = 1,
limit: int = 10,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[Any] = None,
v_outer_join: list[Any] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_count: bool = False,
v_return_scalars: bool = False,
v_return_objs: bool = False,
v_schema: Any = None,
v_distinct: bool = False,
v_expire_all: bool = False,
**kwargs
) -> Union[list[Any], ScalarResult, tuple]:
"""
获取数据列表
:param page: 页码
:param limit: 当前页数据量
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
:param v_order_field: 排序字段
:param v_return_count: 默认为 False是否返回 count 过滤后的数据总数不会影响其他返回结果会一起返回为一个数组
:param v_return_scalars: 返回scalars后的结果
:param v_return_objs: 是否返回对象
:param v_schema: 指定使用的序列化对象
:param v_distinct: 是否结果去重
:param v_expire_all: 使当前会话Session中所有已加载的对象过期确保您获取的是数据库中的最新数据但可能会有性能损耗博客https://blog.csdn.net/k_genius/article/details/135490378
:param kwargs: 查询参数使用的是自定义表达式
:return: 返回值优先级v_return_scalars > v_return_objs > v_schema
"""
if v_expire_all:
self.db.expire_all()
sql: SelectType = await self.filter_core(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outer_join=v_outer_join,
v_options=v_options,
v_where=v_where,
v_order=v_order,
v_order_field=v_order_field,
v_return_sql=True,
**kwargs
)
if v_distinct:
sql = sql.distinct()
count = 0
if v_return_count:
count_sql = select(func.count()).select_from(sql.alias())
count_queryset = await self.db.execute(count_sql)
count = count_queryset.one()[0]
if limit != 0:
sql = sql.offset((page - 1) * limit).limit(limit)
queryset = await self.db.scalars(sql)
if v_return_scalars:
if v_return_count:
return queryset, count
return queryset
if v_options:
result = queryset.unique().all()
else:
result = queryset.all()
if v_return_objs:
if v_return_count:
return list(result), count
return list(result)
datas = [await self.out_dict(i, v_schema=v_schema) for i in result]
if v_return_count:
return datas, count
return datas
async def get_count_sql(
self,
v_start_sql: SelectType
) -> int:
count_sql = select(func.count()).select_from(v_start_sql.alias())
count_queryset = await self.db.execute(count_sql)
count = count_queryset.one()[0]
return count
async def get_count(
self,
v_select_from: list[Any] = None,
v_join: list[Any] = None,
v_outer_join: list[Any] = None,
v_where: list[BinaryExpression] = None,
**kwargs
) -> int:
"""
获取数据总数
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_where: 当前表查询条件原始表达式
:param kwargs: 查询参数
"""
v_start_sql = select(func.count(self.model.id))
sql = await self.filter_core(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outer_join=v_outer_join,
v_where=v_where,
v_return_sql=True,
**kwargs
)
queryset = await self.db.execute(sql)
return queryset.one()[0]
async def create_data(
self,
data,
v_options: list[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
"""
创建单个数据
:param data: 创建数据
:param v_options: 指示应使用select在预加载中加载给定的属性
:param v_schema: 指定使用的序列化对象
:param v_return_obj: 是否返回对象
"""
if isinstance(data, dict):
obj = self.model(**data)
else:
obj = self.model(**data.model_dump())
await self.flush(obj)
return await self.out_dict(obj, v_options, v_return_obj, v_schema)
async def create_datas(self, datas: list[dict]) -> None:
"""
批量创建数据
SQLAlchemy 2.0 批量插入不支持 MySQL 返回值
https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
:param datas: 字典数据列表
"""
await self.db.execute(insert(self.model), datas)
await self.db.flush()
async def put_data(
self,
data_id: int,
data: Any,
v_options: list[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
"""
更新单个数据
:param data_id: 修改行数据的 ID
:param data: 数据内容
:param v_options: 指示应使用select在预加载中加载给定的属性
:param v_return_obj: 是否返回对象
:param v_schema: 指定使用的序列化对象
"""
obj = await self.get_data(data_id, v_options=v_options)
obj_dict = jsonable_encoder(data)
for key, value in obj_dict.items():
setattr(obj, key, value)
await self.flush(obj)
return await self.out_dict(obj, None, v_return_obj, v_schema)
async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None:
"""
删除多条数据
:param ids: 数据集
:param v_soft: 是否执行软删除
:param kwargs: 其他更新字段
"""
if v_soft:
await self.db.execute(
update(self.model).where(self.model.id.in_(ids)).values(
delete_datetime=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
is_delete=True,
**kwargs
)
)
else:
await self.db.execute(delete(self.model).where(self.model.id.in_(ids)))
await self.flush()
async def flush(self, obj: Any = None) -> Any:
"""
刷新到数据库
:param obj:
:return:
"""
if obj:
self.db.add(obj)
await self.db.flush()
if obj:
# 使用 get_data 或者 get_datas 获取到实例后如果更新了实例,并需要序列化实例,那么需要执行 refresh 刷新才能正常序列化
await self.db.refresh(obj)
return obj
async def out_dict(
self,
obj: Any,
v_options: list[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
"""
序列化
:param obj:
:param v_options: 指示应使用select在预加载中加载给定的属性
:param v_return_obj: 是否返回对象
:param v_schema: 指定使用的序列化对象
:return:
"""
if v_options:
obj = await self.get_data(obj.id, v_options=v_options)
if v_return_obj:
return obj
if v_schema:
return v_schema.model_validate(obj).model_dump()
return self.schema.model_validate(obj).model_dump()
async def filter_core(
self,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[Any] = None,
v_outer_join: list[Any] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_sql: bool = False,
**kwargs
) -> Union[ScalarResult, SelectType]:
"""
数据过滤核心功能
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
:param v_order_field: 排序字段
:param v_return_sql: 是否直接返回 sql
:return: 返回过滤后的总数居 sql
"""
if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == false())
sql = self.add_relation(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outer_join=v_outer_join,
v_options=v_options
)
if v_where:
sql = sql.where(*v_where)
sql = self.add_filter_condition(sql, **kwargs)
if v_order_field and (v_order in self.ORDER_FIELD):
sql = sql.order_by(getattr(self.model, v_order_field).desc(), self.model.id.desc())
elif v_order_field:
sql = sql.order_by(getattr(self.model, v_order_field), self.model.id)
elif v_order in self.ORDER_FIELD:
sql = sql.order_by(self.model.id.desc())
if v_return_sql:
return sql
queryset = await self.db.scalars(sql)
return queryset
def add_relation(
self,
v_start_sql: SelectType,
v_select_from: list[Any] = None,
v_join: list[Any] = None,
v_outer_join: list[Any] = None,
v_options: list[_AbstractLoad] = None,
) -> SelectType:
"""
关系查询关系加载
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
"""
if v_select_from:
v_start_sql = v_start_sql.select_from(*v_select_from)
if v_join:
for relation in v_join:
table = relation[0]
if isinstance(table, str):
table = getattr(self.model, table)
if len(relation) == 2:
v_start_sql = v_start_sql.join(table, relation[1])
else:
v_start_sql = v_start_sql.join(table)
if v_outer_join:
for relation in v_outer_join:
table = relation[0]
if isinstance(table, str):
table = getattr(self.model, table)
if len(relation) == 2:
v_start_sql = v_start_sql.outerjoin(table, relation[1])
else:
v_start_sql = v_start_sql.outerjoin(table)
if v_options:
v_start_sql = v_start_sql.options(*v_options)
return v_start_sql
def add_filter_condition(self, sql: SelectType, **kwargs) -> SelectType:
"""
添加过滤条件
:param sql:
:param kwargs: 关键词参数
"""
conditions = self.__dict_filter(**kwargs)
if conditions:
sql = sql.where(*conditions)
return sql
def __dict_filter(self, **kwargs) -> list[BinaryExpression]:
"""
字典过滤
:param model:
:param kwargs:
"""
conditions = []
for field, value in kwargs.items():
if value is not None and value != "":
attr = getattr(self.model, field)
if isinstance(value, tuple):
if len(value) == 1:
if value[0] == "None":
conditions.append(attr.is_(None))
elif value[0] == "not None":
conditions.append(attr.isnot(None))
else:
raise CustomException("SQL查询语法错误")
elif len(value) == 2 and value[1] not in [None, [], ""]:
if value[0] == "date":
# 根据日期查询, 关键函数是func.time_format和func.date_format
conditions.append(func.date_format(attr, "%Y-%m-%d") == value[1])
elif value[0] == "like":
conditions.append(attr.like(f"%{value[1]}%"))
elif value[0] == "in":
conditions.append(attr.in_(value[1]))
elif value[0] == "between" and len(value[1]) == 2:
conditions.append(attr.between(value[1][0], value[1][1]))
elif value[0] == "month":
conditions.append(func.date_format(attr, "%Y-%m") == value[1])
elif value[0] == "!=":
conditions.append(attr != value[1])
elif value[0] == ">":
conditions.append(attr > value[1])
elif value[0] == ">=":
conditions.append(attr >= value[1])
elif value[0] == "<=":
conditions.append(attr <= value[1])
else:
raise CustomException("SQL查询语法错误")
else:
conditions.append(attr == value)
return conditions