16 lines
293 B
Python
16 lines
293 B
Python
import torch
|
|
|
|
features = torch.load("features.pth")
|
|
qf = features["qf"]
|
|
ql = features["ql"]
|
|
gf = features["gf"]
|
|
gl = features["gl"]
|
|
|
|
scores = qf.mm(gf.t())
|
|
res = scores.topk(5, dim=1)[1][:,0]
|
|
top1correct = gl[res].eq(ql).sum().item()
|
|
|
|
print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
|
|
|
|
|