完成项目训练模块的接口测试

This commit is contained in:
sunyugang 2025-04-22 10:11:44 +08:00
parent 7a9e571a96
commit 0033746fe1
11 changed files with 139 additions and 91 deletions

View File

@ -6,9 +6,10 @@
# @IDE : PyCharm # @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作 # @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict from core.data_types import DatetimeStr
from typing import Optional from typing import Optional
from datetime import datetime from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectIn(BaseModel): class ProjectDetectIn(BaseModel):
@ -36,7 +37,7 @@ class ProjectDetectOut(BaseModel):
file_type: Optional[str] file_type: Optional[str]
folder_url: Optional[str] folder_url: Optional[str]
rtsp_url: Optional[str] rtsp_url: Optional[str]
create_time: Optional[datetime] create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -5,9 +5,11 @@
# @File : project_detect_file.py # @File : project_detect_file.py
# @IDE : PyCharm # @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作 # @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict
from core.data_types import DatetimeStr
from typing import Optional from typing import Optional
from datetime import datetime from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectFilePager(BaseModel): class ProjectDetectFilePager(BaseModel):
@ -21,6 +23,6 @@ class ProjectDetectFileOut(BaseModel):
detect_id: Optional[int] = Field(..., description="训练集合id") detect_id: Optional[int] = Field(..., description="训练集合id")
file_name: Optional[str] = Field(None, description="文件名称") file_name: Optional[str] = Field(None, description="文件名称")
thumb_file_url: Optional[str] = Field(None, description="文件路径") thumb_file_url: Optional[str] = Field(None, description="文件路径")
create_time: Optional[datetime] = Field(None, description="上传时间") create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -6,9 +6,10 @@
# @IDE : PyCharm # @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作 # @desc : pydantic 模型,用于数据库序列化操作
from pydantic import BaseModel, Field, ConfigDict from core.data_types import DatetimeStr
from typing import Optional from typing import Optional
from datetime import datetime from pydantic import BaseModel, Field, ConfigDict
class ProjectDetectLogIn(BaseModel): class ProjectDetectLogIn(BaseModel):
@ -25,6 +26,6 @@ class ProjectDetectLogOut(BaseModel):
train_id: Optional[int] train_id: Optional[int]
train_version: Optional[str] train_version: Optional[str]
pt_type: Optional[str] pt_type: Optional[str]
create_time: Optional[datetime] create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -6,15 +6,16 @@
# @IDE : PyCharm # @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作 # @desc : pydantic 模型,用于数据库序列化操作
from core.data_types import DatetimeStr
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import Optional from typing import Optional
from datetime import datetime
class ProjectDetectLogFileOut(BaseModel): class ProjectDetectLogFileOut(BaseModel):
id: Optional[int] id: Optional[int]
file_name: Optional[str] file_name: Optional[str]
thumb_file_url: Optional[str] thumb_file_url: Optional[str]
create_time: Optional[datetime] create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -17,7 +17,6 @@ from utils.response import SuccessResponse, ErrorResponse
import threading import threading
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, APIRouter, Form, UploadFile from fastapi import Depends, APIRouter, Form, UploadFile

View File

