107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
import inspect
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Type
|
||
from core.database import Base
|
||
from .generate_base import GenerateBase
|
||
|
||
|
||
class DalGenerate(GenerateBase):
|
||
|
||
def __init__(
|
||
self,
|
||
model: Type[Base],
|
||
zh_name: str,
|
||
en_name: str,
|
||
dal_class_name: str,
|
||
schema_simple_out_class_name: str
|
||
):
|
||
"""
|
||
初始化工作
|
||
:param model: 提前定义好的 ORM 模型
|
||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||
en_name 例子:
|
||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||
在命名文件名称时,会执行使用 _ 下划线名称
|
||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||
在命名 url 时,会将下划线转换为 /
|
||
:param dal_class_name:
|
||
:param schema_simple_out_class_name:
|
||
"""
|
||
self.model = model
|
||
self.dal_class_name = dal_class_name
|
||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||
self.zh_name = zh_name
|
||
self.en_name = en_name
|
||
# model 文件的地址
|
||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||
# model 文件 app 路径
|
||
self.app_dir_path = self.model_file_path.parent.parent
|
||
# crud 文件地址
|
||
self.crud_file_path = self.app_dir_path / "crud.py"
|
||
|
||
def write_generate_code(self):
|
||
"""
|
||
生成 crud 文件,以及代码内容
|
||
:return:
|
||
"""
|
||
if self.crud_file_path.exists():
|
||
codes = self.file_code_split_module(self.crud_file_path)
|
||
if codes:
|
||
print(f"==========dal 文件已存在并已有代码内容,正在追加新代码============")
|
||
if not codes[0]:
|
||
# 无文件注释则添加文件注释
|
||
codes[0] = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
|
||
codes[2] += self.get_base_code_content()
|
||
code = ''
|
||
code += codes[0]
|
||
code += self.generate_modules_code(codes[1])
|
||
code += codes[2]
|
||
self.crud_file_path.write_text(code, "utf-8")
|
||
print(f"=================dal 代码已创建完成=======================")
|
||
return
|
||
self.crud_file_path.touch()
|
||
code = self.generate_code()
|
||
self.crud_file_path.write_text(code, "utf-8")
|
||
print(f"===========================dal 代码创建完成=================================")
|
||
|
||
def generate_code(self):
|
||
"""
|
||
代码生成
|
||
:return:
|
||
"""
|
||
code = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||
code += self.generate_modules_code(self.get_base_module_config())
|
||
code += self.get_base_code_content()
|
||
return code
|
||
|
||
@staticmethod
|
||
def get_base_module_config():
|
||
"""
|
||
获取基础模块导入配置
|
||
:return:
|
||
"""
|
||
modules = {
|
||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||
"core.crud": ["DalBase"],
|
||
".": ["models", "schemas"],
|
||
}
|
||
return modules
|
||
|
||
def get_base_code_content(self):
|
||
"""
|
||
获取基础代码内容
|
||
:return:
|
||
"""
|
||
base_code = f"\n\nclass {self.dal_class_name}(DalBase):\n"
|
||
base_code += "\n\tdef __init__(self, db: AsyncSession):"
|
||
base_code += f"\n\t\tsuper({self.dal_class_name}, self).__init__()"
|
||
base_code += f"\n\t\tself.db = db"
|
||
base_code += f"\n\t\tself.model = models.{self.model.__name__}"
|
||
base_code += f"\n\t\tself.schema = schemas.{self.schema_simple_out_class_name}"
|
||
base_code += "\n"
|
||
return base_code.replace("\t", " ")
|
||
|