2025-04-11 08:54:28 +08:00

107 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import inspect
import sys
from pathlib import Path
from typing import Type
from core.database import Base
from .generate_base import GenerateBase
class DalGenerate(GenerateBase):
def __init__(
self,
model: Type[Base],
zh_name: str,
en_name: str,
dal_class_name: str,
schema_simple_out_class_name: str
):
"""
初始化工作
:param model: 提前定义好的 ORM 模型
:param zh_name: 功能中文名称,主要用于描述、注释
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名dal、url 命名,默认使用 model class
en_name 例子:
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
在命名文件名称时,会执行使用 _ 下划线名称
在命名 class 名称时会将下划线名称转换为大驼峰命名CamelCase
在命名 url 时,会将下划线转换为 /
:param dal_class_name:
:param schema_simple_out_class_name:
"""
self.model = model
self.dal_class_name = dal_class_name
self.schema_simple_out_class_name = schema_simple_out_class_name
self.zh_name = zh_name
self.en_name = en_name
# model 文件的地址
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
# model 文件 app 路径
self.app_dir_path = self.model_file_path.parent.parent
# crud 文件地址
self.crud_file_path = self.app_dir_path / "crud.py"
def write_generate_code(self):
"""
生成 crud 文件,以及代码内容
:return:
"""
if self.crud_file_path.exists():
codes = self.file_code_split_module(self.crud_file_path)
if codes:
print(f"==========dal 文件已存在并已有代码内容,正在追加新代码============")
if not codes[0]:
# 无文件注释则添加文件注释
codes[0] = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
codes[2] += self.get_base_code_content()
code = ''
code += codes[0]
code += self.generate_modules_code(codes[1])
code += codes[2]
self.crud_file_path.write_text(code, "utf-8")
print(f"=================dal 代码已创建完成=======================")
return
self.crud_file_path.touch()
code = self.generate_code()
self.crud_file_path.write_text(code, "utf-8")
print(f"===========================dal 代码创建完成=================================")
def generate_code(self):
"""
代码生成
:return:
"""
code = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
code += self.generate_modules_code(self.get_base_module_config())
code += self.get_base_code_content()
return code
@staticmethod
def get_base_module_config():
"""
获取基础模块导入配置
:return:
"""
modules = {
"sqlalchemy.ext.asyncio": ['AsyncSession'],
"core.crud": ["DalBase"],
".": ["models", "schemas"],
}
return modules
def get_base_code_content(self):
"""
获取基础代码内容
:return:
"""
base_code = f"\n\nclass {self.dal_class_name}(DalBase):\n"
base_code += "\n\tdef __init__(self, db: AsyncSession):"
base_code += f"\n\t\tsuper({self.dal_class_name}, self).__init__()"
base_code += f"\n\t\tself.db = db"
base_code += f"\n\t\tself.model = models.{self.model.__name__}"
base_code += f"\n\t\tself.schema = schemas.{self.schema_simple_out_class_name}"
base_code += "\n"
return base_code.replace("\t", " ")