aicheckv2-api/deep_sort/deep/feature_extractor.py

94 lines
3.3 KiB
Python

import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import logging
from .model import Net
from .resnet import resnet18
# from fastreid.config import get_cfg
# from fastreid.engine import DefaultTrainer
# from fastreid.utils.checkpoint import Checkpointer
class Extractor(object):
def __init__(self, model_path, use_cuda=True):
self.net = Net(reid=True)
# self.net = resnet18(reid=True)
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
self.net.load_state_dict(state_dict if 'net_dict' not in state_dict else state_dict['net_dict'], strict=False)
logger = logging.getLogger("root.tracker")
logger.info("Loading weights from {}... Done!".format(model_path))
self.net.to(self.device)
self.size = (64, 128)
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def _preprocess(self, im_crops):
"""
TODO:
1. to float with scale from 0 to 1
2. resize to (64, 128) as Market1501 dataset did
3. concatenate to a numpy array
3. to torch Tensor
4. normalize
"""
def _resize(im, size):
return cv2.resize(im.astype(np.float32) / 255., size)
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
return im_batch
def __call__(self, im_crops):
im_batch = self._preprocess(im_crops)
with torch.no_grad():
im_batch = im_batch.to(self.device)
features = self.net(im_batch)
return features.cpu().numpy()
class FastReIDExtractor(object):
def __init__(self, model_config, model_path, use_cuda=True):
cfg = get_cfg()
cfg.merge_from_file(model_config)
cfg.MODEL.BACKBONE.PRETRAIN = False
self.net = DefaultTrainer.build_model(cfg)
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
Checkpointer(self.net).load(model_path)
logger = logging.getLogger("root.tracker")
logger.info("Loading weights from {}... Done!".format(model_path))
self.net.to(self.device)
self.net.eval()
height, width = cfg.INPUT.SIZE_TEST
self.size = (width, height)
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def _preprocess(self, im_crops):
def _resize(im, size):
return cv2.resize(im.astype(np.float32) / 255., size)
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
return im_batch
def __call__(self, im_crops):
im_batch = self._preprocess(im_crops)
with torch.no_grad():
im_batch = im_batch.to(self.device)
features = self.net(im_batch)
return features.cpu().numpy()
if __name__ == '__main__':
img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
extr = Extractor("checkpoint/ckpt.t7")
feature = extr(img)
print(feature.shape)