@ -112,6 +112,12 @@ class ProjectInfoDal(DalBase):
await self.flush(obj) await self.flush(obj)
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut) return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
async def update_version(self, data_id):
proj = await self.get_data(data_id)
if proj:
proj.train_version = proj.train_version + 1
await self.put_data(data_id=data_id, data=proj)
class ProjectImageDal(DalBase): class ProjectImageDal(DalBase):
""" """
@ -246,22 +252,26 @@ class ProjectImageDal(DalBase):
# 2 主查询 # 2 主查询
query = ( query = (
select( select(
models.ProjectImage, models.ProjectImage.id,
func.ifnull(subquery.c.label_count, 0).label('label_count') func.ifnull(subquery.c.label_count, 0).label('label_count')
) )
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id) .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
) )
train_count_sql = await self.filter_core( train_count_sql = await self.filter_core(
v_start_sql=query, v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'], v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'train',
subquery.c.label_count == 0],
v_return_sql=True) v_return_sql=True)
train_count = await self.get_count(train_count_sql) train_count = await self.get_count_sql(train_count_sql)
val_count_sql = await self.filter_core( val_count_sql = await self.filter_core(
v_start_sql=query, v_start_sql=query,
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'], v_where=[models.ProjectImage.project_id == proj_id,
models.ProjectImage.img_type == 'val',
subquery.c.label_count == 0],
v_return_sql=True) v_return_sql=True)
val_count = await self.get_count(val_count_sql) val_count = await self.get_count_sql(val_count_sql)
return train_count, val_count return train_count, val_count
@ -295,14 +305,15 @@ class ProjectLabelDal(DalBase):
async def get_label_for_train(self, project_id: int): async def get_label_for_train(self, project_id: int):
id_list = [] id_list = []
name_list = [] name_list = []
label_list = self.get_datas( label_list = await self.get_datas(
v_where=[self.model.project_id == project_id], v_where=[self.model.project_id == project_id],
limit=0,
v_order='asc', v_order='asc',
v_order_field='id', v_order_field='id',
v_return_count=False) v_return_count=False)
for label in label_list: for label in label_list:
id_list.append(label.id) id_list.append(label['id'])
name_list.append(label.label_name) name_list.append(label['label_name'])
return id_list, name_list return id_list, name_list
@ -326,10 +337,12 @@ class ProjectImgLabelDal(DalBase):
async def get_img_label_list(self, image_id: int): async def get_img_label_list(self, image_id: int):
return await self.get_datas( return await self.get_datas(
limit=0,
v_return_count=False, v_return_count=False,
v_where=[self.model.image_id == image_id], v_where=[self.model.image_id == image_id],
v_order="asc", v_order="asc",
v_order_field="id") v_order_field="id",
v_return_objs=True)
async def del_img_label(self, label_ids: list[int]): async def del_img_label(self, label_ids: list[int]):
img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)]) img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])

View File

