import sys
import json
from setparams import TrainParams




 


def ppx_train(params,id):

    ppx_num_classes = params.get('num_classes').value
    ppx_epoch =  params.get('epochnum').value
    ppx_saveEpoch =  params.get('saveEpoch').value
    ppx_device =  params.get('device').value
    ppx_DatasetDir =  params.get('DatasetDir').value  
    ppx_saveModDir =  params.get('saveModDir').value
    ppx_lr =  params.get('lr').value
    ppx_lr_schedulerList = params.get('lr_schedulerList').value
    ppx_resumeModPath =  params.get('resumeModPath').value
    ppx_id = id
    if ppx_resumeModPath == '': ppx_resumeModPath= "/mnt/sdc/algorithm/AICheck-MaskRCNN/app/maskrcnn_ppx/pretrain/mask_rcnn_r50_fpn_2x_coco.pdparams" #'COCO'



    model.train(
        num_epochs=ppx_epoch,  #***
        save_interval_epochs=ppx_saveEpoch,  #***
        train_dataset=train_dataset,  #***
        train_batch_size=2,
        eval_dataset=eval_dataset,  #***
        pretrain_weights=ppx_resumeModPath,  #***
        learning_rate=ppx_lr,  #***
        lr_decay_epochs=ppx_lr_schedulerList,  #***
        warmup_steps=10, 
        warmup_start_lr=0.0,
        save_dir=ppx_saveModDir,  #***    
        use_vdl=True)




#@start_train_algorithm
def main(params_str):
    params = TrainParams()
    params.read_from_str(params_str)
    ppx_train(params,id='1')


if __name__ == "__main__":
    params_list = [
                    {"index":0,"name":"num_classes","value":9,"description":'类别数(加背景)',"default":9,"type":"I", "show":True},
                    {"index":1,"name":"lr","value":0.0003,"description":'学习率',"default":0.0001,"type":"F", "show":True},
                    {"index":2,"name":"lr_schedulerList","value":[30,60],"description":'学习率衰减轮次',"default":[30,60],"type":"L", "show":True},
                    {"index":3,"name":"device","value":"cpu","description":'训练核心',"default":"cpu","type":"S", "show":True},
                    {"index":4,"name":"DatasetDir","value":"/mnt/sdc/algorithm/PaddleX/datasets/DDX_nb","description":'数据集路径',"default":"/mnt/sdc/algorithm/PaddleX/datasets/DDX_nb","type":"S", "show":False},
                    {"index":5,"name":"saveModDir","value":"/mnt/sdc/algorithm/PaddleX/output","description":'保存模型路径',"default":"/mnt/sdc/algorithm/PaddleX/output","type":"S", "show":False},
                    {"index":6,"name":"resumeModPath","value":'',"description":'继续训练路径',"default":'',"type":"S", "show":False},
                    {"index":7,"name":"epochnum","value":100,"description":'训练轮次',"default":100,"type":"I", "show":True},
                    {"index":8,"name":"saveEpoch","value":2,"description":'保存模型轮次',"default":2,"type":"I", "show":True}]
    params_str = json.dumps(params_list) 
    print(params_str)
   
    main(params_str)