调试
This commit is contained in:
parent
2af24e2f03
commit
e0d7d37b9e
@ -161,8 +161,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
check_suffix(weights, '.pt') # check weights
|
||||
pretrained = weights.endswith('.pt')
|
||||
if pretrained:
|
||||
with torch_distributed_zero_first(LOCAL_RANK):
|
||||
weights = attempt_download(weights) # download if not found locally
|
||||
# with torch_distributed_zero_first(LOCAL_RANK):
|
||||
# weights = attempt_download(weights) # download if not found locally
|
||||
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
||||
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
||||
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
|
||||
|
Loading…
Reference in New Issue
Block a user