aicheckv2-api/core/mongo_manage.py
2025-04-11 08:54:28 +08:00

178 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
import json
from typing import Any
from bson import ObjectId
from bson.errors import InvalidId
from bson.json_util import dumps
from fastapi.encoders import jsonable_encoder
from motor.motor_asyncio import AsyncIOMotorDatabase
from pymongo.results import InsertOneResult, UpdateResult
from core.exception import CustomException
from utils import status
class MongoManage:
"""
mongodb 数据库管理器
博客https://www.cnblogs.com/aduner/p/13532504.html
mongodb 官网https://www.mongodb.com/docs/drivers/motor/
motor 文档https://motor.readthedocs.io/en/stable/
"""
# 倒叙
ORDER_FIELD = ["desc", "descending"]
def __init__(
self,
db: AsyncIOMotorDatabase = None,
collection: str = None,
schema: Any = None,
is_object_id: bool = True
):
"""
初始化
:param db:
:param collection: 集合
:param schema:
:param is_object_id: _id 列是否为 ObjectId 格式
"""
self.db = db
self.collection = db[collection] if collection else None
self.schema = schema
self.is_object_id = is_object_id
async def get_data(
self,
_id: str = None,
v_return_none: bool = False,
v_schema: Any = None,
**kwargs
) -> dict | None:
"""
获取单个数据,默认使用 ID 查询,否则使用关键词查询
:param _id: 数据 ID
:param v_return_none: 是否返回空 None否则抛出异常默认抛出异常
:param v_schema: 指定使用的序列化对象
"""
if _id and self.is_object_id:
kwargs["_id"] = ObjectId(_id)
params = self.filter_condition(**kwargs)
data = await self.collection.find_one(params)
if not data and v_return_none:
return None
elif not data:
raise CustomException("查找失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
elif data and v_schema:
return jsonable_encoder(v_schema(**data))
return data
async def create_data(self, data: dict | Any) -> InsertOneResult:
"""
创建数据
"""
if not isinstance(data, dict):
data = jsonable_encoder(data)
data['create_datetime'] = datetime.datetime.now()
data['update_datetime'] = datetime.datetime.now()
result = await self.collection.insert_one(data)
# 判断插入是否成功
if result.acknowledged:
return result
else:
raise CustomException("创建新数据失败", code=status.HTTP_ERROR)
async def put_data(self, _id: str, data: dict | Any) -> UpdateResult:
"""
更新数据
"""
if not isinstance(data, dict):
data = jsonable_encoder(data)
new_data = {'$set': data}
result = await self.collection.update_one({'_id': ObjectId(_id) if self.is_object_id else _id}, new_data)
if result.matched_count > 0:
return result
else:
raise CustomException("更新失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
async def delete_data(self, _id: str):
"""
删除数据
"""
result = await self.collection.delete_one({'_id': ObjectId(_id) if self.is_object_id else _id})
if result.deleted_count > 0:
return True
else:
raise CustomException("删除失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
async def get_datas(
self,
page: int = 1,
limit: int = 10,
v_schema: Any = None,
v_order: str = None,
v_order_field: str = None,
v_return_objs: bool = False,
**kwargs
):
"""
使用 find() 要查询的一组文档。 find() 没有I / O也不需要 await 表达式。它只是创建一个 AsyncIOMotorCursor 实例
当您调用 to_list() 或为循环执行异步时 (async for) ,查询实际上是在服务器上执行的。
"""
params = self.filter_condition(**kwargs)
cursor = self.collection.find(params)
if v_order or v_order_field:
v_order_field = v_order_field if v_order_field else 'create_datetime'
v_order = -1 if v_order in self.ORDER_FIELD else 1
cursor.sort(v_order_field, v_order)
if limit != 0:
# 对查询应用排序(sort),跳过(skip)或限制(limit)
cursor.skip((page - 1) * limit).limit(limit)
datas = []
async for row in cursor:
data = json.loads(dumps(row))
datas.append(data)
if not datas or v_return_objs:
return datas
elif v_schema:
datas = [jsonable_encoder(v_schema(**data)) for data in datas]
elif self.schema:
datas = [jsonable_encoder(self.schema(**data)) for data in datas]
return datas
async def get_count(self, **kwargs) -> int:
"""
获取统计数据
"""
params = self.filter_condition(**kwargs)
return await self.collection.count_documents(params)
@classmethod
def filter_condition(cls, **kwargs):
"""
过滤条件
"""
params = {}
for k, v in kwargs.items():
if not v:
continue
elif isinstance(v, tuple):
if v[0] == "like" and v[1]:
params[k] = {'$regex': v[1]}
elif v[0] == "between" and len(v[1]) == 2:
params[k] = {'$gte': f"{v[1][0]} 00:00:00", '$lt': f"{v[1][1]} 23:59:59"}
elif v[0] == "ObjectId" and v[1]:
try:
params[k] = ObjectId(v[1])
except InvalidId:
raise CustomException("任务编号格式不正确!")
else:
params[k] = v
return params