@ -6,7 +6,8 @@
# @IDE : PyCharm # @IDE : PyCharm
# @desc : pydantic 模型,用于数据库序列化操作 # @desc : pydantic 模型,用于数据库序列化操作
from datetime import datetime from core.data_types import DatetimeStr
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, Field, ConfigDict
from typing import Optional from typing import Optional
@ -18,6 +19,7 @@ from typing import Optional
class ProjectTrainIn(BaseModel): class ProjectTrainIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id") project_id: Optional[int] = Field(..., description="项目id")
weights_id: Optional[str] = Field(None, description="权重文件") weights_id: Optional[str] = Field(None, description="权重文件")
weights_name: Optional[str] = Field(None, description="权重文件名称")
epochs: Optional[int] = Field(50, description="训练轮数") epochs: Optional[int] = Field(50, description="训练轮数")
patience: Optional[int] = Field(20, description="早停的耐心值") patience: Optional[int] = Field(20, description="早停的耐心值")
@ -28,7 +30,7 @@ class ProjectTrainOut(BaseModel):
weights_name: Optional[str] = Field(None, description="权重名称") weights_name: Optional[str] = Field(None, description="权重名称")
epochs: Optional[int] = Field(None, description="训练轮数") epochs: Optional[int] = Field(None, description="训练轮数")
patience: Optional[int] = Field(None, description="早停的耐心值") patience: Optional[int] = Field(None, description="早停的耐心值")
create_time: Optional[datetime] = Field(None, description="训练时间") create_datetime: DatetimeStr
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -1,4 +1,4 @@
from utils import os_utils as os from utils import os_utils as osu
from application.settings import * from application.settings import *
from . import schemas, models, crud from . import schemas, models, crud
from utils.websocket_server import room_manager from utils.websocket_server import room_manager
@ -19,34 +19,39 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
:param db: 数据库session :param db: 数据库session
:return: :return:
""" """
proj_dal = proj_crud.ProjectInfoDal(db)
img_dal = proj_crud.ProjectImageDal(db) img_dal = proj_crud.ProjectImageDal(db)
label_dal = proj_crud.ProjectLabelDal(db) label_dal = proj_crud.ProjectLabelDal(db)
# 先查询两个图片列表 # 先查询两个图片列表
project_images_train = img_dal.get_data( project_images_train = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train']) v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'],
project_images_val = img_dal.get_data( limit=0,
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val']) v_return_count=False,
v_return_objs=True)
project_images_val = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'],
limit=0,
v_return_count=False,
v_return_objs=True)
# 得到训练版本 # 得到训练版本
version_path = 'v' + str(proj_info.train_version + 1) version_path = 'v' + str(proj_info.train_version + 1)
# 创建训练的根目录 # 创建训练的根目录
train_path = os.create_folder(datasets_url, proj_info.project_no, version_path) train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组 # 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = label_dal.get_label_for_train(proj_info.id) label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id)
# 创建图片的的两个文件夹 # 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train') img_path_train = osu.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val') img_path_val = osu.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹 # 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train') label_path_train = osu.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val') label_path_val = osu.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件 # 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, proj_info.project_no + '.yaml') yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml')
yaml_data = { yaml_data = {
'path': train_path, 'path': train_path,
'train': 'images/train', 'train': 'images/train',
@ -59,20 +64,20 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
# 开始循环复制图片和生成label.txt # 开始循环复制图片和生成label.txt
# 先操作train # 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list) await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
# 再操作val # 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list) await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
# 开始执行异步训练 # 开始执行异步训练
data = yaml_file data = yaml_file
project = os.file_path(runs_url, proj_info.project_no) project = osu.file_path(runs_url, proj_info.project_no)
name = version_path name = version_path
return data, project, name return data, project, name
async def operate_img_label( async def operate_img_label(
img_list: list[proj_models.ProjectImgLabel], img_list: list[proj_models.ProjectImage],
img_path: str, img_path: str,
label_path: str, label_path: str,
db: AsyncSession, db: AsyncSession,
@ -90,10 +95,10 @@ async def operate_img_label(
image = img_list[i] image = img_list[i]
# 先复制图片,并把图片改名,不改后缀 # 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i) file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name) osu.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件 # 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id) img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
label_txt_path = os.file_path(label_path, file_name + '.txt') label_txt_path = osu.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file: with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list: for image_label in img_label_list:
index = label_id_list.index(image_label.label_id) index = label_id_list.index(image_label.label_id)
@ -103,20 +108,19 @@ async def operate_img_label(
+ image_label.mark_height + '\n') + image_label.mark_height + '\n')
async def run_event_loop( def run_event_loop(
data: str, data: str,
project: str, project: str,
name: str, name: str,
train_in: schemas.ProjectTrainIn, train_in: schemas.ProjectTrainIn,
project_id: int, project_id: int,
db: AsyncSession): train_info: models.ProjectTrain,
loop = asyncio.new_event_loop() is_gup: str):
asyncio.set_event_loop(loop) # 运行异步函数,开始训练
# 运行异步函数 loop_run = asyncio.new_event_loop()
loop.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id, asyncio.set_event_loop(loop_run)
project_id, db)) loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience,
# 可选: 关闭循环 project_id, train_info, is_gup))
loop.close()
async def run_commend( async def run_commend(
@ -125,10 +129,9 @@ async def run_commend(
name: str, name: str,
epochs: int, epochs: int,
patience: int, patience: int,
weights: str,
project_id: int, project_id: int,
db: AsyncSession, train_info: models.ProjectTrain,
rd: Redis): is_gpu: str):
""" """
执行训练 执行训练
:param data: 训练数据集 :param data: 训练数据集
@ -138,23 +141,20 @@ async def run_commend(
:param patience: 早停耐心值 :param patience: 早停耐心值
:param weights: 权重文件 :param weights: 权重文件
:param project_id: 项目id :param project_id: 项目id
:param db: 数据库session :param train_info: 训练信息
:param rd: redis连接 :param is_gpu: 是否是gpu环境
:return: :return:
""" """
yolo_path = os.file_path(yolo_url, 'train.py') yolo_path = osu.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id) room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n") await room_manager.send_to_room(room, f"AiCheckV2.0: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)] "--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻 # 增加权重文件,在之前训练的基础上重新训练
if weights != '' and weights is not None: if train_info is not None:
train_info = await crud.ProjectTrainDal(db).get_data(data_id=int(weights)) commend.append("--weights=" + train_info.best_pt)
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
is_gpu = rd.get('is_gpu')
# 判断是否存在cuda版本 # 判断是否存在cuda版本
if is_gpu == 'True': if is_gpu == 'True':
commend.append("--device=0") commend.append("--device=0")
@ -180,20 +180,33 @@ async def run_commend(
await room_manager.send_to_room(room, 'error') await room_manager.send_to_room(room, 'error')
else: else:
await room_manager.send_to_room(room, 'success') await room_manager.send_to_room(room, 'success')
# 然后保存版本训练信息
train = models.ProjectTrain()
train.project_id = project_id async def add_train(
train.train_version = name db,
train_url = os.file_path(project, name) project_id,
train.train_url = train_url name,
train.train_data = data project,
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt') data,
last_pt_path = os.file_path(train_url, 'weights', 'last.pt') train_in,
train.best_pt = bast_pt_path user_id):
train.last_pt = last_pt_path # 更新版本信息
if weights is not None and weights != '': await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id)
train.weights_id = weights # 增加训练版本信息
train.weights_name = train_info.train_version train = models.ProjectTrain()
train.patience = patience train.project_id = project_id
train.epochs = epochs train.train_version = name
await crud.ProjectTrainDal(db).create_data(data=train) train_url = osu.file_path(project, name)
train.train_url = train_url
train.train_data = data
train.user_id = user_id
bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt')
last_pt_path = osu.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if train_in is not None:
train.weights_id = train_in.weights_id
train.weights_name = train_in.weights_name
train.patience = train_in.patience
train.epochs = train_in.epochs
await crud.ProjectTrainDal(db).create_model(data=train)

View File

@ -3,6 +3,8 @@
# @version : 1.0 # @version : 1.0
# @Create Time : 2025/04/03 10:32 # @Create Time : 2025/04/03 10:32
# @File : views.py # @File : views.py
from core.database import redis_getter
from . import models, schemas, crud, service from . import models, schemas, crud, service
from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth from apps.vadmin.auth.utils.validation.auth import Auth
@ -10,6 +12,7 @@ from utils.response import SuccessResponse, ErrorResponse
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
import threading import threading
from redis.asyncio import Redis
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
@ -19,10 +22,11 @@ app = APIRouter()
########################################################### ###########################################################
# 项目训练信息 # 项目训练信息
########################################################### ###########################################################
@app.post("/", summary="执行训练") @app.post("/start", summary="执行训练")
async def run_train( async def run_train(
train_in: schemas.ProjectTrainIn, train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
proj_id = train_in.project_id proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db) proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db) proj_img_dal = ProjectImageDal(auth.db)
@ -43,12 +47,17 @@ async def run_train(
return ErrorResponse("训练图片中存在未标注的图片") return ErrorResponse("训练图片中存在未标注的图片")
if val_label_count > 0: if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片") return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = service.before_train(proj_info, auth.db) data, project, name = await service.before_train(proj_info, auth.db)
is_gpu = await rd.get('is_gpu')
train_info = None
if train_in.weights_id is not None:
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
# 异步执行操作操作过程通过websocket进行同步 # 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread( thread_train = threading.Thread(
target=service.run_event_loop, target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, auth.db,)) args=(data, project, name, train_in, proj_id, train_info, is_gpu))
thread_train.start() thread_train.start()
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
return SuccessResponse(msg="执行成功") return SuccessResponse(msg="执行成功")
@ -57,15 +66,16 @@ async def train_list(
proj_id: int, proj_id: int,
auth: Auth = Depends(AllUserAuth())): auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectTrainDal(auth.db).get_datas( datas = await crud.ProjectTrainDal(auth.db).get_datas(
limit=0,
v_where=[models.ProjectTrain.project_id == proj_id], v_where=[models.ProjectTrain.project_id == proj_id],
v_schema=schemas.ProjectTrainOut, v_schema=schemas.ProjectTrainOut,
v_order="asc", v_order="asc",
v_order_field="id",v_return_count=False) v_order_field="id",
v_return_count=False)
return SuccessResponse(data=datas) return SuccessResponse(data=datas)
@app.get("/result/{proj_id}", summary="查询训练报告") @app.get("/result/{train_id}", summary="查询训练报告")
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())): async def get_result(train_id: int, auth: Auth = Depends(AllUserAuth())):
result = await crud.ProjectTrainDal(auth.db).get_result(train_id) result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
return SuccessResponse(data=result) return SuccessResponse(data=result)

View File

@ -263,11 +263,17 @@ class DalBase:
await self.db.execute(insert(self.model), datas) await self.db.execute(insert(self.model), datas)
await self.db.flush() await self.db.flush()
async def create_model(self, data: Any) -> None:
"""
创建单个model
:param data: model
"""
self.db.add(data)
await self.db.flush()
async def create_models(self, datas: list[Any]) -> None: async def create_models(self, datas: list[Any]) -> None:
""" """
批量创建数据 批量创建数据
SQLAlchemy 2.0 批量插入不支持 MySQL 返回值
https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
:param datas: model数组 :param datas: model数组
""" """
self.db.add_all(datas) self.db.add_all(datas)

View File

@ -7,7 +7,7 @@ from utils.websocket_server import room_manager
def websocket_config(app: FastAPI): def websocket_config(app: FastAPI):
@app.websocket("/{room}") @app.websocket("/ws/{room}")
async def websocket_room(websocket: WebSocket, room: str): async def websocket_room(websocket: WebSocket, room: str):
""" """
websocket 房间管理 websocket 房间管理