from SetParams.DataType.TypeDef import *
from SetParams.DataType.BaseParam import *


class TrainParams(BaseParam):
    """
    训练参数示例,使用时需要根据实际情况修改,如有其他参数(如推理函数的参数)需要自定义,并继承自 BaseParam  
        # 例如训练的时候需要传入以下参数
        gpu_num = IntType("gpu_num", 2)
        support_cpu = BoolType("support_cpu", True)
        labels = EnumType("labels", 0, ["dog", "cat"])
        labels.default = 0

        self.add_param(gpu_num)
        self.add_param(support_cpu)
        self.add_param(labels)
    """

    def __init__(self):
        super().__init__()

        num_classes = IntType("num_classes", default=9)
        lr=FloatType("lr",default=0.005)
        lr_schedulerList = ListType("lr_schedulerList",default=[30,60])
        device = StringType("device",default="cpu")
        DatasetDir=StringType("DatasetDir",default="./datasets/M006B_duanmian")
        saveModDir=StringType("saveModDir",default="./saved_model/M006B_duanmian.pt")
        resumeModPath=StringType("resumeModPath",default="")
        epochnum=IntType("epochnum", default=100)
        saveEpoch=IntType("saveEpoch", default=1)

        self.add_param(num_classes)
        self.add_param(lr)
        self.add_param(lr_schedulerList)
        self.add_param(device)
        self.add_param(DatasetDir)
        self.add_param(saveModDir)
        self.add_param(resumeModPath)
        self.add_param(epochnum)
        self.add_param(saveEpoch)


if __name__ == "__main__":
    params = TrainParams()
    params.save_to_file('./SetParams/TrainParams.json')