178 lines
6.0 KiB
Python
178 lines
6.0 KiB
Python
|
import datetime
|
|||
|
import json
|
|||
|
from typing import Any
|
|||
|
from bson import ObjectId
|
|||
|
from bson.errors import InvalidId
|
|||
|
from bson.json_util import dumps
|
|||
|
from fastapi.encoders import jsonable_encoder
|
|||
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|||
|
from pymongo.results import InsertOneResult, UpdateResult
|
|||
|
from core.exception import CustomException
|
|||
|
from utils import status
|
|||
|
|
|||
|
|
|||
|
class MongoManage:
|
|||
|
"""
|
|||
|
mongodb 数据库管理器
|
|||
|
博客:https://www.cnblogs.com/aduner/p/13532504.html
|
|||
|
mongodb 官网:https://www.mongodb.com/docs/drivers/motor/
|
|||
|
motor 文档:https://motor.readthedocs.io/en/stable/
|
|||
|
"""
|
|||
|
|
|||
|
# 倒叙
|
|||
|
ORDER_FIELD = ["desc", "descending"]
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
db: AsyncIOMotorDatabase = None,
|
|||
|
collection: str = None,
|
|||
|
schema: Any = None,
|
|||
|
is_object_id: bool = True
|
|||
|
):
|
|||
|
"""
|
|||
|
初始化
|
|||
|
:param db:
|
|||
|
:param collection: 集合
|
|||
|
:param schema:
|
|||
|
:param is_object_id: _id 列是否为 ObjectId 格式
|
|||
|
"""
|
|||
|
self.db = db
|
|||
|
self.collection = db[collection] if collection else None
|
|||
|
self.schema = schema
|
|||
|
self.is_object_id = is_object_id
|
|||
|
|
|||
|
async def get_data(
|
|||
|
self,
|
|||
|
_id: str = None,
|
|||
|
v_return_none: bool = False,
|
|||
|
v_schema: Any = None,
|
|||
|
**kwargs
|
|||
|
) -> dict | None:
|
|||
|
"""
|
|||
|
获取单个数据,默认使用 ID 查询,否则使用关键词查询
|
|||
|
:param _id: 数据 ID
|
|||
|
:param v_return_none: 是否返回空 None,否则抛出异常,默认抛出异常
|
|||
|
:param v_schema: 指定使用的序列化对象
|
|||
|
"""
|
|||
|
if _id and self.is_object_id:
|
|||
|
kwargs["_id"] = ObjectId(_id)
|
|||
|
params = self.filter_condition(**kwargs)
|
|||
|
data = await self.collection.find_one(params)
|
|||
|
if not data and v_return_none:
|
|||
|
return None
|
|||
|
elif not data:
|
|||
|
raise CustomException("查找失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
|||
|
elif data and v_schema:
|
|||
|
return jsonable_encoder(v_schema(**data))
|
|||
|
return data
|
|||
|
|
|||
|
async def create_data(self, data: dict | Any) -> InsertOneResult:
|
|||
|
"""
|
|||
|
创建数据
|
|||
|
"""
|
|||
|
if not isinstance(data, dict):
|
|||
|
data = jsonable_encoder(data)
|
|||
|
data['create_datetime'] = datetime.datetime.now()
|
|||
|
data['update_datetime'] = datetime.datetime.now()
|
|||
|
result = await self.collection.insert_one(data)
|
|||
|
# 判断插入是否成功
|
|||
|
if result.acknowledged:
|
|||
|
return result
|
|||
|
else:
|
|||
|
raise CustomException("创建新数据失败", code=status.HTTP_ERROR)
|
|||
|
|
|||
|
async def put_data(self, _id: str, data: dict | Any) -> UpdateResult:
|
|||
|
"""
|
|||
|
更新数据
|
|||
|
"""
|
|||
|
if not isinstance(data, dict):
|
|||
|
data = jsonable_encoder(data)
|
|||
|
new_data = {'$set': data}
|
|||
|
result = await self.collection.update_one({'_id': ObjectId(_id) if self.is_object_id else _id}, new_data)
|
|||
|
|
|||
|
if result.matched_count > 0:
|
|||
|
return result
|
|||
|
else:
|
|||
|
raise CustomException("更新失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
|||
|
|
|||
|
async def delete_data(self, _id: str):
|
|||
|
"""
|
|||
|
删除数据
|
|||
|
"""
|
|||
|
result = await self.collection.delete_one({'_id': ObjectId(_id) if self.is_object_id else _id})
|
|||
|
|
|||
|
if result.deleted_count > 0:
|
|||
|
return True
|
|||
|
else:
|
|||
|
raise CustomException("删除失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
|||
|
|
|||
|
async def get_datas(
|
|||
|
self,
|
|||
|
page: int = 1,
|
|||
|
limit: int = 10,
|
|||
|
v_schema: Any = None,
|
|||
|
v_order: str = None,
|
|||
|
v_order_field: str = None,
|
|||
|
v_return_objs: bool = False,
|
|||
|
**kwargs
|
|||
|
):
|
|||
|
"""
|
|||
|
使用 find() 要查询的一组文档。 find() 没有I / O,也不需要 await 表达式。它只是创建一个 AsyncIOMotorCursor 实例
|
|||
|
当您调用 to_list() 或为循环执行异步时 (async for) ,查询实际上是在服务器上执行的。
|
|||
|
"""
|
|||
|
|
|||
|
params = self.filter_condition(**kwargs)
|
|||
|
cursor = self.collection.find(params)
|
|||
|
|
|||
|
if v_order or v_order_field:
|
|||
|
v_order_field = v_order_field if v_order_field else 'create_datetime'
|
|||
|
v_order = -1 if v_order in self.ORDER_FIELD else 1
|
|||
|
cursor.sort(v_order_field, v_order)
|
|||
|
|
|||
|
if limit != 0:
|
|||
|
# 对查询应用排序(sort),跳过(skip)或限制(limit)
|
|||
|
cursor.skip((page - 1) * limit).limit(limit)
|
|||
|
|
|||
|
datas = []
|
|||
|
async for row in cursor:
|
|||
|
data = json.loads(dumps(row))
|
|||
|
datas.append(data)
|
|||
|
|
|||
|
if not datas or v_return_objs:
|
|||
|
return datas
|
|||
|
elif v_schema:
|
|||
|
datas = [jsonable_encoder(v_schema(**data)) for data in datas]
|
|||
|
elif self.schema:
|
|||
|
datas = [jsonable_encoder(self.schema(**data)) for data in datas]
|
|||
|
return datas
|
|||
|
|
|||
|
async def get_count(self, **kwargs) -> int:
|
|||
|
"""
|
|||
|
获取统计数据
|
|||
|
"""
|
|||
|
params = self.filter_condition(**kwargs)
|
|||
|
return await self.collection.count_documents(params)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def filter_condition(cls, **kwargs):
|
|||
|
"""
|
|||
|
过滤条件
|
|||
|
"""
|
|||
|
params = {}
|
|||
|
for k, v in kwargs.items():
|
|||
|
if not v:
|
|||
|
continue
|
|||
|
elif isinstance(v, tuple):
|
|||
|
if v[0] == "like" and v[1]:
|
|||
|
params[k] = {'$regex': v[1]}
|
|||
|
elif v[0] == "between" and len(v[1]) == 2:
|
|||
|
params[k] = {'$gte': f"{v[1][0]} 00:00:00", '$lt': f"{v[1][1]} 23:59:59"}
|
|||
|
elif v[0] == "ObjectId" and v[1]:
|
|||
|
try:
|
|||
|
params[k] = ObjectId(v[1])
|
|||
|
except InvalidId:
|
|||
|
raise CustomException("任务编号格式不正确!")
|
|||
|
else:
|
|||
|
params[k] = v
|
|||
|
return params
|