190 lines
6.7 KiB
Python
190 lines
6.7 KiB
Python
import argparse
|
|
import os
|
|
import tempfile
|
|
|
|
import math
|
|
import warnings
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torchvision
|
|
from torch.optim import lr_scheduler
|
|
|
|
from multi_train_utils.distributed_utils import init_distributed_mode, cleanup
|
|
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate, load_model
|
|
import torch.distributed as dist
|
|
from datasets import ClsDataset, read_split_data
|
|
|
|
from resnet import resnet18
|
|
|
|
|
|
# plot figure
|
|
x_epoch = []
|
|
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
|
|
fig = plt.figure()
|
|
ax0 = fig.add_subplot(121, title="loss")
|
|
ax1 = fig.add_subplot(122, title="top1_err")
|
|
|
|
|
|
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
|
global record
|
|
record['train_loss'].append(train_loss)
|
|
record['train_err'].append(train_err)
|
|
record['test_loss'].append(test_loss)
|
|
record['test_err'].append(test_err)
|
|
|
|
x_epoch.append(epoch)
|
|
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
|
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
|
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
|
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
|
if epoch == 0:
|
|
ax0.legend()
|
|
ax1.legend()
|
|
fig.savefig("train.jpg")
|
|
|
|
|
|
def main(args):
|
|
init_distributed_mode(args)
|
|
|
|
rank = args.rank
|
|
device = torch.device(args.device)
|
|
batch_size = args.batch_size
|
|
weights_path = args.weights
|
|
args.lr *= args.world_size
|
|
checkpoint_path = ''
|
|
|
|
if rank == 0:
|
|
print(args)
|
|
if os.path.exists('./checkpoint') is False:
|
|
os.mkdir('./checkpoint')
|
|
|
|
train_info, val_info, num_classes = read_split_data(args.data_dir, valid_rate=0.2)
|
|
train_images_path, train_labels = train_info
|
|
val_images_path, val_labels = val_info
|
|
|
|
transform_train = torchvision.transforms.Compose([
|
|
torchvision.transforms.RandomCrop((128, 64), padding=4),
|
|
torchvision.transforms.RandomHorizontalFlip(),
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
transform_val = torchvision.transforms.Compose([
|
|
torchvision.transforms.Resize((128, 64)),
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
train_dataset = ClsDataset(
|
|
images_path=train_images_path,
|
|
images_labels=train_labels,
|
|
transform=transform_train
|
|
)
|
|
val_dataset = ClsDataset(
|
|
images_path=val_images_path,
|
|
images_labels=val_labels,
|
|
transform=transform_val
|
|
)
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
|
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
|
|
|
|
number_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
|
|
|
if rank == 0:
|
|
print('Using {} dataloader workers every process'.format(number_workers))
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_sampler=train_batch_sampler,
|
|
pin_memory=True,
|
|
num_workers=number_workers
|
|
)
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_dataset,
|
|
sampler=val_sampler,
|
|
batch_size=batch_size,
|
|
pin_memory=True,
|
|
num_workers=number_workers,
|
|
)
|
|
|
|
# net definition
|
|
start_epoch = 0
|
|
net = resnet18(num_classes=num_classes)
|
|
if args.weights:
|
|
print('Loading from ', args.weights)
|
|
checkpoint = torch.load(args.weights, map_location='cpu')
|
|
net_dict = checkpoint if 'net_dict' not in checkpoint else checkpoint['net_dict']
|
|
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else start_epoch
|
|
net = load_model(net_dict, net.state_dict(), net)
|
|
else:
|
|
warnings.warn("better providing pretraining weights")
|
|
checkpoint_path = os.path.join(tempfile.gettempdir(), 'initial_weights.pth')
|
|
if rank == 0:
|
|
torch.save(net.state_dict(), checkpoint_path)
|
|
|
|
dist.barrier()
|
|
net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
|
|
|
if args.freeze_layers:
|
|
for name, param in net.named_parameters():
|
|
if 'fc' not in name:
|
|
param.requires_grad = False
|
|
else:
|
|
if args.syncBN:
|
|
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
|
net.to(device)
|
|
|
|
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
|
|
|
|
# loss and optimizer
|
|
pg = [p for p in net.parameters() if p.requires_grad]
|
|
optimizer = torch.optim.SGD(pg, args.lr, momentum=0.9, weight_decay=5e-4)
|
|
|
|
lr = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
|
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
|
|
for epoch in range(start_epoch, start_epoch + args.epochs):
|
|
train_positive, train_loss = train_one_epoch(net, optimizer, train_loader, device, epoch)
|
|
train_acc = train_positive / len(train_dataset)
|
|
scheduler.step()
|
|
|
|
test_positive, test_loss = evaluate(net, val_loader, device)
|
|
test_acc = test_positive / len(val_dataset)
|
|
|
|
if rank == 0:
|
|
print('[epoch {}] accuracy: {}'.format(epoch, test_acc))
|
|
|
|
state_dict = {
|
|
'net_dict': net.module.state_dict(),
|
|
'acc': test_acc,
|
|
'epoch': epoch
|
|
}
|
|
torch.save(state_dict, './checkpoint/model_{}.pth'.format(epoch))
|
|
draw_curve(epoch, train_loss, 1 - train_acc, test_loss, 1 - test_acc)
|
|
|
|
if rank == 0:
|
|
if os.path.exists(checkpoint_path) is True:
|
|
os.remove(checkpoint_path)
|
|
cleanup()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Train on market1501")
|
|
parser.add_argument("--data-dir", default='data', type=str)
|
|
parser.add_argument('--epochs', type=int, default=40)
|
|
parser.add_argument('--batch_size', type=int, default=32)
|
|
parser.add_argument("--lr", default=0.001, type=float)
|
|
parser.add_argument('--lrf', default=0.1, type=float)
|
|
parser.add_argument('--syncBN', type=bool, default=True)
|
|
|
|
parser.add_argument('--weights', type=str, default='./checkpoint/resnet18.pth')
|
|
parser.add_argument('--freeze-layers', action='store_true')
|
|
|
|
# not change the following parameters, the system will automatically assignment
|
|
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0, 1 or cpu)')
|
|
parser.add_argument('--world_size', default=4, type=int, help='number of distributed processes')
|
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|