RODY/setparams.py

47 lines
1.6 KiB
Python
Raw Normal View History

2022-11-04 17:37:08 +08:00
from SetParams.DataType.TypeDef import *
from SetParams.DataType.BaseParam import *
class TrainParams(BaseParam):
"""
训练参数示例使用时需要根据实际情况修改如有其他参数如推理函数的参数需要自定义并继承自 BaseParam
# 例如训练的时候需要传入以下参数
gpu_num = IntType("gpu_num", 2)
support_cpu = BoolType("support_cpu", True)
labels = EnumType("labels", 0, ["dog", "cat"])
labels.default = 0
self.add_param(gpu_num)
self.add_param(support_cpu)
self.add_param(labels)
"""
def __init__(self):
super().__init__()
num_classes = IntType("num_classes", default=9)
lr=FloatType("lr",default=0.005)
lr_schedulerList = ListType("lr_schedulerList",default=[30,60])
device = StringType("device",default="cpu")
DatasetDir=StringType("DatasetDir",default="./datasets/M006B_duanmian")
saveModDir=StringType("saveModDir",default="./saved_model/M006B_duanmian.pt")
resumeModPath=StringType("resumeModPath",default="")
epochnum=IntType("epochnum", default=100)
saveEpoch=IntType("saveEpoch", default=1)
self.add_param(num_classes)
self.add_param(lr)
self.add_param(lr_schedulerList)
self.add_param(device)
self.add_param(DatasetDir)
self.add_param(saveModDir)
self.add_param(resumeModPath)
self.add_param(epochnum)
self.add_param(saveEpoch)
if __name__ == "__main__":
params = TrainParams()
params.save_to_file('./SetParams/TrainParams.json')