47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from SetParams.DataType.TypeDef import *
|
|
from SetParams.DataType.BaseParam import *
|
|
|
|
|
|
class TrainParams(BaseParam):
|
|
"""
|
|
训练参数示例,使用时需要根据实际情况修改,如有其他参数(如推理函数的参数)需要自定义,并继承自 BaseParam
|
|
# 例如训练的时候需要传入以下参数
|
|
gpu_num = IntType("gpu_num", 2)
|
|
support_cpu = BoolType("support_cpu", True)
|
|
labels = EnumType("labels", 0, ["dog", "cat"])
|
|
labels.default = 0
|
|
|
|
self.add_param(gpu_num)
|
|
self.add_param(support_cpu)
|
|
self.add_param(labels)
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
num_classes = IntType("num_classes", default=9)
|
|
lr=FloatType("lr",default=0.005)
|
|
lr_schedulerList = ListType("lr_schedulerList",default=[30,60])
|
|
device = StringType("device",default="cpu")
|
|
DatasetDir=StringType("DatasetDir",default="./datasets/M006B_duanmian")
|
|
saveModDir=StringType("saveModDir",default="./saved_model/M006B_duanmian.pt")
|
|
resumeModPath=StringType("resumeModPath",default="")
|
|
epochnum=IntType("epochnum", default=100)
|
|
saveEpoch=IntType("saveEpoch", default=1)
|
|
|
|
self.add_param(num_classes)
|
|
self.add_param(lr)
|
|
self.add_param(lr_schedulerList)
|
|
self.add_param(device)
|
|
self.add_param(DatasetDir)
|
|
self.add_param(saveModDir)
|
|
self.add_param(resumeModPath)
|
|
self.add_param(epochnum)
|
|
self.add_param(saveEpoch)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
params = TrainParams()
|
|
params.save_to_file('./SetParams/TrainParams.json')
|
|
|