如何从张量中获得最高置信度的类? [英] How to get the class with highest confidence from a tensor?
问题描述
我想以最高的信心获得课程.下面是执行分类任务的代码:
I want to obtain the class with the highest confidence. Heres the code that performs classification task:
names = ['class A', 'class B', 'class C']
def classify_face(image):
device = torch.device("cpu")
img = process_image(image)
print('Image processed')
# img = image.unsqueeze_(0)
# img = image.float()
pred = model(img)[0]
# Apply NMS
pred = non_max_suppression(pred, 0.4, 0.5, classes = [0, 1, 2], agnostic = None )
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
#print(pred)
model.eval()
model.cpu()
print(pred)
# output = non_max_suppression(output, 0.4, 0.5, classes = class_names, agnostic = False)
#_, predicted = torch.max(output[0], 1)
#print(predicted.data[0], "predicted")
classification = torch.cat(pred)[:, -1]
index = int(classification)
print(names[index])
return names[index]
当 pred 具有一维张量时,上述代码完美运行.但是如果张量大小超过这个值,我就会收到错误消息.
The above code works perfectly when the pred has a 1D tensor. But I get an error if the tensor size is more than that.
有5个元素:x1
、y1
、x2
、y2
、confidence代码>和<代码>类代码>.
There are 5 elements: x1
, y1
, x2
, y2
, confidence
, and class
.
例如:
pred = [torch.tensor([[212.38568, 117.47020, 339.35773, 266.00513, 0.74144, 2.00000],
[214.60651, 118.50694, 339.90192, 265.91696, 0.94277, 0.00000]])]
错误:
Traceback (most recent call last):
File "WEBCAM_DETECT.py", line 172, in <module>
label = classify_face(frame)
File "WEBCAM_DETECT.py", line 154, in classify_face
index = int(classification)
ValueError: only one element tensors can be converted to Python scalars
所以我想以最高的信心参加课程.请让我知道如何做到这一点,或者是否有更好的方法.
So I want to access the class with highest confidence. Please let me know how to do that or if there is a better way to do.
推荐答案
您似乎总是会关注张量 pred[0]
.所以让 pred
成为:
It seems you will always be looking at the tensor pred[0]
. So let pred
be:
pred = torch.tensor([[212.38568, 117.47020, 339.35773, 266.00513, 0.74144, 2.00000],
[214.60651, 118.50694, 339.90192, 265.91696, 0.94277, 0.00000]])
置信度最高的预测指标为:
The index of the prediction with highest confidence is:
i = torch.argmax(pred[:, 4])
因此,您只需获取该索引处的最后一个值:
Therefore, you just have to get the last value at that index:
pred[i, -1]
类名将是 names[int(pred[i, -1])]
.
这篇关于如何从张量中获得最高置信度的类?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!