20 lines
1.0 KiB
Python
20 lines
1.0 KiB
Python
|
from .deep_sort import DeepSort
|
||
|
|
||
|
__all__ = ['DeepSort', 'build_tracker']
|
||
|
|
||
|
|
||
|
def build_tracker(cfg, use_cuda):
|
||
|
if cfg.USE_FASTREID:
|
||
|
return DeepSort(model_path=cfg.FASTREID.CHECKPOINT, model_config=cfg.FASTREID.CFG,
|
||
|
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
|
||
|
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
|
||
|
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
|
||
|
use_cuda=use_cuda)
|
||
|
|
||
|
else:
|
||
|
return DeepSort(model_path=cfg.DEEPSORT.REID_CKPT,
|
||
|
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
|
||
|
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
|
||
|
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
|
||
|
use_cuda=use_cuda)
|