94 lines
3.3 KiB
Python
94 lines
3.3 KiB
Python
import torch
|
|
import torchvision.transforms as transforms
|
|
import numpy as np
|
|
import cv2
|
|
import logging
|
|
|
|
from .model import Net
|
|
from .resnet import resnet18
|
|
# from fastreid.config import get_cfg
|
|
# from fastreid.engine import DefaultTrainer
|
|
# from fastreid.utils.checkpoint import Checkpointer
|
|
|
|
|
|
class Extractor(object):
|
|
def __init__(self, model_path, use_cuda=True):
|
|
self.net = Net(reid=True)
|
|
# self.net = resnet18(reid=True)
|
|
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
|
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
|
self.net.load_state_dict(state_dict if 'net_dict' not in state_dict else state_dict['net_dict'], strict=False)
|
|
logger = logging.getLogger("root.tracker")
|
|
logger.info("Loading weights from {}... Done!".format(model_path))
|
|
self.net.to(self.device)
|
|
self.size = (64, 128)
|
|
self.norm = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def _preprocess(self, im_crops):
|
|
"""
|
|
TODO:
|
|
1. to float with scale from 0 to 1
|
|
2. resize to (64, 128) as Market1501 dataset did
|
|
3. concatenate to a numpy array
|
|
3. to torch Tensor
|
|
4. normalize
|
|
"""
|
|
|
|
def _resize(im, size):
|
|
return cv2.resize(im.astype(np.float32) / 255., size)
|
|
|
|
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
|
return im_batch
|
|
|
|
def __call__(self, im_crops):
|
|
im_batch = self._preprocess(im_crops)
|
|
with torch.no_grad():
|
|
im_batch = im_batch.to(self.device)
|
|
features = self.net(im_batch)
|
|
return features.cpu().numpy()
|
|
|
|
|
|
class FastReIDExtractor(object):
|
|
def __init__(self, model_config, model_path, use_cuda=True):
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(model_config)
|
|
cfg.MODEL.BACKBONE.PRETRAIN = False
|
|
self.net = DefaultTrainer.build_model(cfg)
|
|
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
|
|
|
Checkpointer(self.net).load(model_path)
|
|
logger = logging.getLogger("root.tracker")
|
|
logger.info("Loading weights from {}... Done!".format(model_path))
|
|
self.net.to(self.device)
|
|
self.net.eval()
|
|
height, width = cfg.INPUT.SIZE_TEST
|
|
self.size = (width, height)
|
|
self.norm = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def _preprocess(self, im_crops):
|
|
def _resize(im, size):
|
|
return cv2.resize(im.astype(np.float32) / 255., size)
|
|
|
|
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
|
return im_batch
|
|
|
|
def __call__(self, im_crops):
|
|
im_batch = self._preprocess(im_crops)
|
|
with torch.no_grad():
|
|
im_batch = im_batch.to(self.device)
|
|
features = self.net(im_batch)
|
|
return features.cpu().numpy()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
|
|
extr = Extractor("checkpoint/ckpt.t7")
|
|
feature = extr(img)
|
|
print(feature.shape)
|